diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..ae19ba1d115ecf5fd4f007490451a0e40c2abc5c --- /dev/null +++ b/.flake8 @@ -0,0 +1,10 @@ +[flake8] +enable-extensions = G +select = B,C,E,F,G,P,SIM1,T4,W,B9 +max-line-length = 120 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E226,E265 +exclude = + third_party diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c22aca22a9605eb17bfe7ef6d2fe9906c841ac4a --- /dev/null +++ b/.gitattributes @@ -0,0 +1,44 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +<<<<<<< HEAD +assets/*.gif filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +======= +>>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986 +transformer_engine_torch-1.12.0+cu121-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text +transformer_engine.whl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7a3b55f6b251c798c04f90f4bc28b76cb4ab7740 --- /dev/null +++ b/.gitignore @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Misc +outputs/ +checkpoints/* +!checkpoints/README.md +datasets/* +!datasets/README.md +apex/ + +# Data types +*.jit +*.pt +*.hdr +*.webp +*.pgm +*.tiff +*.tif +*.tar +*.tar.gz +*.gz +*.pkl +*.pt +*.bin +*.pickle +*.txt + +# Other uncheckable file types +*.zip +*.exe +*.dll +*.swp +*.vscode +*.DS_Store +*.pyc +*Thumbs.db +*.patch + +# Credential information that should never be checked in +credentials +*.secret + +# ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------ + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +results/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.config +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Third party +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# ruff +.ruff_cache + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +CLIP +.devcontainer/devcontainer.json + +# Coverage +.coverage +coverage.xml + +# JUnit Reports +report.xml + +# CI-CD +temp/ +envs.txt +manifest.json + + +# locks and t5 temp files +*.locks* +*.no_exist* +*models--t5* + +# OneLogger +wandb/ +onelogger.err +onelogger.log diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..1e5dc8b549e189fe9f514555bb470eccb4a7e3c8 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,27 @@ +[submodule "gui/dependencies/pybind11"] + path = gui/dependencies/pybind11 + url = https://github.com/Tom94/pybind11 +[submodule "gui/dependencies/glfw"] + path = gui/dependencies/glfw + url = https://github.com/Tom94/glfw +[submodule "gui/dependencies/args"] + path = gui/dependencies/args + url = https://github.com/Taywee/args +[submodule "gui/dependencies/tinylogger"] + path = gui/dependencies/tinylogger + url = https://github.com/Tom94/tinylogger +[submodule "gui/dependencies/imgui"] + path = gui/dependencies/imgui + url = https://github.com/ocornut/imgui.git +[submodule "gui/dependencies/dlss"] + path = gui/dependencies/dlss + url = https://github.com/NVIDIA/DLSS +[submodule "gui/dependencies/OpenXR-SDK"] + path = gui/dependencies/OpenXR-SDK + url = https://github.com/KhronosGroup/OpenXR-SDK.git +[submodule "gui/dependencies/zlib"] + path = gui/dependencies/zlib + url = https://github.com/Tom94/zlib +[submodule "gui/dependencies/fmt"] + path = gui/dependencies/fmt + url = https://github.com/fmtlib/fmt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d82ca7ba148ca42f4adaf78648b672d73beee30a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +default_language_version: + python: python3.10 +repos: + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + - --max-line-length=120 + - --ignore=E501,F401,E203,E402,E265,E741,F841,F821,F811,W503,E231,E225,E702 + exclude: ^dist/|^third_party/ + + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + args: [--line-length=120] + exclude: ^dist/|^third_party/ + + - repo: https://github.com/timothycrosley/isort + rev: 5.12.0 + hooks: + - id: isort + args: [--line-length=120] + + - repo: https://github.com/MarcoGorelli/absolufy-imports + rev: v0.3.1 + hooks: + - id: absolufy-imports + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + exclude: ^tests/.*/fixtures/.* + args: [--markdown-linebreak-ext=md] + - id: end-of-file-fixer + exclude: ^tests/.*/fixtures/.* + - id: check-added-large-files + args: ['--maxkb=2000'] diff --git a/ATTRIBUTIONS.md b/ATTRIBUTIONS.md new file mode 100644 index 0000000000000000000000000000000000000000..0a6c6e4b9bcfedecd5a5191599016ca59772dd24 --- /dev/null +++ b/ATTRIBUTIONS.md @@ -0,0 +1,2861 @@ +# Open Source License Attribution + + Cosmos uses Open Source components. You can find the details of these open-source projects along with license information below, sorted alphabetically. + We are grateful to the developers for their contributions to open source and acknowledge these below. + +## Better-Profanity - [MIT License](https://github.com/snguyenthanh/better_profanity/blob/master/LICENSE) + + ``` + + Copyright (c) 2018 The Python Packaging Authority + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## FFmpeg - [FFMPEG License](https://github.com/FFmpeg/FFmpeg/blob/master/LICENSE.md) + + ``` + # License + + Most files in FFmpeg are under the GNU Lesser General Public License version 2.1 + or later (LGPL v2.1+). Read the file `COPYING.LGPLv2.1` for details. Some other + files have MIT/X11/BSD-style licenses. In combination the LGPL v2.1+ applies to + FFmpeg. + + Some optional parts of FFmpeg are licensed under the GNU General Public License + version 2 or later (GPL v2+). See the file `COPYING.GPLv2` for details. None of + these parts are used by default, you have to explicitly pass `--enable-gpl` to + configure to activate them. In this case, FFmpeg's license changes to GPL v2+. + + Specifically, the GPL parts of FFmpeg are: + + - libpostproc + - optional x86 optimization in the files + - `libavcodec/x86/flac_dsp_gpl.asm` + - `libavcodec/x86/idct_mmx.c` + - `libavfilter/x86/vf_removegrain.asm` + - the following building and testing tools + - `compat/solaris/make_sunver.pl` + - `doc/t2h.pm` + - `doc/texi2pod.pl` + - `libswresample/tests/swresample.c` + - `tests/checkasm/*` + - `tests/tiny_ssim.c` + - the following filters in libavfilter: + - `signature_lookup.c` + - `vf_blackframe.c` + - `vf_boxblur.c` + - `vf_colormatrix.c` + - `vf_cover_rect.c` + - `vf_cropdetect.c` + - `vf_delogo.c` + - `vf_eq.c` + - `vf_find_rect.c` + - `vf_fspp.c` + - `vf_histeq.c` + - `vf_hqdn3d.c` + - `vf_kerndeint.c` + - `vf_lensfun.c` (GPL version 3 or later) + - `vf_mcdeint.c` + - `vf_mpdecimate.c` + - `vf_nnedi.c` + - `vf_owdenoise.c` + - `vf_perspective.c` + - `vf_phase.c` + - `vf_pp.c` + - `vf_pp7.c` + - `vf_pullup.c` + - `vf_repeatfields.c` + - `vf_sab.c` + - `vf_signature.c` + - `vf_smartblur.c` + - `vf_spp.c` + - `vf_stereo3d.c` + - `vf_super2xsai.c` + - `vf_tinterlace.c` + - `vf_uspp.c` + - `vf_vaguedenoiser.c` + - `vsrc_mptestsrc.c` + + Should you, for whatever reason, prefer to use version 3 of the (L)GPL, then + the configure parameter `--enable-version3` will activate this licensing option + for you. Read the file `COPYING.LGPLv3` or, if you have enabled GPL parts, + `COPYING.GPLv3` to learn the exact legal terms that apply in this case. + + There are a handful of files under other licensing terms, namely: + + * The files `libavcodec/jfdctfst.c`, `libavcodec/jfdctint_template.c` and + `libavcodec/jrevdct.c` are taken from libjpeg, see the top of the files for + licensing details. Specifically note that you must credit the IJG in the + documentation accompanying your program if you only distribute executables. + You must also indicate any changes including additions and deletions to + those three files in the documentation. + * `tests/reference.pnm` is under the expat license. + + + ## External libraries + + FFmpeg can be combined with a number of external libraries, which sometimes + affect the licensing of binaries resulting from the combination. + + ### Compatible libraries + + The following libraries are under GPL version 2: + - avisynth + - frei0r + - libcdio + - libdavs2 + - librubberband + - libvidstab + - libx264 + - libx265 + - libxavs + - libxavs2 + - libxvid + + When combining them with FFmpeg, FFmpeg needs to be licensed as GPL as well by + passing `--enable-gpl` to configure. + + The following libraries are under LGPL version 3: + - gmp + - libaribb24 + - liblensfun + + When combining them with FFmpeg, use the configure option `--enable-version3` to + upgrade FFmpeg to the LGPL v3. + + The VMAF, mbedTLS, RK MPI, OpenCORE and VisualOn libraries are under the Apache License + 2.0. That license is incompatible with the LGPL v2.1 and the GPL v2, but not with + version 3 of those licenses. So to combine these libraries with FFmpeg, the + license version needs to be upgraded by passing `--enable-version3` to configure. + + The smbclient library is under the GPL v3, to combine it with FFmpeg, + the options `--enable-gpl` and `--enable-version3` have to be passed to + configure to upgrade FFmpeg to the GPL v3. + + ### Incompatible libraries + + There are certain libraries you can combine with FFmpeg whose licenses are not + compatible with the GPL and/or the LGPL. If you wish to enable these + libraries, even in circumstances that their license may be incompatible, pass + `--enable-nonfree` to configure. This will cause the resulting binary to be + unredistributable. + + The Fraunhofer FDK AAC and OpenSSL libraries are under licenses which are + incompatible with the GPLv2 and v3. To the best of our knowledge, they are + compatible with the LGPL. + + ``` + +## Hydra-core [MIT License](https://github.com/facebookresearch/hydra/blob/main/LICENSE) + + ``` + + MIT License + + Copyright (c) Facebook, Inc. and its affiliates. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## Llama-Guard-3-8B [META LLAMA 3 COMMUNITY LICENSE](https://github.com/meta-llama/llama3/blob/main/LICENSE) + + ``` + + META LLAMA 3 COMMUNITY LICENSE AGREEMENT + + Meta Llama 3 Version Release Date: April 18, 2024 + + “Agreement” means the terms and conditions for use, reproduction, distribution, and + modification of the Llama Materials set forth herein. + + “Documentation” means the specifications, manuals, and documentation accompanying Meta + Llama 3 distributed by Meta at https://llama.meta.com/get-started/. + + “Licensee” or “you” means you, or your employer or any other person or entity (if you are + entering into this Agreement on such person or entity’s behalf), of the age required under + applicable laws, rules, or regulations to provide legal consent and that has legal authority + to bind your employer or such other person or entity if you are entering into this Agreement + on their behalf. + + “Meta Llama 3” means the foundational large language models and software and algorithms, + including machine-learning model code, trained model weights, inference-enabling code, + training-enabling code, fine-tuning-enabling code, and other elements of the foregoing + distributed by Meta at https://llama.meta.com/llama-downloads. + + “Llama Materials” means, collectively, Meta’s proprietary Meta Llama 3 and Documentation + (and any portion thereof) made available under this Agreement. + + “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are + an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, + Inc. (if you are located outside of the EEA or Switzerland). + + By clicking “I Accept” below or by using or distributing any portion or element of the Llama + Materials, you agree to be bound by this Agreement. + + 1. License Rights and Redistribution. + + a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and + royalty-free limited license under Meta’s intellectual property or other rights owned by + Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative + works of, and make modifications to the Llama Materials. + + b. Redistribution and Use. + i. If you distribute or make available the Llama Materials (or any derivative works + thereof), or a product or service that uses any of them, including another AI model, you + shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) + prominently display “Built with Meta Llama 3” on a related website, user interface, + blogpost, about page, or product documentation. If you use the Llama Materials to create, + train, fine tune, or otherwise improve an AI model, which is distributed or made available, + you shall also include “Llama 3” at the beginning of any such AI model name. + + ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as + part of an integrated end user product, then Section 2 of this Agreement will not apply + to you. + + iii. You must retain in all copies of the Llama Materials that you distribute the + following attribution notice within a “Notice” text file distributed as a part of such + copies: “Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © + Meta Platforms, Inc. All Rights Reserved.” + + iv. Your use of the Llama Materials must comply with applicable laws and regulations + (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy + for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which + is hereby incorporated by reference into this Agreement. + + v. You will not use the Llama Materials or any output or results of the Llama Materials + to improve any other large language model (excluding Meta Llama 3 or derivative works + thereof). + + 2. Additional Commercial Terms. + + If, on the Meta Llama 3 version release date, the monthly active users of the products or + services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 + million monthly active users in the preceding calendar month, you must request a license + from Meta, which Meta may grant to you in its sole discretion, and you are not authorized + to exercise any of the rights under this Agreement unless or until Meta otherwise expressly + grants you such rights. + + 3. Disclaimer of Warranty. + + UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM + ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL + WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY + WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING + THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS + AND ANY OUTPUT AND RESULTS. + + 4. Limitation of Liability. + + IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, + FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR + PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY + OF THE FOREGOING. + + 5. Intellectual Property. + + a. No trademark licenses are granted under this Agreement, and in connection with the Llama + Materials, neither Meta nor Licensee may use any name or mark owned by or associated with + the other or any of its affiliates, except as required for reasonable and customary use in + describing and redistributing the Llama Materials or as set forth in this Section 5(a). + Meta hereby grants you a license to use “Llama 3” (the “Mark”) solely as required to comply + with the last sentence of Section 1.b.i. You will comply with Meta’s brand guidelines + (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/). + All goodwill arising out of your use of the Mark will inure to the benefit of Meta. + + b. Subject to Meta’s ownership of Llama Materials and derivatives made by or for Meta, with + respect to any derivative works and modifications of the Llama Materials that are made by + you, as between you and Meta, you are and will be the owner of such derivative works and + modifications. + + c. If you institute litigation or other proceedings against Meta or any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Meta Llama 3 + outputs or results, or any portion of any of the foregoing, constitutes infringement of + intellectual property or other rights owned or licensable by you, then any licenses granted + to you under this Agreement shall terminate as of the date such litigation or claim is filed + or instituted. You will indemnify and hold harmless Meta from and against any claim by any + third party arising out of or related to your use or distribution of the Llama Materials. + + 6. Term and Termination. + + The term of this Agreement will commence upon your acceptance of this Agreement or access + to the Llama Materials and will continue in full force and effect until terminated in + accordance with the terms and conditions herein. Meta may terminate this Agreement if you + are in breach of any term or condition of this Agreement. Upon termination of this Agreement, + you shall delete and cease use of the Llama Materials. Sections 3, 4, and 7 shall survive + the termination of this Agreement. + + 7. Governing Law and Jurisdiction. + + This Agreement will be governed and construed under the laws of the State of California + without regard to choice of law principles, and the UN Convention on Contracts for the + International Sale of Goods does not apply to this Agreement. The courts of California + shall have exclusive jurisdiction of any dispute arising out of this Agreement. + + META LLAMA 3 ACCEPTABLE USE POLICY + + Meta is committed to promoting safe and fair use of its tools and features, including Meta + Llama 3. If you access or use Meta Llama 3, you agree to this Acceptable Use Policy + (“Policy”). The most recent copy of this policy can be found at + https://llama.meta.com/llama3/use-policy. + + Prohibited Uses + + We want everyone to use Meta Llama 3 safely and responsibly. You agree you will not use, or + allow others to use, Meta Llama 3 to: + + 1. Violate the law or others’ rights, including to: + + a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal + or unlawful activity or content, such as: + + i. Violence or terrorism + ii. Exploitation or harm to children, including the solicitation, creation, acquisition, + or dissemination of child exploitative content or failure to report Child Sexual Abuse + Material + iii. Human trafficking, exploitation, and sexual violence + iv. The illegal distribution of information or materials to minors, including obscene + materials, or failure to employ legally required age-gating in connection with such + information or materials + v. Sexual solicitation + vi. Any other criminal activity + + b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or + bullying of individuals or groups of individuals + + c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful + conduct in the provision of employment, employment benefits, credit, housing, other economic + benefits, or other essential goods and services + + d. Engage in the unauthorized or unlicensed practice of any profession including, but not + limited to, financial, legal, medical/health, or related professional practices + + e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive + personal or private information about individuals without rights and consents required by + applicable laws + + f. Engage in or facilitate any action or generate any content that infringes, misappropriates, + or otherwise violates any third-party rights, including the outputs or results of any + products or services using the Llama Materials + + g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses + or do anything else that could disable, overburden, interfere with or impair the proper + working, integrity, operation, or appearance of a website or computer system + + 2. Engage in, promote, incite, facilitate, or assist in the planning or development of + activities that present a risk of death or bodily harm to individuals, including use of Meta + Llama 3 related to the following: + + a. Military, warfare, nuclear industries or applications, espionage, use for materials or + activities that are subject to the International Traffic Arms Regulations (ITAR) maintained + by the United States Department of State + b. Guns and illegal weapons (including weapon development) + c. Illegal drugs and regulated/controlled substances + d. Operation of critical infrastructure, transportation technologies, or heavy machinery + e. Self-harm or harm to others, including suicide, cutting, and eating disorders + f. Any content intended to incite or promote violence, abuse, or any infliction of bodily + harm to an individual + + 3. Intentionally deceive or mislead others, including use of Meta Llama 3 related to the + following: + + a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation + b. Generating, promoting, or furthering defamatory content, including the creation of + defamatory statements, images, or other content + c. Generating, promoting, or further distributing spam + d. Impersonating another individual without consent, authorization, or legal right + e. Representing that the use of Meta Llama 3 or outputs are human-generated + f. Generating or facilitating false online engagement, including fake reviews and other + means of fake online engagement + g. Fail to appropriately disclose to end users any known dangers of your AI system + + Please report any violation of this Policy, software “bug,” or other problems that could + lead to a violation of this Policy through one of the following means: + + * Reporting issues with the model: https://github.com/meta-llama/llama3 + * Reporting risky content generated by the model: developers.facebook.com/llama_output_feedback + * Reporting bugs and security concerns: facebook.com/whitehat/info + * Reporting violations of the Acceptable Use Policy or unlicensed uses of Meta Llama 3: + LlamaUseReport@meta.com + + ``` + +## ImageIo - [BSD 2-Clause "Simplified" License](https://github.com/imageio/imageio/blob/master/LICENSE) + + ``` + + Copyright (c) 2014-2022, imageio developers + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ``` + +## Iopath - [MIT License](https://github.com/facebookresearch/iopath/blob/main/LICENSE) + + ``` + MIT License + + Copyright (c) Facebook, Inc. and its affiliates. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## Loguru - [MIT License](https://github.com/Delgan/loguru/blob/master/LICENSE) + + ``` + + MIT License + + Copyright (c) 2017 + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## Mediapy - [Apache License 2.0](https://github.com/google/mediapy/blob/main/LICENSE) + + ``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## Nltk - [Apache License 2.0](https://github.com/nltk/nltk/blob/develop/LICENSE.txt) + + ``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## PEFT - [Apache License 2.0](https://github.com/huggingface/peft/blob/main/LICENSE) + + ``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## Pillow - [MIT License](https://github.com/python-pillow/Pillow/blob/main/LICENSE) + + ``` + + The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh and contributors + + Pillow is the friendly PIL fork. It is + + Copyright © 2010 by Jeffrey A. Clark and contributors + + Like PIL, Pillow is licensed under the open source MIT-CMU License: + + By obtaining, using, and/or copying this software and/or its associated + documentation, you agree that you have read, understood, and will comply + with the following terms and conditions: + + Permission to use, copy, modify and distribute this software and its + documentation for any purpose and without fee is hereby granted, + provided that the above copyright notice appears in all copies, and that + both that copyright notice and this permission notice appear in supporting + documentation, and that the name of Secret Labs AB or the author not be + used in advertising or publicity pertaining to distribution of the software + without specific, written prior permission. + + SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS + SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. + IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, + INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + PERFORMANCE OF THIS SOFTWARE. + + ``` + +## PyAV - [BSD 3-Clause "New" or "Revised" License](https://github.com/PyAV-Org/PyAV/blob/main/LICENSE.txt) + + ``` + + Copyright retained by original committers. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the project nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ``` + +## Pytorch_Retinaface - [MIT License](https://github.com/biubug6/Pytorch_Retinaface/blob/master/LICENSE.MIT) + + ``` + MIT License + + Copyright (c) 2019 + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + ``` + +## Sentencepiece - [Apache License 2.0](https://github.com/google/sentencepiece/blob/master/LICENSE) + + ``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## Termcolor - [MIT License](https://github.com/termcolor/termcolor/blob/main/COPYING.txt) + + ``` + Copyright (c) 2008-2011 Volvox Development Team + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + ``` + +## Transformers [Apache License 2.0](https://github.com/huggingface/transformers/blob/main/LICENSE) + + ``` + + Copyright 2018- The Hugging Face team. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## MoGe [MIT License](https://github.com/microsoft/MoGe/blob/main/LICENSE) + + ``` + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + ``` + +## Warp [Apache License 2.0](https://github.com/NVIDIA/warp/blob/main/LICENSE.md) + + ``` + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + ``` + +## Args - [MIT License](https://github.com/Taywee/args/blob/master/LICENSE) + + ``` + + Copyright (c) 2016-2024 Taylor C. Richberger and Pavel Belikov + + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + ``` + +## CUDA CMake GitHub Actions - [MIT License](https://github.com/ptheywood/cuda-cmake-github-actions/blob/master/LICENSE) + + ``` + + MIT License + + Copyright (c) 2021 Peter Heywood + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## Dear ImGui - [MIT License](https://github.com/ocornut/imgui/blob/master/LICENSE.txt) + + ``` + + The MIT License (MIT) + + Copyright (c) 2014-2024 Omar Cornut + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## Filesystem - BSD 3-Clause License + + ``` + + Copyright (c) 2016 Wenzel Jakob , + 2021-2023 Thomas Müller + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You are under no obligation whatsoever to provide any bug fixes, patches, or + upgrades to the features, functionality or performance of the source code + ("Enhancements") to anyone; however, if you choose to make your Enhancements + available either publicly, or directly to the author of this software, without + imposing a separate written license agreement for such Enhancements, then you + hereby grant the following license: a non-exclusive, royalty-free perpetual + license to install, use, modify, prepare derivative works, incorporate into + other computer software, distribute, and sublicense such enhancements or + derivative works thereof, in binary and source code form. + + ``` + +## fmt - [MIT License](https://github.com/fmtlib/fmt/blob/master/LICENSE) + + ``` + + Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + --- Optional exception to the license --- + + As an exception, if, as a result of your compiling your source code, portions + of this Software are embedded into a machine-executable object form of such + source code, you may redistribute such embedded portions in such object form + without including the above copyright and permission notices. + + ``` + +## GLFW - [zlib/libpng License](https://github.com/glfw/glfw/blob/master/LICENSE.md) + + ``` + + Copyright (c) 2002-2006 Marcus Geelnard + Copyright (c) 2006-2019 Camilla Löwy + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would + be appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not + be misrepresented as being the original software. + + 3. This notice may not be removed or altered from any source + distribution. + + ``` + +## JSON for Modern C++ - [MIT License](https://github.com/nlohmann/json/tree/develop/LICENSES) + + ``` + + MIT License + + Copyright (c) 2013-2021 Niels Lohmann + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + + +## PlayNE Equivalence - MIT License + + ``` + + MIT License + + Copyright (c) 2018 - Daniel Peter Playne + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ``` + +## pybind11_json - [BSD 3-Clause License](https://github.com/pybind/pybind11_json/blob/master/LICENSE) + + ``` + + BSD 3-Clause License + + Copyright (c) 2019, + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ``` + + +## stb_image - [MIT OR Public Domain](https://github.com/nothings/stb/blob/master/LICENSE) + + ``` + + This software is available under 2 licenses -- choose whichever you prefer. + ------------------------------------------------------------------------------ + ALTERNATIVE A - MIT License + Copyright (c) 2017 Sean Barrett + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + ------------------------------------------------------------------------------ + ALTERNATIVE B - Public Domain (www.unlicense.org) + This is free and unencumbered software released into the public domain. + Anyone is free to copy, modify, publish, use, compile, sell, or distribute this + software, either in source code form or as a compiled binary, for any purpose, + commercial or non-commercial, and by any means. + In jurisdictions that recognize copyright laws, the author or authors of this + software dedicate any and all copyright interest in the software to the public + domain. We make this dedication for the benefit of the public at large and to + the detriment of our heirs and successors. We intend this dedication to be an + overt act of relinquishment in perpetuity of all present and future rights to + this software under copyright law. + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + ``` + +## ImGuizmo - [MIT license](https://github.com/CedricGuillemet/ImGuizmo/blob/master/LICENSE) + ``` + The MIT License (MIT) + + Copyright (c) 2016 Cedric Guillemet + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + ``` + +## PCG - [Apache-2.0 License]() + ``` + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## OpenXR-SDK - [Apache License 2.0](https://github.com/KhronosGroup/OpenXR-SDK/blob/main/LICENSE) + ``` + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ``` + +## pybind11 - [BSD 3-Clause License](https://github.com/Tom94/pybind11/blob/master/LICENSE) + ``` + Copyright (c) 2016 Wenzel Jakob , All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + Please also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of + external contributions to this project including patches, pull requests, etc. + + ``` + +## tinylogger - [BSD-3-Clause license](https://github.com/Tom94/tinylogger/blob/master/LICENSE.md) + ``` + Copyright (c) 2018-2024, Thomas Müller + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder, the project, nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ``` + +## ZLIB DATA COMPRESSION LIBRARY - [zlib License](https://github.com/Tom94/zlib) + ``` + (C) 1995-2017 Jean-loup Gailly and Mark Adler + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + Jean-loup Gailly Mark Adler + jloup@gzip.org madler@alumni.caltech.edu + + If you use the zlib library in a product, we would appreciate *not* receiving + lengthy legal documents to sign. The sources are provided for free but without + warranty of any kind. The library has been entirely written by Jean-loup + Gailly and Mark Adler; it does not include third-party code. + + If you redistribute modified sources, we would appreciate that you include in + the file ChangeLog history information documenting your changes. Please read + the FAQ for more information on the distribution of modified source versions. + + ``` + +## GL3W - [Unlicense License](https://github.com/skeeto/opengl-demo/blob/master/UNLICENSE) + ``` + This is free and unencumbered software released into the public domain. + + Anyone is free to copy, modify, publish, use, compile, sell, or + distribute this software, either in source code form or as a compiled + binary, for any purpose, commercial or non-commercial, and by any + means. + + In jurisdictions that recognize copyright laws, the author or authors + of this software dedicate any and all copyright interest in the + software to the public domain. We make this dedication for the benefit + of the public at large and to the detriment of our heirs and + successors. We intend this dedication to be an overt act of + relinquishment in perpetuity of all present and future rights to this + software under copyright law. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + For more information, please refer to + ``` + + +## DLSS - [NVIDIA RTX SDKs LICENSE](https://github.com/NVIDIA/DLSS/blob/main/LICENSE.txt) + ``` + NVIDIA RTX SDKs LICENSE + + This license is a legal agreement between you and NVIDIA Corporation ("NVIDIA") and governs the use of the NVIDIA RTX software development kits, including the DLSS SDK, NGX SDK, RTXGI SDK, RTXDI SDK, RTX Video SDK, RTX Dynamic Vibrance SDK and/or NRD SDK, if and when made available to you under this license (in each case, the “SDK”). + This license can be accepted only by an adult of legal age of majority in the country in which the SDK is used. If you are under the legal age of majority, you must ask your parent or legal guardian to consent to this license. If you are entering this license on behalf of a company or other legal entity, you represent that you have legal authority and “you” will mean the entity you represent. + By using the SDK, you affirm that you have reached the legal age of majority, you accept the terms of this license, and you take legal and financial responsibility for the actions of your permitted users. + + You agree to use the SDK only for purposes that are permitted by (a) this license, and (b) any applicable law, regulation or generally accepted practices or guidelines in the relevant jurisdictions. + + 1. LICENSE. Subject to the terms of this license and the terms in the supplement attached, NVIDIA hereby grants you a non-exclusive, non-transferable license, without the right to sublicense (except as expressly provided in this license) to: + a. Install and use the SDK, + b. Modify and create derivative works of sample source code delivered in the SDK, and + c. Distribute any software and materials within the SDK, other than developer tools provided for your internal use, as incorporated in object code format into a software application subject to the distribution requirements indicated in this license. + + 2. DISTRIBUTION REQUIREMENTS. These are the distribution requirements for you to exercise the grants above: + a. An application must have material additional functionality, beyond the included portions of the SDK. + b. The following notice shall be included in modifications and derivative works of source code distributed: “This software contains source code provided by NVIDIA Corporation.” + c. You agree to distribute the SDK subject to the terms at least as protective as the terms of this license, including (without limitation) terms relating to the license grant, license restrictions and protection of NVIDIA’s intellectual property rights. Additionally, you agree that you will protect the privacy, security and legal rights of your application users. + d. You agree to notify NVIDIA in writing of any known or suspected distribution or use of the SDK not in compliance with the requirements of this license, and to enforce the terms of your agreements with respect to the distributed portions of the SDK. + + 3. AUTHORIZED USERS. You may allow employees and contractors of your entity or of your subsidiary(ies) to access and use the SDK from your secure network to perform work on your behalf. If you are an academic institution you may allow users enrolled or employed by the academic institution to access and use the SDK from your secure network. You are responsible for the compliance with the terms of this license by your authorized users. + + 4. LIMITATIONS. Your license to use the SDK is restricted as follows: + a. You may not reverse engineer, decompile or disassemble, or remove copyright or other proprietary notices from any portion of the SDK or copies of the SDK. + b. Except as expressly provided in this license, you may not copy, sell, rent, sublicense, transfer, distribute, modify, or create derivative works of any portion of the SDK. For clarity, you may not distribute or sublicense the SDK as a stand-alone product. + c. Unless you have an agreement with NVIDIA for this purpose, you may not indicate that an application created with the SDK is sponsored or endorsed by NVIDIA. + d. You may not bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the SDK. + e. You may not use the SDK in any manner that would cause it to become subject to an open source software license. As examples, licenses that require as a condition of use, modification, and/or distribution that the SDK be: (i) disclosed or distributed in source code form; (ii) licensed for the purpose of making derivative works; or (iii) redistributable at no charge. + f. Unless you have an agreement with NVIDIA for this purpose, you may not use the SDK with any system or application where the use or failure of the system or application can reasonably be expected to threaten or result in personal injury, death, or catastrophic loss. Examples include use in avionics, navigation, military, medical, life support or other life critical applications. NVIDIA does not design, test or manufacture the SDK for these critical uses and NVIDIA shall not be liable to you or any third party, in whole or in part, for any claims or damages arising from such uses. + g. You agree to defend, indemnify and hold harmless NVIDIA and its affiliates, and their respective employees, contractors, agents, officers and directors, from and against any and all claims, damages, obligations, losses, liabilities, costs or debt, fines, restitutions and expenses (including but not limited to attorney’s fees and costs incident to establishing the right of indemnification) arising out of or related to your use of the SDK outside of the scope of this license, or not in compliance with its terms. + + 5. UPDATES. NVIDIA may, at its option, make available patches, workarounds or other updates to this SDK. Unless the updates are provided with their separate governing terms, they are deemed part of the SDK licensed to you as provided in this license. Further, NVIDIA may, at its option, automatically update the SDK or other software in the system, except for those updates that you may opt-out via the SDK API. You agree that the form and content of the SDK that NVIDIA provides may change without prior notice to you. While NVIDIA generally maintains compatibility between versions, NVIDIA may in some cases make changes that introduce incompatibilities in future versions of the SDK. + + 6. PRE-RELEASE VERSIONS. SDK versions identified as alpha, beta, preview, early access or otherwise as pre-release may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability, and reliability standards relative to commercial versions of NVIDIA software and materials. You may use a pre-release SDK version at your own risk, understanding that these versions are not intended for use in production or business-critical systems. NVIDIA may choose not to make available a commercial version of any pre-release SDK. NVIDIA may also choose to abandon development and terminate the availability of a pre-release SDK at any time without liability. + + 7. THIRD-PARTY COMPONENTS. The SDK may include third-party components with separate legal notices or terms as may be described in proprietary notices accompanying the SDK. If and to the extent there is a conflict between the terms in this license and the third-party license terms, the third-party terms control only to the extent necessary to resolve the conflict. + + 8. OWNERSHIP. + + 8.1 NVIDIA reserves all rights, title and interest in and to the SDK not expressly granted to you under this license. NVIDIA and its suppliers hold all rights, title and interest in and to the SDK, including their respective intellectual property rights. The SDK is copyrighted and protected by the laws of the United States and other countries, and international treaty provisions. + + 8.2 Subject to the rights of NVIDIA and its suppliers in the SDK, you hold all rights, title and interest in and to your applications and your derivative works of the sample source code delivered in the SDK including their respective intellectual property rights. + + 9. FEEDBACK. You may, but are not obligated to, provide Feedback to NVIDIA. “Feedback” means all suggestions, fixes, modifications, feature requests or other feedback regarding the SDK. Feedback, even if designated as confidential by you, shall not create any confidentiality obligation for NVIDIA. If you provide Feedback, you hereby grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIA’s discretion. You will not give Feedback (i) that you have reason to believe is subject to any restriction that impairs the exercise of the grant stated in this section, such as third-party intellectual property rights or (ii) subject to license terms which seek to require any product incorporating or developed using such Feedback, or other intellectual property of NVIDIA or its affiliates, to be licensed to or otherwise shared with any third party. + + 10. NO WARRANTIES. THE SDK IS PROVIDED AS-IS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW NVIDIA AND ITS AFFILIATES EXPRESSLY DISCLAIM ALL WARRANTIES OF ANY KIND OR NATURE, WHETHER EXPRESS, IMPLIED OR STATUTORY, INCLUDING, BUT NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT THAT THE SDK WILL MEET YOUR REQUIREMENTS OR THAT THE OPERATION THEREOF WILL BE UNINTERRUPTED OR ERROR-FREE, OR THAT ALL ERRORS WILL BE CORRECTED. + + 11. LIMITATIONS OF LIABILITY. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW NVIDIA AND ITS AFFILIATES SHALL NOT BE LIABLE FOR ANY (I) SPECIAL, INCIDENTAL, PUNITIVE OR CONSEQUENTIAL DAMAGES, OR FOR DAMAGES FOR (A) ANY LOST PROFITS, PROJECT DELAYS, LOSS OF USE, LOSS OF DATA OR LOSS OF GOODWILL, OR (B) THE COSTS OF PROCURING SUBSTITUTE PRODUCTS, ARISING OUT OF OR IN CONNECTION WITH THIS LICENSE OR THE USE OR PERFORMANCE OF THE SDK, WHETHER SUCH LIABILITY ARISES FROM ANY CLAIM BASED UPON BREACH OF CONTRACT, BREACH OF WARRANTY, TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY OR ANY OTHER CAUSE OF ACTION OR THEORY OF LIABILITY, EVEN IF NVIDIA HAS PREVIOUSLY BEEN ADVISED OF, OR COULD REASONABLY HAVE FORESEEN, THE POSSIBILITY OF SUCH DAMAGES. IN NO EVENT WILL NVIDIA’S AND ITS AFFILIATES TOTAL CUMULATIVE LIABILITY UNDER OR ARISING OUT OF THIS LICENSE EXCEED US$10.00. THE NATURE OF THE LIABILITY OR THE NUMBER OF CLAIMS OR SUITS SHALL NOT ENLARGE OR EXTEND THIS LIMIT. + + 12. TERMINATION. Your rights under this license will terminate automatically without notice from NVIDIA if you fail to comply with any term and condition of this license or if you commence or participate in any legal proceeding against NVIDIA with respect to the SDK. NVIDIA may terminate this license with advance written notice to you, if NVIDIA decides to no longer provide the SDK in a country or, in NVIDIA’s sole discretion, the continued use of it is no longer commercially viable. Upon any termination of this license, you agree to promptly discontinue use of the SDK and destroy all copies in your possession or control. Your prior distributions in accordance with this license are not affected by the termination of this license. All provisions of this license will survive termination, except for the license granted to you. + + 13. APPLICABLE LAW. This license will be governed in all respects by the laws of the United States and of the State of Delaware, without regard to the conflicts of laws principles. The United Nations Convention on Contracts for the International Sale of Goods is specifically disclaimed. You agree to all terms of this license in the English language. The state or federal courts residing in Santa Clara County, California shall have exclusive jurisdiction over any dispute or claim arising out of this license. Notwithstanding this, you agree that NVIDIA shall still be allowed to apply for injunctive remedies or urgent legal relief in any jurisdiction. + + 14. NO ASSIGNMENT. This license and your rights and obligations thereunder may not be assigned by you by any means or operation of law without NVIDIA’s permission. Any attempted assignment not approved by NVIDIA in writing shall be void and of no effect. NVIDIA may assign, delegate or transfer this license and its rights and obligations, and if to a non-affiliate you will be notified. + + 15. EXPORT. The SDK is subject to United States export laws and regulations. You agree to comply with all applicable U.S. and international export laws, including the Export Administration Regulations (EAR) administered by the U.S. Department of Commerce and economic sanctions administered by the U.S. Department of Treasury’s Office of Foreign Assets Control (OFAC). These laws include restrictions on destinations, end-users and end-use. By accepting this license, you confirm that you are not currently residing in a country or region currently embargoed by the U.S. and that you are not otherwise prohibited from receiving the SDK. + + 16. GOVERNMENT USE. The SDK, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R. 2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R. 12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense; (ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement; and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense. In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R. 52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing. + + 17. NOTICES. You agree that any notices that NVIDIA sends you electronically, such as via email, will satisfy any legal communication requirements. Please direct your legal notices or other correspondence to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department. + + 18. ENTIRE AGREEMENT. This license is the final, complete and exclusive agreement between the parties relating to the subject matter of this license and supersedes all prior or contemporaneous understandings and agreements relating to this subject matter, whether oral or written. If any court of competent jurisdiction determines that any provision of this license is illegal, invalid or unenforceable, the remaining provisions will remain in full force and effect. Any amendment or waiver under this license shall be in writing and signed by representatives of both parties. + + 19. LICENSING. If the distribution terms in this license are not suitable for your organization, or for any questions regarding this license, please contact NVIDIA at nvidia-rtx-license-questions@nvidia.com. + (v. March 14, 2024) + + + NVIDIA RTX SUPPLEMENT TO SOFTWARE LICENSE AGREEMENT FOR NVIDIA SOFTWARE DEVELOPMENT KITS + The terms in this supplement govern your use of the NVIDIA RTX SDKs, including the DLSS SDK, NGX SDK, RTXGI SDK, RTXDI SDK, RTX Video SDK, RTX Dynamic Vibrance SDK and/or NRD SDK, if and when made available to you (in each case, the “SDK”) under the terms of your license agreement (“Agreement”) as modified by this supplement. Capitalized terms used but not defined below have the meaning assigned to them in the Agreement. + This supplement is an exhibit to the Agreement and is incorporated as an integral part of the Agreement. In the event of conflict between the terms in this supplement and the terms in the Agreement, the terms in this supplement govern. + + 1. Interoperability. Your applications that incorporate, or are based on, the SDK must be fully interoperable with compatible GPU hardware products designed by NVIDIA or its affiliates. Further, the DLSS SDK, NGX SDK and RTX Dynamic Vibrance SDK are licensed for you to develop applications only for their use in systems with NVIDIA GPUs. + + 2. Game License. You may, but are not obligated to, provide your game or related content (“Game Content”) to NVIDIA. If you provide Game Content, you hereby grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, worldwide, royalty-free, fully paid-up license, to use the Game Content to improve NVIDIA DLSS SDK and DLSS Model Training. + + 3. Limitations for the DLSS SDK, NGX SDK,RTX Video SDK and RTX Dynamic Vibrance SDK. Your applications that incorporate, or are based on, the DLSS SDK, NGX SDK, RTX Video SDK or RTX Dynamic Vibrance SDK may be deployed in a cloud service that runs on systems that consume NVIDIA vGPU software, and any other cloud service use of such SDKs or their functionality is outside of the scope of the Agreement. For the purpose of this section, cloud services include application service providers or service bureaus, operators of hosted/virtual system environments, or hosting, time sharing or providing any other type of service to others. + + 4. Notification for the DLSS SDK, NGX SDK and RTX Dynamic Vibrance SDK. You are required to notify NVIDIA prior to commercial release of an application (including a plug-in to a commercial application) that incorporates, or is based on, the DLSS SDK, NGX SDK or RTX Dynamic Vibrance SDK. Please send notifications to: https://developer.nvidia.com/sw-notification and provide the following information in the email: company name, publisher and developer name, NVIDIA SDK used, application name, platform (i.e. PC, Linux), scheduled ship date, and weblink to product/video. + + 5. Audio and Video Encoders and Decoders. You acknowledge and agree that it is your sole responsibility to obtain any additional third-party licenses required to make, have made, use, have used, sell, import, and offer for sale your products or services that include or incorporate any third-party software and content relating to audio and/or video encoders and decoders from, including but not limited to, Microsoft, Thomson, Fraunhofer IIS, Sisvel S.p.A., MPEG-LA, and Coding Technologies. NVIDIA does not grant to you under this Agreement any necessary patent or other rights with respect to any audio and/or video encoders and decoders. + + 6. SDK Terms. + 6.1 Over the Air Updates. By installing or using the SDK you agree that NVIDIA can make over-the-air updates of the SDK in systems that have the SDK installed, including (without limitation) for quality, stability or performance improvements or to support new hardware. + + 6.2 SDK Integration. If you publicly release a DLSS integration in an end user game or application that presents material stability, performance, image quality, or other technical issues impacting the user experience, you will work to quickly address the integration issues. In the case issues are not addressed, NVIDIA reserves the right, as a last resort, to temporarily disable the DLSS integration until the issues can be fixed. + + 7. Marketing. + 7.1 Marketing Activities. Your license to the SDK(s) under the Agreement is subject to your compliance with the following marketing terms: + (a) Identification by You in the DLSS SDK, NGX SDK, RTX Video SDK or Dynamic Vibrance SDK. During the term of the Agreement, NVIDIA agrees that you may identify NVIDIA on your websites, printed collateral, trade-show displays and other retail packaging materials, as the supplier of the DLSS SDK, NGX SDK, RTX Video SDK or Dynamic Vibrance SDK for the applications that were developed with use of such SDKs, provided that all such references to NVIDIA will be subject to NVIDIA's prior review and written approval, which will not be unreasonably withheld or delayed. + (b) NVIDIA Trademark Placement in Applications with the DLSS SDK, NGX SDK, or RTX Video SDK. For applications that incorporate the DLSS SDK or NGX SDK or portions thereof, you must attribute the use of the applicable SDK and include the NVIDIA Marks on splash screens, in the about box of the application (if present), and in credits for game applications. + (c) NVIDIA Trademark Placement in Applications with a licensed SDK, other than the DLSS SDK, RTX Video SDK or NGX SDK. For applications that incorporates and/or makes use of a licensed SDK, other than the DLSS SDK, RTX Video SDK or NGX SDK, you must attribute the use of the applicable SDK and include the NVIDIA Marks on the credit screen for applications that have such credit screen, or where a credit screen is not present prominently in end user documentation for the application. + (d) Identification by NVIDIA in the DLSS SDK, NGX SDK, RTX Video SDK or Dynamic Vibrance SDK. You agree that NVIDIA may identify you on NVIDIA's websites, printed collateral, trade-show displays, and other retail packaging materials as an individual or entity that produces products and services which incorporate the DLSS SDK, NGX SDK, RTX Video SDK or Dynamic Vibrance SDK as applicable. To the extent that you provide NVIDIA with input or usage requests with regard to the use of your logo or materials, NVIDIA will use commercially reasonable efforts to comply with such requests. For the avoidance of doubt, NVIDIA’s rights pursuant to this section shall survive any expiration or termination of the Agreement with respect to existing applications which incorporate the DLSS SDK, RTX Video SDK or NGX SDK. + (e) Applications Marketing Material in the DLSS SDK, NGX SDK, RTX Video SDK or Dynamic Vibrance SDK. You may provide NVIDIA with screenshots, imagery, and video footage of applications representative of your use of the NVIDIA DLSS SDK or NGX SDKs in your application (collectively, “Assets”). You hereby grant to NVIDIA the right to create and display self-promotional demo materials using the Assets, and after release of the application to the public to distribute, sub-license, and use the Assets to promote and market the NVIDIA RTX SDKs. To the extent you provide NVIDIA with input or usage requests with regard to the use of your logo or materials, NVIDIA will use commercially reasonable efforts to comply with such requests. For the avoidance of doubt, NVIDIA’s rights pursuant to this section shall survive any termination of the Agreement with respect to applications which incorporate the NVIDIA RTX SDK. + + 7.2 Trademark Ownership and Licenses. Trademarks are owned and licenses as follows: + (a) Ownership of Trademarks. Each party owns the trademarks, logos, and trade names (collectively "Marks") for their respective products or services, including without limitation in applications, and the NVIDIA RTX SDKs. Each party agrees to use the Marks of the other only as permitted in this exhibit. + + (b) Trademark License to NVIDIA. You grant to NVIDIA a non-exclusive, non-sub licensable, non-transferable (except as set forth in the assignment provision of the Agreement), worldwide license to refer to you and your applications, and to use your Marks on NVIDIA's marketing materials and on NVIDIA's website (subject to any reasonable conditions of you) solely for NVIDIA’s marketing activities set forth in this exhibit Sections 7 (d)-(e) above. NVIDIA will follow your specifications for your Marks as to style, color, and typeface as reasonably provided to NVIDIA. + + (c) Trademark License to You. NVIDIA grants to you a non-exclusive, non-sub licensable, non-transferable (except as set forth in the assignment provision of the Agreement), worldwide license, subject to the terms of this exhibit and the Agreement, to use NVIDIA RTX™, NVIDIA GeForce RTX™ in combination with GeForce products, and/or NVIDIA Quadro RTX™ in combination with Quadro products (collectively, the “NVIDIA Marks”) on your marketing materials and on your website (subject to any reasonable conditions of NVIDIA) solely for your marketing activities set forth in this exhibit Sections 7.1 (a)-(c) above. For the avoidance of doubt, you will not and will not permit others to use any NVIDIA Mark for any other goods or services, or in a way that tarnishes, degrades, disparages or reflects adversely any of the NVIDIA Marks or NVIDIA’s business or reputation, or that dilutes or otherwise harms the value, reputation or distinctiveness of or NVIDIA’s goodwill in any NVIDIA Mark. In addition to the termination rights set forth in the Agreement, NVIDIA may terminate this trademark license at any time upon written notice to you. You will follow NVIDIA's use guidelines and specifications for NVIDIA's Marks as to style, color and typeface as provided in NVIDIA Marks and submit a sample of each proposed use of NVIDIA's Marks at least two (2) weeks prior to the desired implementation of such use to obtain NVIDIA's prior written approval (which approval will not be unreasonably withheld or delayed). If NVIDIA does not respond within ten (10) business days of your submission of such sample, the sample will be deemed unapproved. All goodwill associated with use of NVIDIA Marks will inure to the sole benefit of NVIDIA. For the RTX Video SDK, contact NVIDIA at nvidia-rtx-video-sdk-license-questions@nvidia.com. + + 7.3 Use Guidelines. Use of the NVIDIA Marks is subject to the following guidelines: + (a) Business Practices. You covenant that you will: (a) conduct business with respect to NVIDIA’s products in a manner that reflects favorably at all times on the good name, goodwill and reputation of such products; (b) avoid deceptive, misleading or unethical practices that are detrimental to NVIDIA, its customers, or end users; (c) make no false or misleading representations with regard to NVIDIA or its products; and (d) not publish or employ or cooperate in the publication or employment of any misleading or deceptive advertising or promotional materials. + + (b) No Combination Marks or Similar Marks. You agree not to (a) combine NVIDIA Marks with any other content without NVIDIA’s prior written approval, or (b) use any other trademark, trade name, or other designation of source which creates a likelihood of confusion with NVIDIA Marks. + + (v. March 14, 2024) + ``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..5b94933b363f4c5edd105df2aec5f785ce01b7e4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,51 @@ +# How to Contribute + +We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them. + +## Code Reviews + +All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. + +## Signing Your Work + +* We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. + + * Any contribution which contains commits that are not Signed-Off will not be accepted. + +* To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: + ```bash + $ git commit -s -m "Add cool feature." + ``` + This will append the following to your commit message: + ``` + Signed-off-by: Your Name + ``` + +* Full text of the DCO: + + ``` + Developer Certificate of Origin + Version 1.1 + + Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + 1 Letterman Drive + Suite D4700 + San Francisco, CA, 94129 + + Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. + ``` + + ``` + Developer's Certificate of Origin 1.1 + + By making a contribution to this project, I certify that: + + (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or + + (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or + + (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. + + (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. + ``` \ No newline at end of file diff --git a/INSTALL.md b/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..5d2640b115a39cc7387467e0ece46ae24e9f766c --- /dev/null +++ b/INSTALL.md @@ -0,0 +1,48 @@ +## Environment setup + +Cosmos runs only on Linux systems. We have tested the installation with Ubuntu 24.04, 22.04, and 20.04. +Cosmos requires the Python version to be `3.10.x`. Please also make sure you have `conda` installed ([instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)). + +### Inference + +The below commands creates the `cosmos-predict1` conda environment and installs the dependencies for inference: +```bash +# Create the cosmos-predict1 conda environment. +conda env create --file cosmos-predict1.yaml +# Activate the cosmos-predict1 conda environment. +conda activate cosmos-predict1 +# Install the dependencies. +pip install -r requirements.txt +# Patch Transformer engine linking issues in conda environments. +ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/ +ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10 +# Install Transformer engine. +pip install transformer-engine[pytorch]==1.12.0 +# Install Apex for inference. +git clone https://github.com/NVIDIA/apex +CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex +# Install MoGe for inference. +pip install git+https://github.com/microsoft/MoGe.git +``` + +* Alternatively, if you are more familiar with a containerized environment, you can build the dockerfile and run it to get an environment with all the packages pre-installed. + This requires docker to be already present on your system with the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) installed. + + ```bash + docker build -f Dockerfile . -t nvcr.io/$USER/cosmos-predict1:latest + ``` + + Note: In case you encounter permission issues while mounting local files inside the docker, you can share the folders from your current directory to all users (including docker) using this helpful alias `alias share='sudo chown -R ${USER}:users $PWD && sudo chmod g+w $PWD'` before running the docker. + + +You can test the environment setup for inference with +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py +``` + +### Post-training + + +🛠️ *Under construction* 👷 + +Stay tuned! diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c9ad980682246bd6ab0d2bae82232be6dbdcbd4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b2fa952326280b0755f8448632fd43dadfe55ccf --- /dev/null +++ b/README.md @@ -0,0 +1,248 @@ +--- +title: GEN3C Project (from DGX Station) +emoji: 🫁 +colorFrom: green +colorTo: blue +sdk: docker +image: elungky/gen3c:latest +# app_port: 7860 # Remove or comment this line as the image handles the port +--- + +# GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control + + + +https://github.com/user-attachments/assets/247e1719-9f8f-4504-bfa3-f9706bd8682d + + +**GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control**
+[Xuanchi Ren*](https://xuanchiren.com/), +[Tianchang Shen*](https://www.cs.toronto.edu/~shenti11/), +[Jiahui Huang](https://huangjh-pub.github.io/), +[Huan Ling](https://www.cs.toronto.edu/~linghuan/), +[Yifan Lu](https://yifanlu0227.github.io/), +[Merlin Nimier-David](https://merlin.nimierdavid.fr/), +[Thomas Müller](https://research.nvidia.com/person/thomas-muller), +[Alexander Keller](https://research.nvidia.com/person/alex-keller), +[Sanja Fidler](https://www.cs.toronto.edu/~fidler/), +[Jun Gao](https://www.cs.toronto.edu/~jungao/)
+\* indicates equal contribution
+**[Paper](https://arxiv.org/pdf/2503.03751), [Project Page](https://research.nvidia.com/labs/toronto-ai/GEN3C/), [HuggingFace](https://huggingface.co/collections/nvidia/gen3c-683f3f9540a8f9c98cf46a8d)** + +Abstract: We present GEN3C, a generative video model with precise Camera Control and +temporal 3D Consistency. Prior video models already generate realistic videos, +but they tend to leverage little 3D information, leading to inconsistencies, +such as objects popping in and out of existence. Camera control, if implemented +at all, is imprecise, because camera parameters are mere inputs to the neural +network which must then infer how the video depends on the camera. In contrast, +GEN3C is guided by a 3D cache: point clouds obtained by predicting the +pixel-wise depth of seed images or previously generated frames. When generating +the next frames, GEN3C is conditioned on the 2D renderings of the 3D cache with +the new camera trajectory provided by the user. Crucially, this means that +GEN3C neither has to remember what it previously generated nor does it have to +infer the image structure from the camera pose. The model, instead, can focus +all its generative power on previously unobserved regions, as well as advancing +the scene state to the next frame. Our results demonstrate more precise camera +control than prior work, as well as state-of-the-art results in sparse-view +novel view synthesis, even in challenging settings such as driving scenes and +monocular dynamic video. Results are best viewed in videos. + +For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). +For any other questions related to the model, please contact Xuanchi, Tianchang or Jun. + +## News +- 2025-06-06 Code and model released! In a future update, we plan to include the pipeline for jointly predicting depth and camera pose from video, as well as a driving-finetuned model. Stay tuned! + +## Installation +Please follow the "Inference" section in [INSTALL.md](INSTALL.md) to set up your environment. + +## Inference + +### Download checkpoints +1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token (if you haven't done so already). Set the access token to `Read` permission (default is `Fine-grained`). + +2. Log in to Hugging Face with the access token: + ```bash + huggingface-cli login + ``` + +3. Download the GEN3C model weights from [Hugging Face](https://huggingface.co/nvidia/GEN3C-Cosmos-7B): + ```bash + CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_gen3c_checkpoints.py --checkpoint_dir checkpoints + ``` + +### Interactive GUI usage + +
+ GEN3C interactive GUI +
+ +GEN3C can be used through an interactive GUI, allowing to visualize the inputs in 3D, author arbitrary camera trajectories, and start inference from a single window. +Please see the [dedicated instructions](gui/README.md). + + +### Command-line usage +GEN3C supports both images and videos as input. Below are examples of running GEN3C on single images and videos with predefined camera trajectory patterns. + +### Example 1: Single Image to Video Generation + +#### Single GPU +Generate a 121-frame video from a single image: +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_single_image.py \ + --checkpoint_dir checkpoints \ + --input_image_path assets/diffusion/000000.png \ + --video_save_name test_single_image \ + --guidance 1 \ + --foreground_masking +``` + +#### Multi-GPU (8 GPUs) +```bash +NUM_GPUS=8 +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_single_image.py \ + --checkpoint_dir checkpoints \ + --input_image_path assets/diffusion/000000.png \ + --video_save_name test_single_image_multigpu \ + --num_gpus ${NUM_GPUS} \ + --guidance 1 \ + --foreground_masking +``` + +#### Additional Options +- To generate longer videos autoregressively, specify the number of frames using `--num_video_frames`. The number of frames must follow the pattern: 121 * N - 1 (e.g., 241, 361, etc.) +- To save buffer images alongside the output video, add the `--save_buffer` flag +- You can control camera trajectories using `--trajectory`, `--camera_rotation`, and `--movement_distance` arguments. See the "Camera Movement Options" section below for details. + +#### Camera Movement Options + +##### Trajectory Types +The `--trajectory` argument controls the path the camera takes during video generation. Available options: + +| Option | Description | +|--------|-------------| +| `left` | Camera moves to the left (default) | +| `right` | Camera moves to the right | +| `up` | Camera moves upward | +| `down` | Camera moves downward | +| `zoom_in` | Camera moves closer to the scene | +| `zoom_out` | Camera moves away from the scene | +| `clockwise` | Camera moves in a clockwise circular path | +| `counterclockwise` | Camera moves in a counterclockwise circular path | + +##### Camera Rotation Modes +The `--camera_rotation` argument controls how the camera rotates during movement. Available options: + +| Option | Description | +|--------|-------------| +| `center_facing` | Camera always rotates to look at the (estimated) center of the scene (default) | +| `no_rotation` | Camera maintains its original orientation while moving | +| `trajectory_aligned` | Camera rotates to align with the direction of movement | + +##### Movement Distance +The `--movement_distance` argument controls how far the camera moves from its initial position. The default value is 0.3. A larger value will result in more dramatic camera movement, while a smaller value will create more subtle movement. + +##### GPU Memory Requirements + +We have tested GEN3C only on H100 and A100 GPUs. For GPUs with limited memory, you can fully offload all models by appending the following flags to your command: + +```bash +--offload_diffusion_transformer \ +--offload_tokenizer \ +--offload_text_encoder_model \ +--offload_prompt_upsampler \ +--offload_guardrail_models \ +--disable_guardrail \ +--disable_prompt_encoder +``` +Maximum observed memory during inference with full offloading: ~43GB. Note: Memory usage may vary depending on system specifications and is provided for reference only. + + +### Example 2: Video to Video Generation +For video input, GEN3C requires additional depth information, camera intrinsics, and extrinsics. These can be obtained using your choice of SLAM packages. For testing purposes, we provide example data. + +First, you need to download the test samples: +```bash +# Download test samples from Hugging Face +huggingface-cli download nvidia/GEN3C-Testing-Example --repo-type dataset --local-dir assets/diffusion/dynamic_video_samples +``` + +#### Single GPU +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_dynamic.py \ + --checkpoint_dir checkpoints \ + --input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \ + --video_save_name test_dynamic_video \ + --guidance 1 +``` + +#### Multi-GPU (8 GPUs) +```bash +NUM_GPUS=8 +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_dynamic.py \ + --checkpoint_dir checkpoints \ + --input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \ + --video_save_name test_dynamic_video_multigpu \ + --num_gpus ${NUM_GPUS} \ + --guidance 1 +``` + +## Gallery + +- **GEN3C** can be easily applied to video/scene creation from a single image +
+ +
+ +- ... or sparse-view images (we use 5 images here) +
+ +
+ + +- .. and dynamic videos +
+ +
+ +## Acknowledgement +Our model is based on [NVIDIA Cosmos](https://github.com/NVIDIA/Cosmos) and [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid). + +We are also grateful to several other open-source repositories that we drew inspiration from or built upon during the development of our pipeline: +- [MoGe](https://github.com/microsoft/MoGe) +- [TrajectoryCrafter](https://github.com/TrajectoryCrafter/TrajectoryCrafter) +- [DimensionX](https://github.com/wenqsun/DimensionX) +- [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2) +- [Video Depth Anything](https://github.com/DepthAnything/Video-Depth-Anything) + +## Citation +``` + @inproceedings{ren2025gen3c, + title={GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control}, + author={Ren, Xuanchi and Shen, Tianchang and Huang, Jiahui and Ling, Huan and + Lu, Yifan and Nimier-David, Merlin and Müller, Thomas and Keller, Alexander and + Fidler, Sanja and Gao, Jun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2025} +} +``` + +## License and Contact + +This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use. + + +GEN3C source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0). + +GEN3C models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). +======= +title: Gen3c +emoji: 🌍 +colorFrom: indigo +colorTo: blue +sdk: docker +pinned: false +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +>>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986 diff --git a/assets/demo_1.gif b/assets/demo_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..9a54e100aefeb6ed547109385025c22d0322ad29 --- /dev/null +++ b/assets/demo_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6162366c56277d084b05a37c617e2994ba75285d421e203556dcff08128b32b +size 14678966 diff --git a/assets/demo_2.gif b/assets/demo_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..54715e1e65fb6428bded8fe88526f33f22608a62 --- /dev/null +++ b/assets/demo_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e765e71d3016c6e314b6403f82313a1df42f68f6fb0f9416f197d82e0710f27e +size 10573280 diff --git a/assets/demo_3.gif b/assets/demo_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..36cd315a9e756bdd237b6924ff7e0e671bf3d406 --- /dev/null +++ b/assets/demo_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c4cf4a4bf62daf03b25ac66c2c3693adbf7cd459e55d3481a65a9ff4a9d09d9 +size 35276047 diff --git a/assets/demo_dynamic.gif b/assets/demo_dynamic.gif new file mode 100644 index 0000000000000000000000000000000000000000..f96dde75172a618a3c2b2aacd4a276e43b1f4185 --- /dev/null +++ b/assets/demo_dynamic.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:174faba45ae701eaa432dd14de1297c0479b6c0b832adbc211cbb529fbec6c61 +size 24517788 diff --git a/assets/diffusion/000000.png b/assets/diffusion/000000.png new file mode 100644 index 0000000000000000000000000000000000000000..7d531d6587b9cb68cb9a77d5be1ad709027c025b --- /dev/null +++ b/assets/diffusion/000000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e6eab7548c2ede900f8b504a5cef981e0cd0ec38af90dbea3f0db860e002c3 +size 1326071 diff --git a/assets/diffusion/000001.png b/assets/diffusion/000001.png new file mode 100644 index 0000000000000000000000000000000000000000..d754ec6803d60186de066118da8406dad11af7ef --- /dev/null +++ b/assets/diffusion/000001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abe310078829c9e1375ac30c7c270c84c8f68a09f3857bd35c7a5754f3326151 +size 1131209 diff --git a/assets/diffusion/000002.png b/assets/diffusion/000002.png new file mode 100644 index 0000000000000000000000000000000000000000..1f3f5f0279e10e718a3478d795db735cbece9d5f --- /dev/null +++ b/assets/diffusion/000002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ad89b53e9fafed0d8eefd1cfc7cc4889c5d2f510ed32d5247c5adab4cb0c622 +size 789185 diff --git a/assets/diffusion/000003.png b/assets/diffusion/000003.png new file mode 100644 index 0000000000000000000000000000000000000000..e2999ff690a749007b70b5e5a25ee3a21c04ff35 --- /dev/null +++ b/assets/diffusion/000003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22f39915f1b277e70683befbc18ac5859c65c3d389e4dbb5127a539a411fec54 +size 1105958 diff --git a/assets/diffusion/000004.png b/assets/diffusion/000004.png new file mode 100644 index 0000000000000000000000000000000000000000..20f4fb80c925e51e3c31a597107b3636ea9851c6 --- /dev/null +++ b/assets/diffusion/000004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2f957208849c0f86b89545734bb7b243868b574554cb6aeed248b04e7234ad4 +size 1262412 diff --git a/assets/diffusion/000005.png b/assets/diffusion/000005.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa49c43c45cdb634da5d424fb8c882be31cb354 --- /dev/null +++ b/assets/diffusion/000005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:267f6ae47d0e2aebda89fac5416bc0915855043131d0d8d8a4fc9506cabd4681 +size 1364198 diff --git a/assets/diffusion/000006.png b/assets/diffusion/000006.png new file mode 100644 index 0000000000000000000000000000000000000000..668af465ce603e33788278b460dfda72ed308b1b --- /dev/null +++ b/assets/diffusion/000006.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b6fd098366bcd54bd21a5707ae6d9f78d74c2eefcfbb6919569c0d1741d837f +size 1207409 diff --git a/assets/diffusion/000007.png b/assets/diffusion/000007.png new file mode 100644 index 0000000000000000000000000000000000000000..ac9a6a0a297bcebceeea924b5db0255d167ef141 --- /dev/null +++ b/assets/diffusion/000007.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:334733b7428f9521e625a8b310770fbba3e4616ccbe0af625d07e2b065e6e9ad +size 1150728 diff --git a/assets/diffusion/000008.png b/assets/diffusion/000008.png new file mode 100644 index 0000000000000000000000000000000000000000..677f6afcb6963858a98ebb2070e220bb19ad41af --- /dev/null +++ b/assets/diffusion/000008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7eae1abb3343c1e11f4e42172eba85eeed0fb2a5f7701a42e5003cf84f1696cd +size 1684291 diff --git a/assets/diffusion/000009.png b/assets/diffusion/000009.png new file mode 100644 index 0000000000000000000000000000000000000000..e19b55a92abc9c737fb23a6b493f5a22cfd38e0a --- /dev/null +++ b/assets/diffusion/000009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a5c5711d41f56bb307ef6020d0dffec9ce2297bda9ef9ae465237d8347adb34 +size 603167 diff --git a/assets/diffusion/000010.png b/assets/diffusion/000010.png new file mode 100644 index 0000000000000000000000000000000000000000..341aad51799c0111c65c58b4bb0e07209e0be04a --- /dev/null +++ b/assets/diffusion/000010.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4d32f1d1c6d427e421d6f4478d4c2c697cb0406a18ecc3b8ebeeb2a0cbba7f5 +size 1184019 diff --git a/assets/diffusion/000011.png b/assets/diffusion/000011.png new file mode 100644 index 0000000000000000000000000000000000000000..72d11ac239d063aa53298ec1040fa2f27c7735a7 --- /dev/null +++ b/assets/diffusion/000011.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e352d7435d3b313fcc47efd9bd0dc6e0dd5d5e8af8c50e965c57987bee1c94ec +size 944420 diff --git a/assets/diffusion/000012.png b/assets/diffusion/000012.png new file mode 100644 index 0000000000000000000000000000000000000000..c685fc6bfe8c6730b007ddb762bffd3c51962a70 --- /dev/null +++ b/assets/diffusion/000012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b672d43521890b2852976a0c12828ad16b9288277efff6c41189dc0c04c9c6e1 +size 1098037 diff --git a/assets/diffusion/000013.png b/assets/diffusion/000013.png new file mode 100644 index 0000000000000000000000000000000000000000..6fd722831d73a54f25f9dd20014e91f72aee68d6 --- /dev/null +++ b/assets/diffusion/000013.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eab3a655213eede094889bab94313e1cef142b811429bee9e0f3420c2b013105 +size 1243979 diff --git a/assets/diffusion/000014.png b/assets/diffusion/000014.png new file mode 100644 index 0000000000000000000000000000000000000000..432386657cc4c969a1fc052ce7c1e3d2109beee8 --- /dev/null +++ b/assets/diffusion/000014.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb014db53082677aca35a3fc27daa1f306452c5cb7130a4ed6468cae144a0b63 +size 1351667 diff --git a/assets/diffusion/000015.png b/assets/diffusion/000015.png new file mode 100644 index 0000000000000000000000000000000000000000..2c76996a58c78c95bec945bb0f0c11777bad0989 --- /dev/null +++ b/assets/diffusion/000015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6ac0d4e7eb6d4dbc3ae997fafc28721b716db092aaa52ede11e4d87b3e9b20d +size 1494431 diff --git a/checkpoints/README.md b/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..726899abdbae8de94885b0c5bc111291fb8dce7a --- /dev/null +++ b/checkpoints/README.md @@ -0,0 +1,4 @@ + +### Checkpoint directory + +Model checkpoints will be downloaded to this directory. diff --git a/cosmos-predict1.yaml b/cosmos-predict1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0722589d77b183cdd7b865227b2e0cb934e27088 --- /dev/null +++ b/cosmos-predict1.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# conda env create --file cosmos-predict1.yaml +name: cosmos-predict1 +channels: + - conda-forge +dependencies: + - python=3.10 + - pip=25.0 + - cmake + - ninja + - gcc=12.4.0 + - gxx=12.4.0 + - cuda=12.4 + - cuda-nvcc=12.4 + - cuda-toolkit=12.4 diff --git a/cosmos_predict1/__init__.py b/cosmos_predict1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/__init__.py b/cosmos_predict1/autoregressive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py b/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py new file mode 100644 index 0000000000000000000000000000000000000000..f0df8f71dfd5d79142685210de71fd2c45e87f5c --- /dev/null +++ b/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py @@ -0,0 +1,352 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import math +import os +from typing import Optional + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as torchvision_F +import wandb +from einops import rearrange +from megatron.core import parallel_state +from torch.distributed import get_process_group_ranks + +from cosmos_predict1.autoregressive.utils.parallel import ( + broadcast_data_batch_in_tp_cp_group, + gather_batch_from_cp_ranks, + get_batch_on_this_cp_rank, +) +from cosmos_predict1.callbacks.every_n import EveryN +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor: + _, _, h, w = image.shape + new_h, new_w = int(resize_factor * h), int(resize_factor * w) + return torchvision_F.resize(image, (new_h, new_w)) + + +class VideoSamplingTeacherForcing(EveryN): + def __init__( + self, + every_n: int, + step_size: int = 1, + video_latent_shape: list = [6, 24, 40], + num_frames_to_display: int = 4, + save_folder: Optional[str] = None, + num_file_to_log: int = 8, + ): + r""" + This callback enables us to perform teacher forcing inference on the training data. + By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model + to predict the next tokens. The predicted next tokens are then visualized. This does not perform + autoregressive sampling. + We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast. + + Args: + every_n (int): Call this callback every_n steps + step_size (int): Number of steps taken for gradient accumulation. Global iteration number is + iteration // self.step_size + video_latent_shape (list): Shape of the video latent + num_frames_to_display (int): Number of frames to subsample for displaying in wandb + save_folder (str): Name of the local folder to save the video + num_file_to_log (int): Number of files to upload to wandb + """ + super().__init__(every_n, step_size) + self.save_folder = save_folder if save_folder else self.__class__.__name__ + self.video_latent_shape = video_latent_shape + self.num_frames_to_display = num_frames_to_display + self.num_file_to_log = num_file_to_log + self.rank = distributed.get_rank() + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + config_job = self.config.job + self.local_dir = f"{config_job.path_local}/{self.save_folder}" + if self.rank == 0: + os.makedirs(self.local_dir, exist_ok=True) + log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}") + + @torch.inference_mode() + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + # Tokenize the data + + broadcast_data_batch_in_tp_cp_group(data_batch) + + input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key] + + dataset_name = data_batch.get("dataset_name", None) + if dataset_name is not None and dataset_name.startswith("image"): + # we disable the callback if the input video is an image batch + log.info(f"dataset_name is {dataset_name}, skip this callback") + return + + # get the caption + captions = data_batch.get("caption", None) + + # get the context embedding and mask + context = data_batch.get("context", None) + context_mask = data_batch.get("context_mask", None) + if context is not None: + context = misc.to(context, "cuda").detach().clone() + if context_mask is not None: + context_mask = misc.to(context_mask, "cuda").detach().clone() + # get the action + action = data_batch.get("action", None) + if action is not None: + action = misc.to(action, "cuda").detach().clone() + + # Input tokens + tokens, _ = model.tokenizer.tokenize(data_batch) + tokens = misc.to(tokens, "cuda").detach().clone() + skip_save_file = False + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + if self.rank != min(get_process_group_ranks(cp_group)): + skip_save_file = True + tokens = get_batch_on_this_cp_rank(tokens) + if parallel_state.get_tensor_model_parallel_world_size() > 1: + # Turn on TP + tp_group = parallel_state.get_tensor_model_parallel_group() + if self.rank != min(get_process_group_ranks(tp_group)): + skip_save_file = True + tokens_encoded_in_train = output_batch["encode_tokens"].detach() + percent_token_diff = (tokens != tokens_encoded_in_train).float().mean() + percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff) + + input_tokens = tokens + + num_tokens_to_generate = np.prod(self.video_latent_shape) + + # Do a forward pass + logits = model.model.forward( + tokens, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + ) + if parallel_state.get_context_parallel_world_size() > 1: + logits = gather_batch_from_cp_ranks(logits) + input_tokens = gather_batch_from_cp_ranks(input_tokens) + + # Start position for video tokens in the vocabulary + video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset + video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size + + # Clipping logits only to video tokens. We remove the text vocab predictions. + # This will ensure that the video tokens only correspond to the video part of the vocabulary. + logits = logits[:, :, video_token_start : video_token_start + video_vocab_size] + + # Sample with argmax token. This should be good for teacher forcing experiment. + logits = logits.contiguous() + generations = torch.argmax(logits, dim=-1) + + # For each video in the batch, subsample frames for display + batch_size = input_tokens.shape[0] + out_frames = [] + out_videos_gen = [] + out_videos_rec = [] + out_videos_gt = [] + # log the accuracy of teacher-forcing + acc = [] + loss_list = [] + + for sample_num in range(batch_size): + # Subsample the generations to the video part. + # This corresponds to the part from begin of video to end of video. + bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"] + bov_index = input_tokens[sample_num] == bov_token + use_special_token = sum(bov_index) != 0 + if use_special_token: + bov_index = bov_index.nonzero().item() + # generations: real_token1 real_token2, ... real_token7680; total 7680 + # gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680 + # for vis: real_token1 real_token2, ..., real_token7680; total 7680 + # for accuracy: real_token1 real_token2, ..., real_token7680; total 7680 + gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate] + gen_video_tokens_vis = gen_video_tokens + gen_video_tokens_acc = gen_video_tokens + logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate] + else: + # generations: real_token1 real_token2, ... real_token7680 + # gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679 + # We need different tokens for vis and accuracy compute + # for acc: real_token2 real_token3, ..., real_token7680; total 7679 + # for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679 + gen_video_tokens = generations[sample_num][ + : num_tokens_to_generate - 1 + ] # remove the last token since there is no gt + # Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct + gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens]) + gen_video_tokens_acc = gen_video_tokens + logits_loss = logits[sample_num][: num_tokens_to_generate - 1] + + # Rearrange the video to a spatial tensor + gen_video_tokens_vis_BTHW = rearrange( + gen_video_tokens_vis.unsqueeze(0), + "B (T H W) -> B T H W", + T=self.video_latent_shape[0], + H=self.video_latent_shape[1], + W=self.video_latent_shape[2], + ) + + # for real videos, we need to skip the bov and eov tokens for decoding + if use_special_token: + # input_tokens: real_token1 real_token2 ... ... + # real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680 + # for vis: real_token1 real_token2 ... real_token7680; total 7680 + # for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above + real_video_tokens = ( + input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start + ) + real_video_tokens_vis = real_video_tokens + real_video_tokens_acc = real_video_tokens + else: + # input_tokens: real_token1 real_token2,... real_token7680; total 7680 + # real_video_tokens: real_token1 real_token2,... real_token7680; total 7680 + # for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted + # for vis: gt start from real_token1, real_token2; total 7680 + real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start + real_video_tokens_vis = real_video_tokens + real_video_tokens_acc = real_video_tokens[1:].flatten() + + real_video_tokens_vis_BTHW = rearrange( + real_video_tokens_vis.unsqueeze(0), + "B (T H W) -> B T H W", + T=self.video_latent_shape[0], + H=self.video_latent_shape[1], + W=self.video_latent_shape[2], + ) + # Calculate accuracy + correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float() + labels = real_video_tokens_acc.clone() + + if model.config.ignore_first_num_tokens > 0: + labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index + select_index = labels != model.tokenizer.ignore_index + correct_predictions = correct_predictions[select_index] + + loss = torch.nn.functional.cross_entropy( + logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none" + ) + acc.append(correct_predictions.mean() * 100.0) + loss_list.append(loss.mean()) + + # Decode the predicted latents + if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0: + vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda()) + else: + vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap( + gen_video_tokens_vis_BTHW.cuda(), + temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap, + ) + # normalize decoded images from [-1, 1] to [0, 1], and clip value + vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1) + vid_decoded = vid_decoded[0] + + # Decode the GT latents + if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0: + vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda()) + else: + vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap( + real_video_tokens_vis_BTHW.cuda(), + temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap, + ) + # normalize decoded image from [-1, 1] to [0, 1], and clip value + vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1) + vid_rec = vid_rec[0] + + vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W] + vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W] + + # Subsample real and generated video frames + input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W] + rec_video_frames = vid_rec.transpose(0, 1) + gen_video_frames = vid_decoded.transpose(0, 1) + out_videos_gen.append(gen_video_frames) + out_videos_rec.append(rec_video_frames) + out_videos_gt.append(input_video_frames) + + stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display) + + input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5) + input_video_frames_subsampled = torchvision.utils.make_grid( + input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0] + ) + + gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5) + gt_video_frames_subsampled = torchvision.utils.make_grid( + gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0] + ) + gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5) + gen_video_frames_subsampled = torchvision.utils.make_grid( + gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0] + ) + + out_frames.append(input_video_frames_subsampled) + out_frames.append(gt_video_frames_subsampled) + out_frames.append(gen_video_frames_subsampled) + + scaled_num_rank_to_log = ( + self.num_file_to_log + * parallel_state.get_context_parallel_world_size() + * parallel_state.get_tensor_model_parallel_world_size() + ) + if self.rank < scaled_num_rank_to_log and not skip_save_file: + local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg" + out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + torchvision.utils.save_image(out_image_grid, local_path) + + # Log to wandb + avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item() + avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item() + log_info = "" + if "acc" in output_batch: + log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%" + if percent_token_diff is not None: + log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%" + log.info( + f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}" + ) + if self.rank == 0 and wandb.run: + local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg") + local_files = sorted(local_files)[: self.num_file_to_log] + if captions is None: + captions = ["vid_frames_teacher_forcing"] * len(local_files) + for local_path, caption in zip(local_files, captions): + wandb.log( + {"frames": [wandb.Image(local_path, caption=caption)]}, + step=iteration, + ) + + wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration) + wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration) + if percent_token_diff is not None: + wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration) diff --git a/cosmos_predict1/autoregressive/configs/__init__.py b/cosmos_predict1/autoregressive/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/configs/base/__init__.py b/cosmos_predict1/autoregressive/configs/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/configs/base/callbacks.py b/cosmos_predict1/autoregressive/configs/base/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..040f326f221febac7af54f7d7a64876a9fc030ef --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/callbacks.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing +from cosmos_predict1.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.lazy_config import LazyCall as L + +BASIC_CALLBACKS = dict( + progress_bar=L(ProgressBarCallback)(), + grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"), +) + +VIDEO_TEACHER_FORCING_CALLBACK = dict( + vid_sampling_tf=L(VideoSamplingTeacherForcing)( + every_n=500, + video_latent_shape="${model.model_config.video_latent_shape}", + num_frames_to_display=4, + save_folder="video_sampling_teacher_forcing", + ) +) diff --git a/cosmos_predict1/autoregressive/configs/base/dataloader.py b/cosmos_predict1/autoregressive/configs/base/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..69458e450a7fb5d85ea25c2514e975d5540a108f --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/dataloader.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig +from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall as L + +DATALOADER_OPTIONS = {} + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +def dataloader_register(key): + log.info(f"registering dataloader {key}...") + + def decorator(func): + DATALOADER_OPTIONS[key] = func + return func + + return decorator + + +@dataloader_register("tealrobot_video") +def get_tealrobot_video( + batch_size: int = 1, + dataset_dir: str = "datasets/cosmos_nemo_assets/videos/", + sequence_interval: int = 1, + num_frames: int = 33, + video_size: list[int, int] = [640, 848], + start_frame_interval: int = 1, +): + dataset = L(VideoDataset)( + config=VideoDatasetConfig( + dataset_dir=dataset_dir, + sequence_interval=sequence_interval, + num_frames=num_frames, + video_size=video_size, + start_frame_interval=start_frame_interval, + ) + ) + return L(DataLoader)( + dataset=dataset, + sampler=L(get_sampler)(dataset=dataset), + batch_size=batch_size, + drop_last=True, + pin_memory=True, + num_workers=8, + ) diff --git a/cosmos_predict1/autoregressive/configs/base/dataset.py b/cosmos_predict1/autoregressive/configs/base/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8e24fa535a0abc0b7fde86ad302a960bc6bf28 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/dataset.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset config class.""" + +import attrs + +from cosmos_predict1.utils.config import make_freezable + + +@make_freezable +@attrs.define(slots=False) +class VideoDatasetConfig: + """ + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + start_frame_interval (int): Interval between starting frames of sequences + """ + + dataset_dir: str = "datasets/cosmos_nemo_assets/videos/" + sequence_interval: int = 1 + num_frames: int = 33 + video_size: list[int, int] = [640, 848] + start_frame_interval: int = 1 diff --git a/cosmos_predict1/autoregressive/configs/base/model.py b/cosmos_predict1/autoregressive/configs/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f72feb258ee3233eeadbd3afd218762648211ea5 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.utils import config + +_ACTION_DIM = 8 +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define +class ModelConfig: + """ + A class to hold model configuration arguments. + + Args: + dim (int): The dimensionality of the input and output of each transformer block. + n_layers (int): Number of layers in the transformer. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to + `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention. + head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads. + vocab_size (int): Vocabulary size. + ffn_hidden_size (int): Hidden size for feedforward network. + norm_eps (float): Epsilon value for normalization. + rope_theta (float): Theta value for rotary positional embeddings. + apply_abs_pos_emb (bool): Whether to apply absolute position embeddings. + max_batch_size (int): Maximum batch size for inference. + max_seq_len (int): Maximum sequence length for input text. + fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True. + causal_mask (bool): Whether to use causal mask. Defaults to True. + norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm". + precision (str): Data type for the model. + use_qk_normalization (bool): Whether to enable QK normalization. + tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1. + ckpt_dir (str): Checkpoint directory. + ckpt_path (str): Checkpoint path. + apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension). + yarn_scale (Optional[float]): Scale factor for YaRN. + yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code) + yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code) + original_seq_len (Optional[int]): Original sequence length. + vision_encoder (Optional[str]): Vision encoder name. + mm_projector (Optional[str]): Multi-modal projector name. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D". + pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2". + original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + num_video_frames (Optional[int]): Number of video frames. + video_height (Optional[int]): Raw video pixel height dimension. + video_width (Optional[int]): Raw video pixel width dimension. + video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W). + """ + + dim: int = attrs.field(default=4096) + n_layers: int = attrs.field(default=32) + n_heads: int = attrs.field(default=32) + n_kv_heads: Optional[int] = attrs.field(default=8) + head_dim: Optional[int] = attrs.field(default=None) + vocab_size: int = attrs.field(default=128256) + ffn_hidden_size: int = attrs.field(default=14336) + norm_eps: float = attrs.field(default=1e-5) + rope_theta: float = attrs.field(default=500000) + apply_abs_pos_emb: bool = attrs.field(default=False) + max_batch_size: int = attrs.field(default=1) + max_seq_len: int = attrs.field(default=8192) + fuse_qkv: bool = attrs.field(default=False) + causal_mask: bool = attrs.field(default=True) + norm_type: str = attrs.field(default="rmsnorm") + precision: str = attrs.field(default="bfloat16") + use_qk_normalization: bool = False + tokenizer: Optional[TokenizerConfig] = None + tensor_model_parallel_size: int = attrs.field(default=1) + ckpt_dir: Optional[str] = attrs.field(default=None) + ckpt_path: Optional[str] = attrs.field( + default=None + ) # If not None, load the model from this path instead of ckpt_dir + apply_yarn: Optional[bool] = attrs.field(default=False) + yarn_scale: Optional[float] = attrs.field(default=None) + yarn_beta_fast: Optional[int] = attrs.field(default=None) + yarn_beta_slow: Optional[int] = attrs.field(default=None) + original_seq_len: Optional[int] = attrs.field(default=None) + vision_encoder: Optional[str] = attrs.field(default=None) + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + mm_projector: Optional[str] = attrs.field(default=None) + rope_dim: Optional[str] = attrs.field(default="1D") + pytorch_rope_version: Optional[str] = attrs.field(default="v2") + original_latent_shape: Optional[list] = None + pad_to_multiple_of: Optional[int] = None + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + insert_cross_attn: bool = False + insert_cross_attn_every_k_layers: int = 1 + context_dim: Optional[int] = attrs.field(default=1024) + # For video training + num_video_frames: Optional[int] = None + # Raw video pixel dimension + video_height: Optional[int] = None + video_width: Optional[int] = None + # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact + video_latent_shape: Optional[list] = None + + def __getitem__(self, item): + return getattr(self, item) + + +@attrs.define +class TrainingModelConfig: + """ + A class to hold model configuration arguments. + + Args: + dim (int): The dimensionality of the input and output of each transformer block. + n_layers (int): Number of layers in the transformer. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to + `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention. + head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads. + vocab_size (int): Vocabulary size. + multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation. + ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension. + ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it. + norm_eps (float): Epsilon value for normalization. + rope_theta (float): Theta value for rotary positional embeddings. + apply_abs_pos_emb (bool): Whether to apply absolute position embeddings. + max_batch_size (int): Maximum batch size for inference (determines KV cache size). + max_seq_len (int): Maximum sequence length for input text (determines KV cache size). + fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend. + causal_mask (bool): Whether to use causal mask. Defaults to True. + flash_attn (bool): Whether to use Flash attention. + norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm". + backend (str): Backend for the model. + precision (str): Data type for the model. + ema (config.EMAConfig): Configuration for exponential moving average. + embedding_dropout(float): Dropout rate for the embedding layer. + attention_dropout(float): Dropout rate for attention. + hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's + implementation in its TransformerLayer class). + use_qk_normalization (bool): Whether to enable QK normalization. + inference (bool): Whether the model is used for inference. + act_ckpt_enabled (bool): Whether to enable activation checkpointing. + fsdp_enabled (bool): Whether to enable FSDP. + fsdp (LazyDict): Configuration for FSDP. + ckpt_dir (str): Checkpoint directory. + ckpt_path (str): Checkpoint path. + cache_dir (str): Cache directory. + apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension). + yarn_scale (Optional[float]): Scale factor for YaRN. + yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code) + yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code) + original_seq_len (Optional[int]): Original sequence length. + depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3). + context_parallel_size (int): Context parallel size. Defaults to 1. + tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1. + sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False. + set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism. + Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`. + attention_tp (bool): Whether to use tensor parallelism for attention layers. + mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None. + Choices: "identity", "linear", "mlp", "mlp_downsample". + video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W] + image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W] + num_video_frames (Optional[int]): Number of video frames. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". + pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation. + original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT). + peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0, + it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer. + It is advised to pick n such that the final layer is included. + freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn). + finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn). + use_action_condition (bool): Whether to use the robot action condition. + action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp". + action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]). + action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding. + group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal". + sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True). + Note: this is to ensure all TP-ranks have the same layernorm parameters. + z_loss_coeff (float): The coefficient for the z-loss. + insert_medusa_head (bool): Whether to insert the Medusa head. + ft_medusa_option (str): Options on which layers to finetune, choices like: + "fft": fully fine-tune both medusa heads and all LLM backbone; + "head": fine-tune medusa heads; + "head_out": fine-tune medusa heads, and the output layer; + "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone. + medusa_num_heads (int): Number of heads in the Medusa head. + medusa_num_layers (int): Number of layers in the Medusa head. + medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1. + zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False). + concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False). + """ + + dim: int = attrs.field(default=4096) + n_layers: int = attrs.field(default=32) + n_heads: int = attrs.field(default=32) + n_kv_heads: Optional[int] = attrs.field(default=8) + head_dim: Optional[int] = attrs.field(default=None) + vocab_size: int = attrs.field(default=128256) + multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3) + ffn_hidden_size: Optional[int] = attrs.field(default=None) + norm_eps: float = attrs.field(default=1e-5) + rope_theta: float = attrs.field(default=500000) + apply_abs_pos_emb: bool = attrs.field(default=False) + max_batch_size: int = attrs.field(default=1) + max_seq_len: int = attrs.field(default=8192) + fuse_qkv: bool = attrs.field(default=False) + causal_mask: bool = attrs.field(default=True) + flash_attn: bool = attrs.field(default=True) + norm_type: str = attrs.field(default="rmsnorm") + backend: str = attrs.field(default="pytorch") + precision: str = attrs.field(default="bfloat16") + ema: config.EMAConfig = config.EMAConfig(enabled=False) + embedding_dropout: float = 0.0 + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + use_qk_normalization: bool = False + tokenizer: Optional[TokenizerConfig] = None + inference: bool = False + act_ckpt_enabled: bool = False + fsdp_enabled: bool = False + context_parallel_size: int = attrs.field(default=1) + tensor_model_parallel_size: int = attrs.field(default=1) + sequence_parallel: bool = attrs.field(default=False) + set_parallel_mode: bool = attrs.field(default=False) + fsdp: LazyDict = LazyDict( + dict( + policy="auto", # choices: ["size", "auto"] + min_num_params=1024, # Used as policy == "size" + sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size + sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes. + ) + ) + ckpt_dir: Optional[str] = attrs.field(default="") + ckpt_path: Optional[str] = attrs.field( + default=None + ) # If not None, load the model from this path instead of ckpt_dir + cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache") + apply_yarn: Optional[bool] = attrs.field(default=False) + yarn_scale: Optional[float] = attrs.field(default=None) + yarn_beta_fast: Optional[int] = attrs.field(default=None) + yarn_beta_slow: Optional[int] = attrs.field(default=None) + original_seq_len: Optional[int] = attrs.field(default=None) + depth_init: bool = attrs.field(default=True) + ignore_first_num_tokens: int = 0 + z_loss_coeff: float = 1e-4 + attention_tp: bool = False + vision_encoder: Optional[str] = attrs.field(default=None) + mm_projector: Optional[str] = attrs.field(default=None) + rope_dim: Optional[str] = attrs.field(default="1D") + pytorch_rope_version: Optional[str] = attrs.field(default="v2") + original_latent_shape: Optional[list] = None + pad_to_multiple_of: Optional[int] = None + peft_last_n_layers: Optional[int] = attrs.field(default=0) + peft_every_n_layers: Optional[int] = attrs.field(default=0) + freeze_vision_encoder: bool = False + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + insert_cross_attn: bool = False + insert_cross_attn_every_k_layers: int = 1 + context_dim: Optional[int] = attrs.field(default=1024) + finetune_layers_with_cross_attn: bool = False + finetune_layers_without_cross_attn: bool = False + use_action_condition: bool = False + action_embedding_mode: Optional[str] = attrs.field(default="mlp") + action_dim: Optional[int] = attrs.field(default=_ACTION_DIM) + action_embedding_dim: Optional[int] = attrs.field(default=1024) + group_causal_mask_mode: Optional[str] = attrs.field(default=None) + sync_1d_parameters: bool = True + # hyper-parameters for the medusa head configs + insert_medusa_head: bool = False + ft_medusa_option: str = "fft" + medusa_num_heads: int = 7 + medusa_num_layers: int = 1 + medusa_concat_heads: bool = True + # For video training + num_video_frames: Optional[int] = None + # Raw video pixel dimension + video_height: Optional[int] = None + video_width: Optional[int] = None + # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact + video_latent_shape: Optional[list] = None + # For image training + image_latent_shape: Optional[list] = None + # For robot training (action) + zero_init_cross_attn_proj: bool = False + # For robot training (action) + concat_action_to_context: bool = False + + def __getitem__(self, item): + return getattr(self, item) diff --git a/cosmos_predict1/autoregressive/configs/base/model_config.py b/cosmos_predict1/autoregressive/configs/base/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..53442dc6faea3346b59bb860014cde1373e236bd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model_config.py @@ -0,0 +1,718 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, List, Optional + +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import ( + TextTokenizerConfig, + TokenizerConfig, + VideoTokenizerConfig, + create_discrete_video_fsq_tokenizer_state_dict_config, +) +from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer +from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import EMAConfig +from cosmos_predict1.utils.lazy_config import LazyCall as L + +# Common architecture specifications +BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336} +COSMOS_ARCHITECTURES = { + "1b": { + "n_layers": 16, + "dim": 2048, + "n_heads": 32, + }, + "4b": { + "n_layers": 16, + "dim": 4096, + "n_heads": 32, + }, + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "head_dim": 128, + }, +} + +COSMOS_YARN_CONFIG = { + "original_latent_shape": [3, 40, 64], + "apply_yarn": True, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, + "yarn_scale": 2, +} + +# Llama3 architecture specifications for different model sizes +LLAMA3_ARCHITECTURES = { + "8b": { + "n_layers": 32, + "dim": 4096, + "n_heads": 32, + "ffn_hidden_size": 14336, + }, +} +# Llama3.1 uses YaRN for long context support (context of 128k tokens) +LLAMA_YARN_CONFIG = { + "apply_yarn": True, + "yarn_scale": 8, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, +} + +# Mistral architecture specifications for different model sizes +MISTRAL_ARCHITECTURES = { + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "ffn_hidden_size": 14336, + "head_dim": 128, + }, +} + +PIXTRAL_VISION_ARCHITECTURES = { + "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"}, +} + + +def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict: + """ + Get the model architecture specifications for the given model size, model family and pretrained status. + + Args: + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral" + pretrained (bool): Whether to load pretrained weights. + + Returns: + dict: A dictionary containing the model architecture specifications. + """ + arch_specs = copy.deepcopy(BASE_CONFIG) + model_size = model_size.lower() + if model_family.startswith("cosmos"): + arch_specs.update(COSMOS_ARCHITECTURES[model_size]) + elif model_family.startswith("llama"): + arch_specs.update(LLAMA3_ARCHITECTURES[model_size]) + elif model_family in ["mistral", "pixtral"]: + arch_specs.update(MISTRAL_ARCHITECTURES[model_size]) + if model_family == "pixtral": + arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size]) + else: + raise ValueError(f"Model family {model_family} is not supported.") + + if pretrained: + if model_family == "cosmos": + if model_size == "12b": + arch_specs.update(COSMOS_YARN_CONFIG) + log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}") + else: + pass + elif model_family in ["llama", "llama3"]: + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 8192, + "vocab_size": 128256, + } + arch_specs.update(pretrained_specs) + elif model_family == "llama3.1": + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 131072, + "original_seq_len": 8192, + "vocab_size": 128256, + **LLAMA_YARN_CONFIG, + } + arch_specs.update(pretrained_specs) + elif model_family == "mistral": + assert model_size == "12b", "We only support Mistral-Nemo-12B model." + pretrained_specs = { + "rope_theta": 1000000, + "max_seq_len": 128000, + "vocab_size": 131072, + } + arch_specs.update(pretrained_specs) + elif model_family == "pixtral": + assert model_size == "12b", "We only support Pixtral 12B model." + pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072} + arch_specs.update(pretrained_specs) + else: + raise ValueError(f"Model family {model_family} doesn't have a pretrained config.") + + return arch_specs + + +def create_text_model_config( + model_ckpt_path: str, + tokenizer_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "mistral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_seq_len: int = None, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + pytorch_rope_version: str = None, +) -> dict: + """Create a text model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_path (str): Path to the tokenizer folder. + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc. + is_instruct_model (bool): Whether the model is an instruct model. + inference (bool): Whether to create the model for inference. + max_seq_len (int): Maximum sequence length. + max_batch_size (int): Maximum batch size. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + tensor_model_parallel_size=tensor_model_parallel_size, + rope_dim=rope_dim, + **model_arch_specs, + ) + + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(TextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ), + data_key="text", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="text_only", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_vision_language_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "pixtral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + max_seq_len: int = None, + vision_encoder_in_channels: int = 3, + fuse_qkv: bool = False, + pytorch_rope_version: str = None, +) -> dict: + """Create a vision-language model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_ckpt_path (str): Path to the tokenizer checkpoint. + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "pixtral". + model_size (str): Model size. Choices: "12b". + is_instruct_model (bool): Whether the model is an instruct model. + rope_dim (str): RoPE dimension. Choices: "1D". + add_special_tokens (bool): Whether to add special tokens. + max_seq_len (int): Maximum sequence length. + vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4. + fuse_qkv (bool): Whether to fuse the QKV linear layers. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + tensor_model_parallel_size=tensor_model_parallel_size, + rope_dim=rope_dim, + vision_encoder_in_channels=vision_encoder_in_channels, + fuse_qkv=fuse_qkv, + **model_arch_specs, + ) + # Vision-language tokenizer + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(ImageTextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + image_processor_path=tokenizer_ckpt_path, + tokenizer_path=tokenizer_ckpt_path, + ), + data_key="image_text_interleaved", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="image_text_interleaved", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_video2world_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "cosmos", + model_size: str = "4b", + pixel_chunk_duration: int = 9, + num_video_frames: int = 36, + compression_ratio: List[int] = [8, 16, 16], + original_seq_len: int = 8192, + num_condition_latents_t: int = 1, + num_tokens_to_ignore: int = -1, + batch_size: int = 2, + video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config, + rope_dim: str = "3D", + add_special_tokens: bool = True, + video_height: int = 384, + video_width: int = 640, + use_qk_normalization: bool = True, + insert_cross_attn: bool = False, + insert_cross_attn_every_k_layers: int = 1, + context_dim: int = 1024, + training_type: str = "video_to_video", + pad_to_multiple_of: Optional[int] = 64, + vocab_size: int = 64000, + apply_abs_pos_emb: bool = False, +) -> dict: + """Create a video-to-world model config. + Args: + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "8b", "3b". + pixel_chunk_duration (int): Number of frames in each chunk. + num_video_frames (int): Number of video frames. + compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8]. + original_seq_len (int): Original sequence length. + apply_yarn (bool): Whether to apply YaRN for long context scaling. + yarn_beta_fast (Optional[int]): Fast beta for YaRN. + yarn_beta_slow (Optional[int]): Slow beta for YaRN. + yarn_scale (Optional[int]): Scale factor for ctx extension. + use_qk_normalization (bool): Whether to use Query-Key normalization. + training_type (str): Type of training task. + batch_size (int): Batch size. + video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config + video_tokenizer_version (str): Version of the video tokenizer. + num_condition_latents_t (int): Number of conditioning latent channels + num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence + video_height (int): Height of the video frame. Defaults to 384. + video_width (int): Width of the video frame. Defaults to 640. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE. + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + vocab_size (int): Vocabulary size. + apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings. + Returns: + dict: A dictionary containing the model configuration representing the model object, can be instantiated. + """ + assert ( + pixel_chunk_duration % compression_ratio[0] == 1 + ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})" + latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1 + latent_height = video_height // compression_ratio[1] + latent_width = video_width // compression_ratio[2] + # Do some math to compute the video latent shape and sequence length + assert ( + num_video_frames % pixel_chunk_duration == 0 + ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}" + video_latent_shape = [ + num_video_frames // pixel_chunk_duration * latent_chunk_duration, + latent_height, + latent_width, + ] + # product of video_latent_shape + num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2] + if add_special_tokens: + seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3 + seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64 + # for text to video, we need to add token to indicate the start of the video + elif training_type == "text_to_video": + seq_len = num_token_video_latent + 1 + else: + seq_len = num_token_video_latent + + if seq_len % pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + + # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss + # If num_tokens_to_ignore is specified, use it. + # Else compute it from num_condition_latents_t + if num_tokens_to_ignore < 0: + num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t + if not add_special_tokens and num_condition_latents_t > 0: + # If there are no special tokens (bov), do a -1 so that you can compute the loss + # from the first token of the next chunk + num_tokens_to_ignore -= 1 + + model_config = ModelConfig( + video_height=video_height, + video_width=video_width, + max_seq_len=seq_len, + max_batch_size=batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=use_qk_normalization, + vocab_size=64000, + original_seq_len=original_seq_len, + tensor_model_parallel_size=tensor_model_parallel_size, + video_latent_shape=video_latent_shape, + num_video_frames=num_video_frames, + rope_dim=rope_dim, + pad_to_multiple_of=pad_to_multiple_of, + insert_cross_attn=insert_cross_attn, + insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers, + context_dim=context_dim, + apply_abs_pos_emb=apply_abs_pos_emb, + **model_arch_specs, + ) + + video_tokenizer_config = video_tokenizer_config_creator( + tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio + ) + tokenizer_config = TokenizerConfig( + text_tokenizer=None, + video_tokenizer=VideoTokenizerConfig( + config=video_tokenizer_config, + data_key="video", + tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token. + tokenize_here=True, + max_seq_len=num_token_video_latent, + vocab_size=vocab_size, + ), + seq_len=seq_len, + training_type=training_type, + add_special_tokens=add_special_tokens, + pad_to_multiple_of=pad_to_multiple_of, + ) + return model_config, tokenizer_config + + +def create_video2world_model( + tensor_model_parallel_size: int = 1, + context_parallel_size: int = 1, + shard_checkpoint: bool = False, + model_family: str = "cosmos", + model_size: str = "1b", + backend: str = "pytorch", + pixel_chunk_duration: int = 9, + num_video_frames: int = 36, + compression_ratio: List[int] = [8, 16, 16], + original_seq_len: int = 8192, + apply_yarn: bool = False, + yarn_beta_fast: Optional[int] = None, + yarn_beta_slow: Optional[int] = None, + yarn_scale: Optional[int] = None, + num_condition_latents_t: int = 1, + num_tokens_to_ignore: int = -1, + batch_size: int = 1, + fsdp_enabled: bool = False, + act_ckpt_enabled: bool = False, + video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config, + rope_dim: str = "3D", + add_special_tokens: bool = False, + video_height: int = 384, + video_width: int = 640, + original_latent_shape: Optional[List[int]] = None, + use_qk_normalization: bool = True, + sequence_parallel: bool = False, + insert_cross_attn: bool = False, + insert_cross_attn_every_k_layers: int = 1, + context_dim: int = 1024, + finetune_layers_with_cross_attn: bool = False, + finetune_layers_without_cross_attn: bool = False, + use_action_condition: bool = False, + action_embedding_mode: Optional[str] = "mlp", + action_dim: int = 8, # ACTION_DIM, + action_embedding_dim: int = 1024, + group_causal_mask_mode: Optional[str] = None, + training_type: str = "video_to_video", + pad_to_multiple_of: Optional[int] = 1, + z_loss_coeff: float = 1e-4, + temporal_overlap: int = 0, + embedding_dropout: float = 0.0, + insert_medusa_head: bool = False, + ft_medusa_option: str = "fft", + medusa_num_heads: int = 7, + medusa_num_layers: int = 1, + medusa_concat_heads: bool = True, + fuse_qkv: bool = False, + zero_init_cross_attn_proj: bool = False, + concat_action_to_context: bool = False, + tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit", +) -> dict: + """Create a video-to-video model for training. + Args: + tensor_model_parallel_size (int): Number of tensor model parallel groups. + context_parallel_size (int): Number of context parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "8b", "3b". + backend (str): Backend for the model. Choices: "pytorch", "transformer_engine". + pixel_chunk_duration (int): Number of frames in each chunk. + num_video_frames (int): Number of video frames. + compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8]. + original_seq_len (int): Original sequence length. + apply_yarn (bool): Whether to apply YaRN for long context scaling. + yarn_beta_fast (Optional[int]): Fast beta for YaRN. + yarn_beta_slow (Optional[int]): Slow beta for YaRN. + yarn_scale (Optional[int]): Scale factor for ctx extension. + fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled. + act_ckpt_enabled (bool): Whether activation checkpointing is enabled. + use_qk_normalization (bool): Whether to use Query-Key normalization. + training_type (str): Type of training task. + batch_size (int): Batch size. + video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config + video_tokenizer_version (str): Version of the video tokenizer. + num_condition_latents_t (int): Number of conditioning latent channels + num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence + video_height (int): Height of the video frame. Defaults to 384. + video_width (int): Width of the video frame. Defaults to 640. + rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D". + add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE. + original_latent_shape (list): Original latent shape before RoPE scaling. + sequence_parallel (bool): Whether to enable sequence parallelism. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn). + finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn). + use_action_condition (bool): Whether to use action condition. + action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp". + action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]). + action_embedding_dim (int): Dimension of the action embedding. + group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal". + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + z_loss_coeff (float): Coefficient for the z loss. + temporal_overlap (int): Temporal overlap in the latent space. + embedding_dropout (float): Dropout rate for the embeddings. + insert_medusa_head (bool): Whether to insert the Medusa head. + ft_medusa_option (str): Options on which layers to finetune, choices like: + "fft": fully fine-tune both medusa heads and all LLM backbone; + "head": fine-tune medusa heads; + "head_out": fine-tune medusa heads, and the output layer; + "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone. + medusa_num_heads (int): Number of heads in the Medusa head. + medusa_num_layers (int): Number of layers in the Medusa head. + medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1. + fuse_qkv (bool): Whether to fuse the QKV linear layers. + zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False). + concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False). + Returns: + dict: A dictionary containing the model configuration representing the model object, can be instantiated. + """ + assert ( + pixel_chunk_duration % compression_ratio[0] == 1 + ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})" + latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1 + latent_height = video_height // compression_ratio[1] + latent_width = video_width // compression_ratio[2] + # Compute the video latent shape and sequence length + if temporal_overlap == 0: + assert ( + num_video_frames % pixel_chunk_duration == 0 + ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}" + video_latent_shape = [ + num_video_frames // pixel_chunk_duration * latent_chunk_duration, + latent_height, + latent_width, + ] + + else: + # Calculate temporal overlap in the latent space + temporal_overlap_latent = temporal_overlap // compression_ratio[0] + + # Calculate the effective number of latent chunks for the video + latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap) + + # Compute the total duration of the latent chunks, accounting for overlap + effective_latent_duration = ( + latent_chunk_duration - temporal_overlap_latent + ) * latent_chunks + temporal_overlap_latent + + # Define the shape of the video in the latent space + video_latent_shape = [ + effective_latent_duration, # Temporal dimension + latent_height, # Height in the latent space + latent_width, # Width in the latent space + ] + + # product of video_latent_shape + num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2] + if add_special_tokens: + seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3 + seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64 + # for text to video, we need to add token to indicate the start of the video + elif training_type == "text_to_video": + seq_len = num_token_video_latent + 1 + else: + seq_len = num_token_video_latent + + if seq_len % pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False) + + inference = False # False for training, True for inference + # set_parallel_mode = True + set_parallel_mode = tensor_model_parallel_size > 1 + attention_tp = True + + if context_parallel_size > 1: + assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine." + + if tensor_model_parallel_size > 1: + assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode." + + # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss + # If num_tokens_to_ignore is specified, use it. + # Else compute it from num_condition_latents_t + if num_tokens_to_ignore < 0: + num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t + if not add_special_tokens and num_condition_latents_t > 0: + # If there are no special tokens (bov), do a -1 so that you can compute the loss + # from the first token of the next chunk + num_tokens_to_ignore -= 1 + + model_config = TrainingModelConfig( + video_height=video_height, + video_width=video_width, + max_seq_len=seq_len, + max_batch_size=batch_size, + inference=inference, + backend=backend, + precision="bfloat16", + ema=EMAConfig(enabled=False), + act_ckpt_enabled=act_ckpt_enabled, + fsdp_enabled=fsdp_enabled, + cache_dir=None, + ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt", + use_qk_normalization=use_qk_normalization, + vocab_size=64000, + ignore_first_num_tokens=num_tokens_to_ignore, + apply_yarn=apply_yarn, + yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + original_seq_len=original_seq_len, + yarn_scale=yarn_scale, + context_parallel_size=context_parallel_size, + tensor_model_parallel_size=tensor_model_parallel_size, + set_parallel_mode=set_parallel_mode, + attention_tp=attention_tp, + video_latent_shape=video_latent_shape, + num_video_frames=num_video_frames, + rope_dim=rope_dim, + original_latent_shape=original_latent_shape, + pad_to_multiple_of=pad_to_multiple_of, + sequence_parallel=sequence_parallel, + insert_cross_attn=insert_cross_attn, + insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers, + context_dim=context_dim, + finetune_layers_with_cross_attn=finetune_layers_with_cross_attn, + finetune_layers_without_cross_attn=finetune_layers_without_cross_attn, + use_action_condition=use_action_condition, + action_embedding_mode=action_embedding_mode, + action_dim=action_dim, + action_embedding_dim=action_embedding_dim, + group_causal_mask_mode=group_causal_mask_mode, + z_loss_coeff=z_loss_coeff, + embedding_dropout=embedding_dropout, + insert_medusa_head=insert_medusa_head, + ft_medusa_option=ft_medusa_option, + medusa_num_heads=medusa_num_heads, + medusa_num_layers=medusa_num_layers, + medusa_concat_heads=medusa_concat_heads, + fuse_qkv=fuse_qkv, + zero_init_cross_attn_proj=zero_init_cross_attn_proj, + concat_action_to_context=concat_action_to_context, + **model_arch_specs, + ) + + tokenizer_config = TokenizerConfig( + text_tokenizer=None, + video_tokenizer=VideoTokenizerConfig( + config=video_tokenizer_config_creator( + ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration + ), + data_key="video", + tokenizer_offset=0, + vocab_size=64000, + tokenize_here=True, + max_seq_len=num_token_video_latent, + temporal_overlap=temporal_overlap, + ), + seq_len="${model.model_config.max_seq_len}", + training_type=training_type, + add_special_tokens=add_special_tokens, + pad_to_multiple_of=pad_to_multiple_of, + ) + + model_parallel = ModelParallelConfig( + bf16=True, + params_dtype=getattr(torch, "bfloat16"), + ) + model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}" + model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}" + model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}" + return L(AutoRegressiveTrainingModel.build)( + seed=0, + train_from_scratch=True, + model_config=model_config, + fsdp_checkpointer=None, + tokenizer_config=tokenizer_config, + model_parallel=model_parallel, + shard_checkpoint=shard_checkpoint, + ) diff --git a/cosmos_predict1/autoregressive/configs/base/model_parallel.py b/cosmos_predict1/autoregressive/configs/base/model_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5f93e257bce3af5fff5e79317d07a3199e7650fd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model_parallel.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_model_parallel_config(): + model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16")) + model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}" + model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}" + model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}" + MODEL_PARALLELS = LazyDict( + dict( + model_parallel_bf16=model_parallel, + ), + flags={"allow_objects": True}, + ) + return MODEL_PARALLELS["model_parallel_bf16"] diff --git a/cosmos_predict1/autoregressive/configs/base/optim.py b/cosmos_predict1/autoregressive/configs/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..beed4c4959f86d9ce440c90899bd5fd8c8b32cbd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/optim.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cosmos_predict1.utils.lazy_config import LazyCall as L + + +class LambdaLinearWarmupScheduler: + """ + A learning rate scheduler that implements linear warm-up and cool-down. + + This scheduler provides three phases: + 1. Warm-up: Learning rate linearly increases from 0 to 1. + 2. Constant: Learning rate remains at 1. + 3. Cool-down: Learning rate linearly decreases from 1 to 0. + + Args: + warmup_steps (int): Number of steps for the warm-up phase. + warmup_offset (int): Starts warmup from this offset. + max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided. + cooldown_steps (int, optional): Number of steps for the cool-down phase. + + Raises: + ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given. + """ + + def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None): + self.warmup_steps = warmup_steps + self.warmup_offset = warmup_offset + self.max_iter = max_iter + self.cooldown_steps = cooldown_steps + + if cooldown_steps is not None: + if max_iter is None: + raise ValueError("max_iter must be specified when cooldown_steps is provided") + self.cooldown_start = max_iter - cooldown_steps + else: + self.cooldown_start = None + + def __call__(self, step): + # Warm-up phase + if step < self.warmup_offset: + return 0 + + if step < self.warmup_steps + self.warmup_offset: + return float(step - self.warmup_offset) / float(max(1, self.warmup_steps)) + + # Constant phase (no cool-down) + elif self.cooldown_steps is None: + return 1.0 + + # Constant phase (before cool-down starts) + elif step < self.cooldown_start: + return 1.0 + + # Cool-down phase + elif self.cooldown_start <= step < self.max_iter: + cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps + return 1.0 - cooldown_progress + + # After max_iter + elif step >= self.max_iter: + return 0.0 + + # Unexpected case + else: + raise ValueError(f"Invalid step {step}") + + +LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)( + optimizer=None, + lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000), +) diff --git a/cosmos_predict1/autoregressive/configs/base/tokenizer.py b/cosmos_predict1/autoregressive/configs/base/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9e81eaccfe86d86411a2e0ef194a4d40a0460b --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/tokenizer.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer +from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_discrete_video_fsq_tokenizer_state_dict_config( + ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16] +) -> LazyDict: + CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( + # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + name="CausalDiscreteFactorizedVideoTokenizer", + ) + + return L(DiscreteVideoFSQStateDictTokenizer)( + enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"), + dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"), + tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig, + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0], + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=compression_ratio, + ) + + +@attrs.define(slots=False) +class TextTokenizerConfig: + """ + Text tokenizer config + + Args: + config: Config file to define the text tokenizer class. + data_key (str): The input key from data_dict that will be passed to the text tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. + vocab_size (int): Vocabulary size of the tokenizer. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = False + tokenizer_offset: int = 0 + vocab_size: int = 0 + + +@attrs.define(slots=False) +class VideoTokenizerConfig: + """ + Video tokenizer config + + Args: + config: Config file to define the video tokenizer class. + data_key (str): The input key from data_dict that will be passed to the video tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we + add an offset to make sure that video tokens and text tokens don't overlap. + vocab_size (int): Vocabulary size of the tokenizer. + max_seq_len (int): Maximum token length for an input video. + temporal_overlap (int): Overlap between consecutive video chunks. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = True + tokenizer_offset: int = 0 + vocab_size: int = 0 + max_seq_len: int = -1 + temporal_overlap: int = 0 + + +@attrs.define(slots=False) +class TokenizerConfig: + """ + Joint tokenizer config + + Args: + text_tokenizer (TextTokenizerConfig): Text tokenizer config file + class_tokenizer (ClassTokenizerConfig): Class tokenizer config file + video_tokenizer (VideoTokenizerConfig): Video tokenizer config file + image_tokenizer (ImageTokenizerConfig): Image tokenizer config file + seq_len (int): Final token sequence length + training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"] + add_special_tokens (bool): Whether to add special tokens to the output tokens + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + """ + + text_tokenizer: Optional[TextTokenizerConfig] = None + video_tokenizer: Optional[VideoTokenizerConfig] = None + seq_len: int = 4096 + training_type: str = None + add_special_tokens: bool = True + pad_to_multiple_of: Optional[int] = 64 diff --git a/cosmos_predict1/autoregressive/configs/config.py b/cosmos_predict1/autoregressive/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..df074434b8128b849e6570d8579e32f121e88ca5 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/config.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default config for cosmos_ar project.""" + +import os +from typing import Any, List + +import attrs + +from cosmos_predict1.autoregressive.configs.registry import register_configs +from cosmos_predict1.autoregressive.trainer import Trainer +from cosmos_predict1.utils import config, log +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"model": None}, + {"data_train": "mock_video"}, + {"data_val": None}, + {"optimizer": "fused_adamw"}, + {"scheduler": "warmup_cosine_lr"}, + {"checkpoint": "local"}, + {"callbacks": "basic"}, + {"global_config": None}, + {"experiment": None}, + ] + ) + + def validate(self) -> None: + """Validate that the config has all required fields.""" + assert self.job.project != "", "job.project is not set" + assert self.job.group != "", "job.group is not set" + assert self.job.name != "", "job.name is not set" + log.info("Validating config for cosmos_autoregressive job") + # FSDP config check + if self.model.model_config.fsdp_enabled: + assert self.trainer.distributed_parallelism == "fsdp" + else: + assert self.trainer.distributed_parallelism == "ddp" + + # Transformer Engine config check + if self.model.model_config.backend == "transformer_engine": + assert ( + "NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1" + ) # Enable Flash attention for transformer engine + + # TP, CP config check + if self.model_parallel is not None: + if self.model_parallel.context_parallel_size > 1: + assert ( + self.model.model_config.backend == "transformer_engine" + ), "Context parallelism is only supported in transformer engine." + + if self.model_parallel.tensor_model_parallel_size > 1: + assert ( + self.model.model_config.set_parallel_mode + ), "Tensor model parallelism is only supported in parallel mode." + + if self.model_parallel.sequence_parallel: + assert ( + self.model_parallel.tensor_model_parallel_size > 1 + ), "Sequence parallelism is only supported in tensor model parallelism." + assert ( + self.model.model_config.backend == "transformer_engine" + ), "Sequence parallelism is only supported in transformer engine." + + +def make_config(): + c = Config( + model=None, + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + checkpoint=None, + ) + + c.job.project = "cosmos_autoregressive" + c.job.group = "debug" + c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + c.trainer.run_validation = True + + c.trainer.seed = 0 + c.trainer.max_iter = 10 + c.trainer.logging_iter = 1 + + c.trainer.callbacks = None + register_configs() + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment") + return c diff --git a/cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py b/cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py b/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..7c427968e19669f9e51a97650f38037feb03efd6 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + This file contains a basic configuration for video2video experiments. +""" + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model +from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() + + +""" + Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33. + Usage: + torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1 +""" +base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict( + dict( + defaults=[ + {"override /data_train": "tealrobot_video_small"}, + { + "override /callbacks": [ + "basic", + "video_teacher_forcing", + ] + }, + {"override /checkpoint": "local"}, + {"override /optimizer": "fused_adamw"}, + {"override /scheduler": "warmup_cosine_lr"}, + "_self_", + ], + job=dict( + project="posttraining", + group="autoregressive_base", + name="base_4b_example_tealrobotsmall_tp1", + ), + model=create_video2world_model( + model_size="4b", + model_family="cosmos", + backend="pytorch", + tensor_model_parallel_size=1, + batch_size=1, + pixel_chunk_duration=33, + num_video_frames=33, + video_height=384, + video_width=640, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + add_special_tokens=False, + ), + trainer=dict( + max_iter=50000, + grad_accum_iter=1, + grad_scaler_args=dict(enabled=False), + run_validation=False, # No need for validation as epoch <= 1 + distributed_parallelism="ddp", + callbacks=dict( + vid_sampling_tf=dict( + every_n=500, + ), + ), + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Predict1-4B/model.pt", + load_training_state=False, + strict_resume=True, + save_iter=1000, + ), + model_parallel=create_model_parallel_config(), + ), +) + + +""" + Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33. + Usage: + torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4 +""" +base_4b_example_tealrobot_tp4: LazyDict = LazyDict( + dict( + defaults=[ + {"override /data_train": "tealrobot_video"}, + { + "override /callbacks": [ + "basic", + "video_teacher_forcing", + ] + }, + {"override /checkpoint": "local"}, + {"override /optimizer": "fused_adamw"}, + {"override /scheduler": "warmup_cosine_lr"}, + "_self_", + ], + job=dict( + project="posttraining", + group="autoregressive_base", + name="base_4b_example_tealrobot_tp4", + ), + model=create_video2world_model( + model_size="4b", + model_family="cosmos", + backend="pytorch", + tensor_model_parallel_size=4, + batch_size=1, + pixel_chunk_duration=33, + num_video_frames=33, + video_height=640, + video_width=848, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + add_special_tokens=False, + ), + trainer=dict( + max_iter=50000, + grad_accum_iter=1, + grad_scaler_args=dict(enabled=False), + run_validation=False, # No need for validation as epoch <= 1 + distributed_parallelism="ddp", + callbacks=dict( + vid_sampling_tf=dict( + every_n=500, + ), + ), + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Predict1-4B/model.pt", + load_training_state=False, + strict_resume=False, + save_iter=1000, + ), + model_parallel=create_model_parallel_config(), + ), +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + base_4b_example_tealrobotsmall_tp1, + base_4b_example_tealrobot_tp4, + ]: + cs.store( + group="experiment", + package="_global_", + name=_item["job"]["name"], + node=_item, + ) diff --git a/cosmos_predict1/autoregressive/configs/inference/inference_config.py b/cosmos_predict1/autoregressive/configs/inference/inference_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b13ffc382b3fe20d237aa4241411cfac5444c353 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/inference/inference_config.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Union + +import attrs + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig + + +@attrs.define(slots=False) +class DataShapeConfig: + latent_shape: list = [] + num_video_frames: Union[None, int] = None + height: Union[None, int] = None + width: Union[None, int] = None + + +@attrs.define(slots=False) +class SamplingConfig: + """ + Sampling config + Args: + temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + """ + + temperature: float = 0.6 + top_k: int = None + top_p: float = 0.9 + compile_prefill: bool = False + compile_sampling: bool = True + logprobs: bool = False + echo: bool = False + + +@attrs.define(slots=False) +class DiffusionDecoderSamplingConfig: + """ + Diffusion decoder sampling config + Args: + guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8. + sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02. + sigma (float): Initial noise level for the diffusion process. Defaults to 8. + num_steps (int): Number of denoising steps to perform. Defaults to 35. + overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2. + continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16. + continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8. + dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57. + """ + + guidance: float = 1.8 + sigma_min: float = 0.02 + sigma: float = 8 + num_steps: int = 15 + overlap: int = 2 + continuous_tokenizer_channel = 16 + continuous_tokenizer_spatial_compression_ratio = 8 + dd_train_num_video_frames: int = 57 + max_iter: int = 99 + fps: int = 24 + + +@attrs.define(slots=False) +class InferenceConfig: + """ + Inference config + Args: + model_config (ModelConfig): Model config + tokenizer_config (TokenizerConfig): Tokenizer config + ckpt_path (str): Path to the checkpoint + latent_shape (list): Shape of the latent + """ + + model_config: ModelConfig = None + tokenizer_config: TokenizerConfig = None + ckpt_path: str = "" + data_shape_config: DataShapeConfig = None + + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_val": None}, + {"data_shape_config": "video_shape_as_model_config"}, + {"eval_job": None}, + ] + ) diff --git a/cosmos_predict1/autoregressive/configs/registry.py b/cosmos_predict1/autoregressive/configs/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2cdfdcdddc02080d18b1b6d55ac482aac43915 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/registry.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK +from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video +from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR +from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments +from cosmos_predict1.utils import config, log +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.scheduler import WarmupCosineLR + + +def register_checkpoint(cs): + checkpoint_local = config.CheckpointConfig( + save_iter=5000, + broadcast_via_filesystem=True, + ) + cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local) + + +def register_callbacks(cs): + cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) + cs.store( + group="callbacks", + package="trainer.callbacks", + name="video_teacher_forcing", + node=VIDEO_TEACHER_FORCING_CALLBACK, + ) + + +def register_scheduler(cs): + cs.store( + group="scheduler", + package="scheduler", + name="warmup_cosine_lr", + node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8), + ) + cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR) + + +def register_optimizer(cs): + cs.store( + group="optimizer", + package="optimizer", + name="fused_adamw", + node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True), + ) + cs.store( + group="optimizer", + package="optimizer", + name="sgd", + node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9), + ) + + +def register_training_data(cs): + cs.store( + group="data_train", + package="dataloader_train", + name="tealrobot_video_small", + node=get_tealrobot_video(num_frames=33, video_size=[384, 640]), + ) + cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video()) + + +def register_configs(): + log.info("Registering configs for autoregressive_base") + cs = ConfigStore.instance() + register_callbacks(cs) + register_checkpoint(cs) + register_optimizer(cs) + register_scheduler(cs) + register_training_data(cs) + register_experiments(cs) diff --git a/cosmos_predict1/autoregressive/datasets/dataset_utils.py b/cosmos_predict1/autoregressive/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e13360c351f6bc614b346e9c64420af39c088a4 --- /dev/null +++ b/cosmos_predict1/autoregressive/datasets/dataset_utils.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +import torchvision.transforms.functional as transforms_F +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +class Augmentor: + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + r"""Base augmentor class + + Args: + input_keys (list): List of input keys + output_keys (list): List of output keys + args (dict): Arguments associated with the augmentation + """ + self.input_keys = input_keys + self.output_keys = output_keys + self.args = args + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise ValueError("Augmentor not implemented") + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert ( + (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args) + ), "Please specify size in args" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - img_w) // 2 + crop_y0 = (orig_h - img_h) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": img_w, + "crop_h": img_h, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) + return data_dict + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_predict1/autoregressive/datasets/video_dataset.py b/cosmos_predict1/autoregressive/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e129ed4b4a0d1cbae5297a31ba4286a6ae259b8a --- /dev/null +++ b/cosmos_predict1/autoregressive/datasets/video_dataset.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/autoregressive/datasets/video_dataset.py +""" + +import os +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from tqdm import tqdm + +from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig +from cosmos_predict1.autoregressive.datasets.dataset_utils import ( + CenterCrop, + Normalize, + ResizeSmallestSideAspectPreserving, +) + + +class VideoDataset(Dataset): + def __init__(self, config: VideoDatasetConfig): + """Video Dataset class for loading video-to-video generation data.""" + + super().__init__() + self.dataset_dir = config.dataset_dir + self.sequence_interval = config.sequence_interval + self.sequence_length = config.num_frames + self.video_size = config.video_size + self.start_frame_interval = config.start_frame_interval + + self.video_dir = self.dataset_dir + self.video_paths = [os.path.join(self.video_dir, f) for f in os.listdir(self.video_dir) if f.endswith(".mp4")] + print(f"{len(self.video_paths)} videos in total") + + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + + self.resize_transform = ResizeSmallestSideAspectPreserving( + input_keys=["video"], + args={"img_w": self.video_size[1], "img_h": self.video_size[0]}, + ) + self.crop_transform = CenterCrop( + input_keys=["video"], + args={"img_w": self.video_size[1], "img_h": self.video_size[0]}, + ) + self.normalize_transform = Normalize( + input_keys=["video"], + args={"mean": 0.5, "std": 0.5}, + ) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["orig_num_frames"] = n_frames + sample["chunk_index"] = -1 + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + sample["chunk_index"] += 1 + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all(), "Some frame_ids are out of range." + assert (np.array(frame_ids) >= 0).all(), "Some frame_ids are negative." + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + fps = vr.get_avg_fps() + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames) + frames = frames.permute(0, 3, 1, 2) # Rearrange from [T, H, W, C] to [T, C, H, W] + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video, fps = self._get_frames(video_path, frame_ids) + data["video"] = video + data["fps"] = fps + data["num_frames"] = self.sequence_length + data["orig_num_frames"] = sample["orig_num_frames"] + data["chunk_index"] = sample["chunk_index"] + data["frame_start"] = frame_ids[0] + data["frame_end"] = frame_ids[-1] + + data["video_name"] = { + "video_path": video_path, + "start_frame_id": str(frame_ids[0]), + } + + # resize video to smallest side aspect preserving + data = self.resize_transform(data) + # center crop video + data = self.crop_transform(data) + # normalize video + data = self.normalize_transform(data) + + data["video"] = data["video"].permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + config = VideoDatasetConfig(dataset_dir="datasets/cosmos_nemo_assets/videos/") + dataset = VideoDataset(config) + + indices = [0, 1, 2, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data.keys()=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py b/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2e21c745a21193949db41c45b5654c24156c4c --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +import torch + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition, GeneralConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + LatentConditionConfig, + LatentConditionSigmaConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, +) +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class VideoDiffusionDecoderConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoLatentDiffusionDecoderCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoLatentDiffusionDecoderCondition(**output) + + +VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + latent_condition=LatentConditionConfig(), + latent_condition_sigma=LatentConditionSigmaConfig(), +) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c675081114d8b0dcc5d27f42321c98ed8decd78f --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs +from cosmos_predict1.diffusion.config.base.model import LatentDiffusionDecoderModelConfig +from cosmos_predict1.diffusion.config.registry import register_configs +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "basic"}, + {"tokenizer": "tokenizer"}, + {"tokenizer_corruptor": None}, + {"latent_corruptor": None}, + {"pixel_corruptor": None}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config(model=LatentDiffusionDecoderModelConfig()) + + # Specifying values through instances of attrs + c.job.project = "cosmos_video4" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + # Call this function to register config groups for advanced overriding. + register_configs() + register_dd_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) + import_all_modules_from_package("cosmos_predict1.autoregressive.diffusion_decoder.config.inference", reload=True) + return c diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..308232872f98f300922374eb030b999866d5b3dc --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +num_frames = 57 +Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"}, + {"override /conditioner": "video_latent_diffusion_decoder_cond"}, + {"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"}, + "_self_", + ], + job=dict( + group="diffusion_deocder_FT_7Bv1_001", + name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token", + ), + model=dict( + diffusion_decoder_cond_sigma_low=0.0, + diffusion_decoder_cond_sigma_high=0.0, + diffusion_decoder_corrupt_prob=0.0, + condition_on_tokenizer_corruptor_token=True, + latent_shape=[ + 16, + num_frames, + 88, + 160, + ], + tokenizer_corruptor=dict( + pixel_chunk_duration=num_frames, + latent_chunk_duration=1 + (num_frames - 1) // 8, + ), + net=L(DiffusionDecoderGeneralDIT)( + diffusion_decoder_condition_on_sigma=False, + max_img_h=240, + max_img_w=240, + rope_h_extrapolation_ratio=1.5, + rope_w_extrapolation_ratio=1.5, + rope_t_extrapolation_ratio=1, + block_x_format="THWBD", + is_diffusion_decoder=True, + patch_spatial=2, + diffusion_decoder_condition_on_token=True, + diffusion_decoder_token_condition_voc_size=64000, + diffusion_decoder_token_condition_dim=32, + ), + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=num_frames, + ) + ), + conditioner=dict( + latent_condition=dict( + dropout_rate=0.2, + ) + ), + ), + ) +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"], + node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY, +) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fbcf6e3310394eb25d352fe453d23e0a4dcc2bdc --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.diffusion_decoder.config.base.conditioner import ( + VideoLatentDiffusionDecoderConditionerConfig, +) +from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer +from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + + +def get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution: str, + chunk_duration: int, + checkpoint_path: str, +): + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 16 + + return L(DiscreteVideoFSQJITTokenizer)( + enc_fp=checkpoint_path.replace(".jit", "encoder.jit"), + dec_fp=checkpoint_path.replace(".jit", "decoder.jit"), + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor, + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor], + ) + + +def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None): + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + latent_ch=16, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_video_tokenizer_comp8x8x8( + resolution="720", + chunk_duration=121, + checkpoint_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/.jit", + ), + ) + + +def register_corruptor(cs): + cs.store( + group="tokenizer_corruptor", + package="model.tokenizer_corruptor", + name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224", + node=get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution="720", + chunk_duration=49, + checkpoint_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/.jit", + ), + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="video_latent_diffusion_decoder_cond", + node=VideoLatentDiffusionDecoderConditionerConfig, + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_conditioner(cs) + register_corruptor(cs) + register_tokenizer(cs) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/inference.py b/cosmos_predict1/autoregressive/diffusion_decoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7dbdb256dabbbe383fd31eee9c92e1b936d8ce --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/inference.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +from typing import List + +import torch + +from cosmos_predict1.autoregressive.configs.inference.inference_config import DiffusionDecoderSamplingConfig +from cosmos_predict1.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from cosmos_predict1.autoregressive.diffusion_decoder.utils import linear_blend_video_list, split_with_overlap +from cosmos_predict1.utils import log + + +def diffusion_decoder_process_tokens( + model: LatentDiffusionDecoderModel, + indices_tensor: List[torch.Tensor], + dd_sampling_config: DiffusionDecoderSamplingConfig = None, + original_video_example: torch.Tensor = None, + t5_emb_batch: List[torch.Tensor] = None, +): + _, T, H, W = original_video_example.shape + if dd_sampling_config is None: + dd_sampling_config = DiffusionDecoderSamplingConfig() + # indices_tensor is assumed to be a list of tensors with shape 1LHW + data_batch_list = [] + for sample_num, token_CTHW in enumerate(indices_tensor): + token_BCTHW = token_CTHW.unsqueeze(0).unsqueeze(1) + token_BCTHW = split_with_overlap( + token_BCTHW, + (dd_sampling_config.dd_train_num_video_frames - 1) // 8 + 1, + overlap=dd_sampling_config.overlap, + tobf16=False, + ) + data_batch_list.append( + { + "token_chunks": token_BCTHW, + "t5_text_embeddings": t5_emb_batch[sample_num].to(torch.bfloat16), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + # other conditions + "image_size": torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([dd_sampling_config.fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor( + [dd_sampling_config.dd_train_num_video_frames] * 1, dtype=torch.bfloat16 + ).cuda(), + "padding_mask": torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda(), + } + ) + + out_videos_batch = [] + + for idx, data_batch_template in enumerate(data_batch_list): + full_length_sample = [] + iterations = min(len(data_batch_template["token_chunks"]), dd_sampling_config.max_iter) + for iter in range(iterations): + gc.collect() + torch.cuda.empty_cache() + + data_batch = copy.deepcopy(data_batch_template) + data_batch["video"] = data_batch_template["token_chunks"][iter].cuda().to("cuda") + + log.debug(f"Run iter {iter} for video # {idx} at length {data_batch['video'].shape[2]}") + # org_video, + with torch.no_grad(): + samples_latent = model.generate_samples_from_batch( + data_batch, + guidance=dd_sampling_config.guidance, + state_shape=[ + dd_sampling_config.continuous_tokenizer_channel, + dd_sampling_config.continuous_tokenizer_spatial_compression_ratio, + H // 8, + W // 8, + ], + apply_corruptor=False, + preencode_condition=True, # We are using discrete model, so the input is already pre-encoded + num_steps=dd_sampling_config.num_steps, + ) + log.debug(f"Current sample shape {samples_latent.shape} for video # {idx} ") + full_length_sample.append(samples_latent.detach()) + + # Turn off because we remove CP + # distributed.barrier() + del data_batch + + torch.cuda.empty_cache() + + gc.collect() + torch.cuda.empty_cache() + + # Decode full-length samples and free GPU memory + full_length_sample_pixs = [model.decode(item).clamp(-1, 1).cpu() for item in full_length_sample] + torch.cuda.empty_cache() + + # Blend pixel samples + if len(full_length_sample_pixs) > 1: + full_length_sample_pixel_blend = linear_blend_video_list( + full_length_sample_pixs, dd_sampling_config.overlap + )[:, :, :T] + else: + full_length_sample_pixel_blend = full_length_sample_pixs[0][:, :, :T] + + # Batch size of full_length_sample_pixel_blend is always 1 + out_videos_batch.append((1 + full_length_sample_pixel_blend[0].cpu()) / 2) + return out_videos_batch diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/model.py b/cosmos_predict1/autoregressive/diffusion_decoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..93c474bfee40441c1bbed7feedcbe4c3199e6e4f --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/model.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from diffusers import EDMEulerScheduler +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.module import parallel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class LatentDiffusionDecoderModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + """ + latent_corruptor: the corruption module is used to corrupt the latents. It add gaussian noise to the latents. + pixel_corruptor: the corruption module is used to corrupt the pixels. It apply gaussian blur kernel to pixels in a temporal consistent way. + tokenizer_corruptor: the corruption module is used to simulate tokenizer reconstruction errors. + + diffusion decoder noise augmentation pipeline for continuous token condition model: + condition: GT_video [T, H, W] + -> tokenizer_corruptor~(8x8x8) encode -> latent_corruptor -> tokenizer_corruptor~(8x8x8) decode + -> pixel corruptor + -> tokenizer~(1x8x8) encode -> condition [T, H/8, W/8] + GT: GT_video [T, H, W] -> tokenizer~(1x8x8) -> x_t [T, H/8, W/8]. + + diffusion decoder noise augmentation pipeline for discrete token condition model: + condition: GT_video [T, H, W] + -> pixel corruptor + -> discrete tokenizer encode -> condition [T, T/8, H/16, W/16] + GT: GT_video [T, H, W] -> tokenizer~(8x8x8) -> x_t [T, T/8, H/8, W/8]. + + """ + self.latent_corruptor = lazy_instantiate(config.latent_corruptor) + self.pixel_corruptor = lazy_instantiate(config.pixel_corruptor) + self.tokenizer_corruptor = lazy_instantiate(config.tokenizer_corruptor) + + if self.latent_corruptor: + self.latent_corruptor.to(**self.tensor_kwargs) + if self.pixel_corruptor: + self.pixel_corruptor.to(**self.tensor_kwargs) + + if self.tokenizer_corruptor: + if hasattr(self.tokenizer_corruptor, "reset_dtype"): + self.tokenizer_corruptor.reset_dtype() + else: + assert self.pixel_corruptor is not None + + self.diffusion_decoder_cond_sigma_low = config.diffusion_decoder_cond_sigma_low + self.diffusion_decoder_cond_sigma_high = config.diffusion_decoder_cond_sigma_high + self.diffusion_decoder_corrupt_prob = config.diffusion_decoder_corrupt_prob + if hasattr(config, "condition_on_tokenizer_corruptor_token"): + self.condition_on_tokenizer_corruptor_token = config.condition_on_tokenizer_corruptor_token + else: + self.condition_on_tokenizer_corruptor_token = False + + self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.02, sigma_data=self.sigma_data) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + apply_corruptor: bool = False, + corrupt_sigma: float = 0.01, + preencode_condition: bool = False, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + preencode_condition (bool): use pre-computed condition if true, save tokenizer's inference time memory/ + """ + if not preencode_condition: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + if n_sample is None: + n_sample = data_batch[self.input_data_key].shape[0] + + condition, uncondition = self._get_conditions( + data_batch, + is_negative_prompt=is_negative_prompt, + apply_corruptor=apply_corruptor, + corrupt_sigma=corrupt_sigma, + preencode_condition=preencode_condition, + ) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + apply_corruptor: bool = True, + corrupt_sigma: float = 1.5, + preencode_condition: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + self._add_latent_conditions_to_data_batch( + data_batch, + apply_corruptor=apply_corruptor, + corrupt_sigma=corrupt_sigma, + preencode_condition=preencode_condition, + ) + + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition.latent_condition = split_inputs_cp(condition.latent_condition, seq_dim=2, cp_group=cp_group) + condition.latent_condition_sigma = split_inputs_cp( + condition.latent_condition_sigma, seq_dim=2, cp_group=cp_group + ) + uncondition.latent_condition = split_inputs_cp(uncondition.latent_condition, seq_dim=2, cp_group=cp_group) + uncondition.latent_condition_sigma = split_inputs_cp( + uncondition.latent_condition_sigma, seq_dim=2, cp_group=cp_group + ) + return condition, uncondition + + def _add_latent_conditions_to_data_batch( + self, + data_batch: dict, + apply_corruptor: bool = True, + corrupt_sigma: float = 1.5, + preencode_condition: bool = False, + ): + # Latent state + raw_state = data_batch[self.input_data_key] + + if self.condition_on_tokenizer_corruptor_token: + if preencode_condition: + latent_condition = raw_state.to(torch.int32).contiguous() + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition[:, 0]) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.tokenizer_corruptor.encode(corrupted_pixel) + latent_condition = latent_condition[1] if isinstance(latent_condition, tuple) else latent_condition + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition) + latent_condition = latent_condition.unsqueeze(1) + else: + if preencode_condition: + latent_condition = raw_state + corrupted_pixel = self.decode(latent_condition) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.encode(corrupted_pixel).contiguous() + + sigma = ( + torch.rand((latent_condition.shape[0],)).to(**self.tensor_kwargs) * corrupt_sigma + ) # small value to indicate clean video + c_noise_cond = self.scheduler.precondition_noise(sigma=sigma) + if corrupt_sigma != self.diffusion_decoder_cond_sigma_low and self.diffusion_decoder_corrupt_prob > 0: + sigma_expand = sigma.view((-1,) + (1,) * (latent_condition.dim() - 1)) + noise = sigma_expand * torch.randn_like(latent_condition) + latent_condition = latent_condition + noise + data_batch["latent_condition_sigma"] = torch.ones_like(latent_condition[:, 0:1, ::]) * c_noise_cond + data_batch["latent_condition"] = latent_condition + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/network.py b/cosmos_predict1/autoregressive/diffusion_decoder/network.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca7b372b0d8f340fae6017d4e99b15c96a4d874 --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/network.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.module.blocks import PatchEmbed +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class DiffusionDecoderGeneralDIT(GeneralDIT): + def __init__( + self, + *args, + is_diffusion_decoder: bool = True, + diffusion_decoder_condition_on_sigma: bool = False, + diffusion_decoder_condition_on_token: bool = False, + diffusion_decoder_token_condition_voc_size: int = 64000, + diffusion_decoder_token_condition_dim: int = 32, + **kwargs, + ): + # diffusion decoder setting + self.is_diffusion_decoder = is_diffusion_decoder + self.diffusion_decoder_condition_on_sigma = diffusion_decoder_condition_on_sigma + self.diffusion_decoder_condition_on_token = diffusion_decoder_condition_on_token + self.diffusion_decoder_token_condition_voc_size = diffusion_decoder_token_condition_voc_size + self.diffusion_decoder_token_condition_dim = diffusion_decoder_token_condition_dim + super().__init__(*args, **kwargs) + + def initialize_weights(self): + # Initialize transformer layers: + super().initialize_weights() + if self.diffusion_decoder_condition_on_token: + nn.init.constant_(self.token_embedder.weight, 0) + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + # self.pos_embedder.enable_context_parallel(cp_group) + self.pos_embedder.cp_group = cp_group + + if self.extra_per_block_abs_pos_emb: + # self.extra_pos_embedder.enable_context_parallel(cp_group) + self.extra_pos_embedder.cp_group = cp_group + + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff", "cross_attn", "ca"]: + continue + elif layer.block.attn.backend == "transformer_engine": + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + + log.debug("[CP] Disable context parallelism.") + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + is_diffusion_decoder, + diffusion_decoder_token_condition_dim, + diffusion_decoder_condition_on_sigma, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.is_diffusion_decoder, + self.diffusion_decoder_token_condition_dim, + self.diffusion_decoder_condition_on_sigma, + ) + in_channels = ( + in_channels + in_channels + if (is_diffusion_decoder and not self.diffusion_decoder_condition_on_token) + else in_channels + ) + in_channels = in_channels + 1 if diffusion_decoder_condition_on_sigma else in_channels + in_channels = ( + in_channels + self.diffusion_decoder_token_condition_dim + if self.diffusion_decoder_condition_on_token + else in_channels + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + if self.diffusion_decoder_condition_on_token: + self.token_embedder = nn.Embedding( + self.diffusion_decoder_token_condition_voc_size, self.diffusion_decoder_token_condition_dim + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.diffusion_decoder_condition_on_token: + latent_condition = self.token_embedder(latent_condition) + B, _, T, H, W, _ = latent_condition.shape + latent_condition = rearrange(latent_condition, "B 1 T H W D -> (B T) (1 D) H W") + + latent_condition = transforms.functional.resize( + latent_condition, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.BILINEAR + ) + latent_condition = rearrange(latent_condition, "(B T) D H W -> B D T H W ", B=B, T=T) + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition], dim=1) + if self.diffusion_decoder_condition_on_sigma: + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition_sigma], dim=1) + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/utils.py b/cosmos_predict1/autoregressive/diffusion_decoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c584c7c9a5e03bcb3b808d053f89e7c2aeaf9cf --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/utils.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True): + """ + Splits the video tensor into chunks of num_video_frames with a specified overlap. + + Args: + - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width]. + - num_video_frames (int): Number of frames per chunk. + - overlap (int): Number of overlapping frames between chunks. + + Returns: + - List of torch.Tensors: List of video chunks with overlap. + """ + # Get the dimensions of the input tensor + B, C, T, H, W = video_BCTHW.shape + + # Ensure overlap is less than num_video_frames + assert overlap < num_video_frames, "Overlap should be less than num_video_frames." + + # List to store the chunks + chunks = [] + + # Step size for the sliding window + step = num_video_frames - overlap + + # Loop through the time dimension (T) with the sliding window + for start in range(0, T - overlap, step): + end = start + num_video_frames + # Handle the case when the last chunk might go out of bounds + if end > T: + # Get the last available frame + num_padding_frames = end - T + chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect") + else: + # Regular case: no padding needed + chunk = video_BCTHW[:, :, start:end, :, :] + if tobf16: + chunks.append(chunk.to(torch.bfloat16)) + else: + chunks.append(chunk) + return chunks + + +def linear_blend_video_list(videos, D): + """ + Linearly blends a list of videos along the time dimension with overlap length D. + + Parameters: + - videos: list of video tensors, each of shape [b, c, t, h, w] + - D: int, overlap length + + Returns: + - output_video: blended video tensor of shape [b, c, L, h, w] + """ + assert len(videos) >= 2, "At least two videos are required." + b, c, t, h, w = videos[0].shape + N = len(videos) + + # Ensure all videos have the same shape + for video in videos: + assert video.shape == (b, c, t, h, w), "All videos must have the same shape." + + # Calculate total output length + L = N * t - D * (N - 1) + output_video = torch.zeros((b, c, L, h, w), device=videos[0].device) + + output_index = 0 # Current index in the output video + + for i in range(N): + if i == 0: + # Copy frames from the first video up to t - D + output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :] + output_index += t - D + else: + # Blend overlapping frames between videos[i-1] and videos[i] + blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device) + + for j in range(D): + w1 = 1 - blend_weights[j] + w2 = blend_weights[j] + frame_from_prev = videos[i - 1][:, :, t - D + j, :, :] + frame_from_curr = videos[i][:, :, j, :, :] + output_frame = w1 * frame_from_prev + w2 * frame_from_curr + output_video[:, :, output_index, :, :] = output_frame + output_index += 1 + + if i < N - 1: + # Copy non-overlapping frames from current video up to t - D + frames_to_copy = t - 2 * D + if frames_to_copy > 0: + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][ + :, :, D : t - D, :, : + ] + output_index += frames_to_copy + else: + # For the last video, copy frames from D to t + frames_to_copy = t - D + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :] + output_index += frames_to_copy + + return output_video diff --git a/cosmos_predict1/autoregressive/inference/__init__.py b/cosmos_predict1/autoregressive/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/inference/base.py b/cosmos_predict1/autoregressive/inference/base.py new file mode 100644 index 0000000000000000000000000000000000000000..214836c6ada2eaf8dd9034177b15400fb5eb893a --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/base.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos_predict1.autoregressive.inference.world_generation_pipeline import ARBaseGenerationPipeline +from cosmos_predict1.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from cosmos_predict1.utils import log + + +def parse_args(): + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-Predict1-4B", + ) + parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) + args = parser.parse_args() + return args + + +def main(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple images/videos from input + - Generating videos from images/videos + - Saving the generated videos to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "base" # When the inference_type is "base", AR model does not take text as input, the world generation is purely based on the input video + sampling_config = validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + + # Initialize base generation model pipeline + pipeline = ARBaseGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + disable_guardrail=args.disable_guardrail, + parallel_size=args.num_gpus, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + + for idx, input_filename in enumerate(input_videos): + inp_vid = input_videos[input_filename] + # Generate video + log.info(f"Run with image or video path: {input_filename}") + out_vid = pipeline.generate( + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked base generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + + imageio.mimsave(out_vid_path, out_vid, fps=25) + log.info(f"Saved video to {out_vid_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos_predict1/autoregressive/inference/video2world.py b/cosmos_predict1/autoregressive/inference/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..532919570a433ac21e76333db87822487c7e40b3 --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/video2world.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos_predict1.autoregressive.inference.world_generation_pipeline import ARVideo2WorldGenerationPipeline +from cosmos_predict1.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from cosmos_predict1.utils import log +from cosmos_predict1.utils.io import read_prompts_from_file + + +def parse_args(): + parser = argparse.ArgumentParser(description="Prompted video to world generation demo script") + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-Predict1-5B-Video2World", + ) + parser.add_argument( + "--input_type", + type=str, + default="text_and_video", + choices=["text_and_image", "text_and_video"], + help="Input types", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload T5 model after inference", + ) + args = parser.parse_args() + return args + + +def main(args): + """Run prompted video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "video2world" # When the inference_type is "video2world", AR model takes both text and video as input, the world generation is based on the input text prompt and video + sampling_config = validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + + # Initialize prompted base generation model pipeline + pipeline = ARVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + disable_guardrail=args.disable_guardrail, + parallel_size=args.num_gpus, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + # Load input prompt(s) + if args.batch_input_path: + prompts_list = read_prompts_from_file(args.batch_input_path) + else: + prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}] + + # Iterate through prompts + for idx, prompt_entry in enumerate(prompts_list): + video_path = prompt_entry["visual_input"] + input_filename = os.path.basename(video_path) + + # Check if video exists in loaded videos + if input_filename not in input_videos: + log.critical(f"Input file {input_filename} not found, skipping prompt.") + continue + + inp_vid = input_videos[input_filename] + inp_prompt = prompt_entry["prompt"] + + # Generate video + log.info(f"Run with input: {prompt_entry}") + out_vid = pipeline.generate( + inp_prompt=inp_prompt, + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked video2world generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + imageio.mimsave(out_vid_path, out_vid, fps=25) + + log.info(f"Saved video to {out_vid_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py b/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fd5b392395d777c120528241dbf89aef050efa --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py @@ -0,0 +1,1031 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from einops import rearrange +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.configs.inference.inference_config import ( + DataShapeConfig, + DiffusionDecoderSamplingConfig, + InferenceConfig, + SamplingConfig, +) +from cosmos_predict1.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens +from cosmos_predict1.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from cosmos_predict1.autoregressive.model import AutoRegressiveModel, update_model_config +from cosmos_predict1.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving +from cosmos_predict1.autoregressive.utils.parallel import broadcast_data_batch_in_tp_cp_group, get_batch_on_this_cp_rank +from cosmos_predict1.diffusion.inference.inference_utils import ( + load_model_by_config, + load_network_model, + load_tokenizer_model, +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline + + +def detect_model_size_from_ckpt_path(ckpt_path: str) -> str: + """Detect model size from checkpoint path. + + Args: + ckpt_path: Path to model checkpoint file + + Returns: + str: Model size ('4b', '5b', '12b', or '13b') + + Examples: + >>> detect_model_size_from_ckpt_path("model_4B.pt") + '4b' + """ + model_size = "4b" + if "4B" in ckpt_path: + model_size = "4b" + elif "5B" in ckpt_path: + model_size = "5b" + elif "12B" in ckpt_path: + model_size = "12b" + elif "13B" in ckpt_path: + model_size = "13b" + else: + log.warning(f"Could not detect model size from checkpoint path: {ckpt_path}") + return model_size + + +def create_inference_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + model_size: str = "4b", + parallel_size: int = 4, + batch_size: int = 1, + inference_type: str = "base", +) -> InferenceConfig: + """Create inference configuration for model. + + Args: + model_ckpt_path: Path to model checkpoint + tokenizer_ckpt_path: Path to tokenizer checkpoint + model_size: Size of model ('4b', '5b', '12b', '13b') + parallel_size: Number of GPUs for parallelism + batch_size: Batch size for inference + inference_type: Type of inference ('base' or 'video2world') + + Returns: + InferenceConfig: Configuration object for inference + """ + model_size = model_size.lower() + # For inference config + kwargs = {} + if inference_type == "video2world": + kwargs.update( + dict( + insert_cross_attn=True, + insert_cross_attn_every_k_layers=1, + context_dim=1024, + training_type="text_to_video", + apply_abs_pos_emb=True, + ) + ) + if model_size == "5b": + model_size = "4b" # The base model (excluding the cross attention layers) is the 4B model + elif model_size == "13b": + model_size = "12b" # The base model (excluding the cross attention layers) is the 12B model + else: + raise ValueError(f"Unsupported model size for video2world inference_type: {model_size}") + else: + assert inference_type == "base", f"Unsupported inference_type: {inference_type}" + + model_config, tokenizer_config = create_video2world_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + tensor_model_parallel_size=parallel_size, + rope_dim="3D", + add_special_tokens=False, + pixel_chunk_duration=33, + num_video_frames=33, + num_condition_latents_t=1, + batch_size=batch_size, + video_height=640, + video_width=1024, + **kwargs, + ) + + inference_config = InferenceConfig() + + inference_config.model_config = model_config + inference_config.tokenizer_config = tokenizer_config + + inference_config.data_shape_config = DataShapeConfig( + num_video_frames=model_config.num_video_frames, + height=model_config.video_height, + width=model_config.video_width, + latent_shape=model_config.video_latent_shape, + ) + inference_config.model_config.fuse_qkv = False + return inference_config + + +class ARBaseGenerationPipeline(BaseWorldGenerationPipeline): + """Base class for autoregressive world generation models. + + Handles the core functionality for generating videos using autoregressive models. + Provides configurable GPU memory management through model offloading and supports + different inference types for video generation. + + Attributes: + inference_config (InferenceConfig): Configuration for model inference + tokenizer_config (TokenizerConfig): Configuration for tokenizer + disable_diffusion_decoder (bool): Whether diffusion decoder is disabled + parallel_size (int): Number of GPUs for parallelism + latent_shape (List[int]): Shape of video latents [T, H, W] + _supported_context_len (int): Supported context window length + latent_chunk_duration (int): Duration of latent chunks + pixel_chunk_duration (int): Duration of pixel chunks + diffusion_decoder_model (Optional[nn.Module]): The diffusion decoder model + """ + + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + has_text_input: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + disable_guardrail: bool = False, + parallel_size: int = 1, + ): + """Initialize the autoregressive world generation pipeline. + + Args: + inference_type: Type of world generation ('base' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the AR checkpoint to load + has_text_input: Whether the pipeline takes text input for world generation + disable_diffusion_decoder: Whether to disable the diffusion decoder stage + offload_network: Whether to offload AR model from GPU after use + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + disable_guardrail: Whether to disable guardrail + parallel_size: Number of GPUs for parallelism + + Raises: + AssertionError: If inference_type is not 'base' or 'video2world' + """ + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + + # Create inference config + model_size = detect_model_size_from_ckpt_path(checkpoint_name) + model_ckpt_path = os.path.join(checkpoint_dir, checkpoint_name, "model.pt") + tokenizer_ckpt_path = os.path.join(checkpoint_dir, "Cosmos-Tokenize1-DV8x16x16-720p/ema.jit") + + inference_config: InferenceConfig = create_inference_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + parallel_size=parallel_size, + inference_type=inference_type, + ) + + self.inference_config = inference_config + self.parallel_size = parallel_size + self.disable_diffusion_decoder = disable_diffusion_decoder + + if not disable_diffusion_decoder: + self.diffusion_decoder_ckpt_path = os.path.join( + checkpoint_dir, "Cosmos-Predict1-7B-Decoder-DV8x16x16ToCV8x8x8-720p/model.pt" + ) + self.diffusion_decoder_config = "DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token" + self.diffusion_decoder_tokenizer_path = os.path.join(checkpoint_dir, "Cosmos-Tokenize1-CV8x8x8-720p") + self.dd_sampling_config = DiffusionDecoderSamplingConfig() + aux_vars_path = os.path.join(os.path.dirname(self.diffusion_decoder_ckpt_path), "aux_vars.pt") + # We use a generic prompt when no text prompts are available for diffusion decoder. + # Generic prompt used - "high quality, 4k, high definition, smooth video" + aux_vars = torch.load(aux_vars_path, weights_only=True) + self.generic_prompt = dict() + self.generic_prompt["context"] = aux_vars["context"].cuda() + self.generic_prompt["context_mask"] = aux_vars["context_mask"].cuda() + + self.latent_shape = inference_config.data_shape_config.latent_shape # [L, 40, 64] + self._supported_context_len = _SUPPORTED_CONTEXT_LEN + self.tokenizer_config = inference_config.tokenizer_config + + self.offload_diffusion_decoder = offload_diffusion_decoder + self.diffusion_decoder_model = None + if not self.offload_diffusion_decoder and not disable_diffusion_decoder: + self._load_diffusion_decoder() + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + has_text_input=has_text_input, + offload_guardrail_models=offload_guardrail_models, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + disable_guardrail=disable_guardrail, + offload_text_encoder_model=True, + ) + + def _load_model(self): + """Load and initialize the autoregressive model. + + Sets up parallelism if enabled (parallel_size > 1). + Initializes model parallel state and seeds for reproducibility. + Creates and configures the autoregressive model with appropriate settings. + """ + if self.parallel_size > 1: + model_parallel = ModelParallelConfig( + tensor_model_parallel_size=self.parallel_size, + context_parallel_size=1, + bf16=True, + params_dtype=getattr(torch, "bfloat16"), + ) + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + model_parallel_cuda_manual_seed(0) + parallel_state.destroy_model_parallel() + else: + model_parallel = None + self.model_config = self.inference_config.model_config + self.model_config = update_model_config( + self.model_config, + inference_tensor_parallel_size=self.parallel_size, + ) + self.model = AutoRegressiveModel( + config=self.inference_config.model_config, + model_parallel=model_parallel, + ) + + def _load_network(self): + """Load network weights for the autoregressive model. + + Sets up distributed training if available and handles checkpoint loading. + Supports tensor parallel model sharding when enabled. Coordinates across + distributed process groups if needed. + """ + if dist.is_available() and dist.is_initialized(): + # ddp_group = parallel_state.get_data_parallel_group() + # tp_group = parallel_state.get_tensor_model_parallel_group() + # dist.barrier(group=ddp_group) + # dist.barrier(group=tp_group) + pass + if "{rank}" in self.model_config.ckpt_path: + shard_checkpoint = False + else: + shard_checkpoint = ( + dist.is_available() and dist.is_initialized() + ) # Take the TP-rank specific checkpoint when initializing the model + + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + self.model.load_ar_model( + shard_checkpoint=shard_checkpoint, tokenizer_config=self.inference_config.tokenizer_config + ) + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + def _load_tokenizer(self): + """Load and initialize the tokenizer model. + + Configures the tokenizer using settings from inference_config and + attaches it to the autoregressive model. + """ + self.model.load_tokenizer(tokenizer_config=self.inference_config.tokenizer_config) + + def _load_diffusion_decoder(self): + """Load and initialize the diffusion decoder model. + + Sets up context parallelism if enabled. Loads model weights, + and configures parallel processing groups as needed. + Handles model parallel state initialization and management. + """ + self.diffusion_decoder_model = load_model_by_config( + config_job_name=self.diffusion_decoder_config, + config_file="cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py", + model_class=LatentDiffusionDecoderModel, + ) + load_network_model(self.diffusion_decoder_model, self.diffusion_decoder_ckpt_path) + load_tokenizer_model(self.diffusion_decoder_model, self.diffusion_decoder_tokenizer_path) + + def _offload_diffusion_decoder(self): + """Offload diffusion decoder model from GPU memory.""" + if self.diffusion_decoder_model is not None: + del self.diffusion_decoder_model + self.diffusion_decoder_model = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model_with_offload( + self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the autoregressive model to generate video tokens. + + Takes input video frames and generates new video tokens using the autoregressive model. + Handles context frame selection and token generation. + + Args: + inp_vid (torch.Tensor): Input video tensor of shape + num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W). + seed (int): Random seed for generation + sampling_config (SamplingConfig): Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors, + List of token index tensors, + List of prompt embedding tensors + ) + """ + # Choosing the context length from list of available contexts + out_videos_cur_batch, indices_tensor_cur_batch = self._run_model( + inp_vid, num_input_frames, seed, sampling_config + ) + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos_cur_batch, indices_tensor_cur_batch + + def _run_model( + self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the autoregressive model to generate video tokens. + + Takes input video frames and generates new video tokens using the autoregressive model. + Handles context frame selection and token generation. + + Args: + inp_vid (torch.Tensor): Input video tensor of shape + num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W). + seed (int): Random seed for generation + sampling_config (SamplingConfig): Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors, + List of token index tensors, + List of prompt embedding tensors + ) + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + + # Choosing the context length from list of available contexts + latent_context_t_size = 0 + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using input size of {context_used} frames") + + data_batch = {"video": inp_vid} + data_batch = misc.to(data_batch, "cuda") + + T, H, W = self.latent_shape + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="video", + num_chunks_to_generate=1, + seed=seed, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch, indices_tensor_cur_batch + + def _run_diffusion_decoder( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Process generated tokens through the diffusion decoder. + + Enhances video quality through diffusion-based decoding. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(context_parallel_size=self.parallel_size) + process_group = parallel_state.get_context_parallel_group() + self.diffusion_decoder_model.net.enable_context_parallel(process_group) + + out_videos_cur_batch_dd = diffusion_decoder_process_tokens( + model=self.diffusion_decoder_model, + indices_tensor=indices_tensor_cur_batch, + dd_sampling_config=self.dd_sampling_config, + original_video_example=out_videos_cur_batch[0], + t5_emb_batch=t5_emb_batch, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch_dd + + def _run_diffusion_decoder_with_offload( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Run diffusion decoder with memory management. + + Loads decoder if needed, processes videos, and offloads decoder afterward + if configured in offload_diffusion_decoder. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + if self.offload_diffusion_decoder: + self._load_diffusion_decoder() + out_videos_cur_batch = self._run_diffusion_decoder(out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch) + if self.offload_diffusion_decoder: + self._offload_diffusion_decoder() + return out_videos_cur_batch + + def generate( + self, + inp_vid: torch.Tensor, + sampling_config: SamplingConfig, + num_input_frames: int = 9, + seed: int = 0, + ) -> np.ndarray | None: + """Generate a video continuation from input frames. + + Pipeline steps: + 1. Generates video tokens using autoregressive model + 2. Optionally enhances quality via diffusion decoder + 3. Applies safety checks if enabled + + Args: + inp_vid: Input video tensor of shape (batch_size, time, channels=3, height, width) + sampling_config: Parameters controlling the generation process + num_input_frames: Number of input frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch = self._run_model_with_offload( + inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch=[self.generic_prompt["context"]] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video + + @torch.inference_mode() + def generate_partial_tokens_from_data_batch( + self, + data_batch: dict, + num_tokens_to_generate: int, + sampling_config: SamplingConfig, + tokenizer_config: TokenizerConfig, + latent_shape: list[int], + task_condition: str, + num_chunks_to_generate: int = 1, + seed: int = 0, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Generate video tokens from partial input tokens with conditioning. + + Handles token generation and decoding process: + 1. Processes input batch and applies conditioning + 2. Generates specified number of new tokens + 3. Decodes tokens to video frames + + Args: + data_batch: Dictionary containing input data including video and optional context + num_tokens_to_generate: Number of tokens to generate + sampling_config: Configuration for sampling parameters + tokenizer_config: Configuration for tokenizer, including video tokenizer settings + latent_shape: Shape of video latents [T, H, W] + task_condition: Type of generation task ('video' or 'text_and_video') + num_chunks_to_generate: Number of chunks to generate (default: 1) + seed: Random seed for generation (default: 0) + + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Input videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + log.debug(f"Starting generate_partial_tokens_from_data_batch with seed {seed}") + log.debug(f"Number of tokens to generate: {num_tokens_to_generate}") + log.debug(f"Latent shape: {latent_shape}") + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + broadcast_data_batch_in_tp_cp_group(data_batch) + + video_token_start = tokenizer_config.video_tokenizer.tokenizer_offset + video_vocab_size = tokenizer_config.video_tokenizer.vocab_size + video_token_end = video_token_start + video_vocab_size + + logit_clipping_range = [video_token_start, video_token_end] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + assert logit_clipping_range == [ + 0, + self.model.tokenizer.video_vocab_size, + ], f"logit_clipping_range {logit_clipping_range} is not supported for fast generate. Expected [0, {self.model.tokenizer.video_vocab_size}]" + + out_videos = {} + out_indices_tensors = {} + + # for text2world, we only add a token at the beginning of the video tokens, this applies to 5B and 13B models + if self.model.tokenizer.tokenizer_config.training_type == "text_to_video": + num_bov_tokens = 1 + num_eov_tokens = 0 + else: + num_eov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + num_bov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + + chunk_idx = 0 + out_videos[chunk_idx] = [] + out_indices_tensors[chunk_idx] = [] + + # get the context embedding and mask + context = data_batch.get("context", None) if task_condition != "video" else None + context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None + if context is not None: + context = misc.to(context, "cuda").detach().clone() + if context_mask is not None: + context_mask = misc.to(context_mask, "cuda").detach().clone() + + # get the video tokens + data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch) + data_tokens = misc.to(data_tokens, "cuda").detach().clone() + if parallel_state.get_context_parallel_world_size() > 1: + data_tokens = get_batch_on_this_cp_rank(data_tokens) + batch_size = data_tokens.shape[0] + + for sample_num in range(batch_size): + input_tokens = data_tokens[sample_num][0 : token_boundaries["video"][sample_num][1]] # [B, L] + input_tokens = [ + input_tokens[0 : -num_tokens_to_generate - num_eov_tokens].tolist() + ] # -1 is to exclude eov token + log.debug( + f"Run sampling. # input condition tokens: {len(input_tokens[0])}; # generate tokens: {num_tokens_to_generate + num_eov_tokens}; " + f"full length of the data tokens: {len(data_tokens[sample_num])}: {data_tokens[sample_num]}" + ) + video_start_boundary = token_boundaries["video"][sample_num][0] + num_bov_tokens + + video_decoded, indices_tensor = self.generate_video_from_tokens( + prompt_tokens=input_tokens, + latent_shape=latent_shape, + video_start_boundary=video_start_boundary, + max_gen_len=num_tokens_to_generate, + sampling_config=sampling_config, + logit_clipping_range=logit_clipping_range, + seed=seed, + context=context, + context_mask=context_mask, + ) # BCLHW, range [0, 1] + + # For the first chunk, we store the entire generated video + out_videos[chunk_idx].append(video_decoded[sample_num].detach().clone()) + out_indices_tensors[chunk_idx].append(indices_tensor[sample_num].detach().clone()) + + output_videos = [] + output_indice_tensors = [] + for sample_num in range(len(out_videos[0])): + tensors_to_concat = [out_videos[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate)] + concatenated = torch.cat(tensors_to_concat, dim=1) + output_videos.append(concatenated) + + indices_tensor_to_concat = [ + out_indices_tensors[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate) + ] + concatenated_indices_tensor = torch.cat(indices_tensor_to_concat, dim=1) # BLHW + output_indice_tensors.append(concatenated_indices_tensor) + + return output_videos, output_indice_tensors + + def generate_video_from_tokens( + self, + prompt_tokens: list[torch.Tensor], + latent_shape: list[int], + video_start_boundary: int, + max_gen_len: int, + sampling_config: SamplingConfig, + logit_clipping_range: list[int], + seed: int = 0, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Function to generate video from input tokens. These input tokens can be initial text tokens (in case of text to video), + or partial ground truth tokens. + + Handles the core token-to-video generation process: + 1. Generates new tokens using the autoregressive model + 2. Handles padding and token sequence completion + 3. Reshapes and processes generated tokens + 4. Decodes final tokens into video frames + + Args: + model (AutoRegressiveModel): LLama model instance + prompt_tokens (list): Prompt tokens used by the model + latent_shape (list): Shape of the video latents + video_start_boundary (int): Index where the video tokens start + max_gen_len (int): Maximum length of the tokens that needs to be generated + sampling_config (SamplingConfig): Config used by sampler during inference + logit_clipping_range (list): Range of indices in the logits to be clipped, e.g. [video_token_start, video_token_end] + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + total_seq_len = np.prod(latent_shape) + + assert not sampling_config.logprobs + + stop_tokens = self.model.tokenizer.stop_tokens + if self.offload_tokenizer: + self._offload_tokenizer() + if self.offload_network: + self._load_network() + + generation_tokens, _ = self.model.generate( + prompt_tokens=prompt_tokens, + temperature=sampling_config.temperature, + top_p=sampling_config.top_p, + echo=sampling_config.echo, + seed=seed, + context=context, + context_mask=context_mask, + max_gen_len=max_gen_len, + compile_sampling=sampling_config.compile_sampling, + compile_prefill=sampling_config.compile_prefill, + stop_tokens=stop_tokens, + verbose=True, + ) + generation_tokens = generation_tokens[:, video_start_boundary:] + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + if generation_tokens.shape[1] < total_seq_len: + log.warning( + f"Generated video tokens (shape:{generation_tokens.shape}) shorted than expected {total_seq_len}. Could be the model produce end token early. Repeat the last token to fill the sequence in order for decoding." + ) + padding_len = total_seq_len - generation_tokens.shape[1] + padding_tokens = generation_tokens[:, [-1]].repeat(1, padding_len) + generation_tokens = torch.cat([generation_tokens, padding_tokens], dim=1) + # Cast to LongTensor + indices_tensor = generation_tokens.long() + # First, we reshape the generated tokens into batch x time x height x width + indices_tensor = rearrange( + indices_tensor, + "B (T H W) -> B T H W", + T=latent_shape[0], + H=latent_shape[1], + W=latent_shape[2], + ) + log.debug(f"generated video tokens {len(generation_tokens[0])} -> reshape: {indices_tensor.shape}") + # If logit clipping range is specified, offset the generated indices by the logit_clipping_range[0] + # Video decoder always takes tokens in the range (0, N-1). So, this offset is needed. + if len(logit_clipping_range) > 0: + indices_tensor = indices_tensor - logit_clipping_range[0] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + # Now decode the video using tokenizer. + video_decoded = self.model.tokenizer.video_tokenizer.decode(indices_tensor.cuda()) + # Normalize decoded video from [-1, 1] to [0, 1], and clip value + video_decoded = (video_decoded * 0.5 + 0.5).clamp_(0, 1) + return video_decoded, indices_tensor + + +class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline): + """Video-to-world generation pipeline with text conditioning capabilities. + + Extends the base autoregressive generation pipeline by adding: + - Text prompt processing and embedding + - Text-conditioned video generation + - Additional safety checks for text input + - Memory management for text encoder model + + Enables generating video continuations that are guided by both + input video frames and text descriptions. + + Additional attributes compared to ARBaseGenerationPipeline: + offload_text_encoder_model (bool): Whether to offload text encoder from GPU after use + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_name: str, + inference_type: str = None, + has_text_input: bool = True, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + disable_guardrail: bool = False, + parallel_size: int = 1, + ): + """Initialize text-conditioned video generation pipeline. + + Args: + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the checkpoint to load + inference_type: Type of world generation workflow + has_text_input: Whether the pipeline takes text input for world generation + disable_diffusion_decoder: Whether to disable diffusion decoder stage + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + offload_network: Whether to offload AR model from GPU + offload_tokenizer: Whether to offload tokenizer from GPU + disable_guardrail: Whether to disable guardrail + offload_text_encoder_model: Whether to offload text encoder + parallel_size: Number of GPUs for parallelism + """ + super().__init__( + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + inference_type=inference_type, + has_text_input=has_text_input, + disable_diffusion_decoder=disable_diffusion_decoder, + offload_guardrail_models=offload_guardrail_models, + offload_diffusion_decoder=offload_diffusion_decoder, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + disable_guardrail=disable_guardrail, + parallel_size=parallel_size, + ) + self.offload_text_encoder_model = offload_text_encoder_model + if not self.offload_text_encoder_model: + self._load_text_encoder_model() + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Run model generation with memory management. + + Executes generation process and handles model offloading to manage GPU memory. + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + List of prompt embedding tensors + ) + """ + out_videos, indices_tensor, prompt_embedding = self._run_model( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos, indices_tensor, prompt_embedding + + def _run_model( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + """Run core model generation process. + + Handles text-conditioned video generation: + 1. Prepares data batch with text embeddings and video + 2. Determines appropriate context length + 3. Generates video tokens with text conditioning + 4. Processes output tensors + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + Text context tensor + ) + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + + data_batch = {} + data_batch["context"], data_batch["context_mask"] = prompt_embedding, prompt_mask + T, H, W = self.latent_shape + + if sampling_config is None: + sampling_config = self.sampling_config + if type(inp_vid) is list: + batch_size = len(inp_vid) + elif type(inp_vid) is torch.Tensor: + batch_size = 1 + data_batch["context"] = data_batch["context"].repeat(batch_size, 1, 1) + data_batch["context_mask"] = data_batch["context_mask"].repeat(batch_size, 1) + data_batch["context_mask"] = torch.ones_like(data_batch["context_mask"]).bool() + + latent_context_t_size = 0 + + # Choosing the context length from list of available contexts + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using context of {context_used} frames") + + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + data_batch["video"] = inp_vid + data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1) + + data_batch = misc.to(data_batch, "cuda") + + log.debug(f" num_tokens_to_generate: {num_gen_tokens}") + log.debug(f" sampling_config: {sampling_config}") + log.debug(f" tokenizer_config: {self.tokenizer_config}") + log.debug(f" latent_shape: {self.latent_shape}") + log.debug(f" latent_context_t_size: {latent_context_t_size}") + log.debug(f" seed: {seed}") + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="text_and_video", + seed=seed, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch, indices_tensor_cur_batch, data_batch["context"] + + def generate( + self, + inp_prompt: str, + inp_vid: torch.Tensor, + num_input_frames: int = 9, + seed: int = 0, + sampling_config: SamplingConfig = None, + ) -> np.ndarray | None: + """Generate a video guided by text prompt and input frames. + + Pipeline steps: + 1. Validates text prompt safety if enabled + 2. Converts text to embeddings + 3. Generates video with text conditioning + 4. Enhances quality via diffusion decoder + 5. Applies video safety checks if enabled + + Args: + inp_prompt: Text prompt to guide the generation + inp_vid: Input video tensor with shape (batch_size, time, channels=3, height, width) + num_input_frames: Number of frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + if not self.disable_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(inp_prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + log.info("Run text embedding on prompt") + prompt_embeddings, prompt_masks = self._run_text_embedding_on_prompt_with_offload([inp_prompt]) + prompt_embedding = prompt_embeddings[0] + prompt_mask = prompt_masks[0] + log.info("Finish text embedding on prompt") + + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch, prompt_embedding = self._run_model_with_offload( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, [prompt_embedding] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video diff --git a/cosmos_predict1/autoregressive/model.py b/cosmos_predict1/autoregressive/model.py new file mode 100644 index 0000000000000000000000000000000000000000..38179ea5a952600f733cbdc7c7a679ed9a737f8f --- /dev/null +++ b/cosmos_predict1/autoregressive/model.py @@ -0,0 +1,660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from safetensors.torch import load_file +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector +from cosmos_predict1.autoregressive.networks.transformer import Transformer +from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config +from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size +from cosmos_predict1.autoregressive.utils.checkpoint import ( + get_partial_state_dict, + obtain_tensor_parallel_state_dict, + process_state_dict, + substrings_to_ignore, +) +from cosmos_predict1.autoregressive.utils.sampling import decode_n_tokens, decode_one_token, prefill +from cosmos_predict1.utils import log, misc + + +def update_model_config(model_config, inference_tensor_parallel_size): + if inference_tensor_parallel_size > 1: + log.warning(f"Setting tensor parallel size to {inference_tensor_parallel_size}") + setattr( + model_config, + "tensor_model_parallel_size", + inference_tensor_parallel_size, + ) + + if "{rank}" in model_config.ckpt_path: + tp_rank = parallel_state.get_tensor_model_parallel_rank() + model_config.ckpt_path = model_config.ckpt_path.format(rank=tp_rank) + return model_config + + +class AutoRegressiveModel(torch.nn.Module): + """ + A class to build and use a AutoRegressiveModel model for text generation. + + Methods: + build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + generate: Generate text sequences based on provided prompts using the language generation model. + """ + + def __init__( + self, + model: Transformer = None, + tokenizer: DiscreteMultimodalTokenizer = None, + config: ModelConfig = None, + model_parallel: ModelParallelConfig = None, + vision_encoder: VisionTransformer = None, + mm_projector: MultimodalProjector = None, + ): + """ + Initialize the AutoRegressiveModel instance with a model and tokenizer. + + Args: + model (Transformer): The Transformer model for text generation. + tokenizer (Tokenizer): The tokenizer for encoding and decoding text. + config (Config): The configuration for the AutoRegressiveModel model. + model_parallel (ModelParallelConfig): The model parallel configuration for the AutoRegressiveModel model. + vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model. + mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model. + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + + self.vision_encoder = vision_encoder + self.mm_projector = mm_projector + self.model_parallel = model_parallel + + @property + def precision(self): + return self.model.precision + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def load_ar_model( + self, + shard_checkpoint, + tokenizer_config, + ): + """ + Load the AR model. + """ + model_config = self.config + tensor_parallel_size = 1 if self.model_parallel is None else self.model_parallel.tensor_model_parallel_size + assert tensor_parallel_size == model_config["tensor_model_parallel_size"] + ckpt_path = model_config.ckpt_path + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + model_parallel=self.model_parallel, + tokenizer_config=tokenizer_config, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}" + ) + vocab_size = update_vocab_size( + existing_vocab_size=0, + to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size, + training_type=tokenizer_config.training_type, + add_special_tokens=False, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}" + ) + # Perform vocab expansion + if vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer_config.training_type == "text_to_video") + model.expand_vocab( + vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if self.model_parallel is not None: + assert self.model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + self.model = model.to(precision).to("cuda") + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + def load_tokenizer(self, tokenizer_config): + """ + Load the tokenizer. + """ + self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + + @staticmethod + def build( + model_config: ModelConfig = ModelConfig(), + tokenizer_config: TokenizerConfig = None, + model_parallel: ModelParallelConfig = None, + shard_checkpoint: bool = False, + ) -> "AutoRegressiveModel": + """ + Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + + Args: + model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig(). + tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None. + shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. + download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. + Returns: + AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory. + + Note: + This method sets the device to CUDA and loads the pre-trained model and tokenizer. + """ + tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size + assert tensor_parallel_size == model_config["tensor_model_parallel_size"] + + # Initialize model configuration parameters + config_params = {} + + # Load checkpoint and model parameters + + if model_config.ckpt_path is None: + # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir + ckpt_dir = model_config.ckpt_dir + + # We prioritize safetensors version over the pytorch version, since the former is + # much faster for checkpoint loading. + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + if len(checkpoints) == 0: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert ( + len(checkpoints) == 1 + ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" + ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case + + if os.path.exists(Path(ckpt_dir) / "config.json"): + with open(Path(ckpt_dir) / "config.json", "r") as f: + config_params = json.loads(f.read()) + else: + log.info( + f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config." + ) + + else: + # If ckpt_path is provided, we load the model from the specified path, + # and use the default model configuration + ckpt_path = model_config.ckpt_path + + for key, value in config_params.items(): + if hasattr(model_config, key): + # Override the default model configuration with the parameters from the checkpoint + setattr(model_config, key, value) + + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + + if model_config.vision_encoder is not None: + # Take the LLM weights (starting with "model.") from the VLM checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + if model_config.vision_encoder is not None: + # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` + # and `checkpoint['mm_projector']` are both for those weights + # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights + if "vision_encoder" in checkpoint: + log.debug("Using pretrained vision_encoder") + vit_checkpoint = checkpoint["vision_encoder"] + else: + log.debug("Using fine-tuned vision_encoder") + vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") + vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") + if "mm_projector" in checkpoint: + log.debug("Using pretrained mm_projector") + projector_checkpoint = checkpoint["mm_projector"] + else: + log.debug("Using fine-tuned mm_projector") + projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") + projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") + assert ( + len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 + ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." + + tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + model_parallel=model_parallel, + tokenizer_config=tokenizer_config, + ) + model_kwargs = {} + + if model_config.vision_encoder is not None: + assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." + vit_config = get_vit_config(model_config.vision_encoder) + vit_config["tensor_model_parallel_size"] = tensor_parallel_size + vision_encoder = VisionTransformer.build( + vit_config, + ) + + mm_projector = MultimodalProjector( + mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] + ) + model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) + + # Perform vocab expansion + if tokenizer.vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {tokenizer.vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer.training_type == "text_to_video") + model.expand_vocab( + tokenizer.vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if model_parallel is not None: + assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + if model_config.vision_encoder is not None: + # Shard vision encoder and multimodal projector weights + vit_checkpoint = obtain_tensor_parallel_state_dict( + vit_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=vit_config, + ) + + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + if model_config.vision_encoder is not None: + vision_encoder.load_state_dict(vit_checkpoint) + mm_projector.load_state_dict(projector_checkpoint) + if model_config.vision_encoder_in_channels != 3: + vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) + + model = model.to(precision) # ensure model parameters are in the correct precision + log.debug(f"Model config: {model_config}") + + model_class = AutoRegressiveModel + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return model_class(model, tokenizer, model_config, **model_kwargs) + + @torch.no_grad() + def generate( + self, + prompt_tokens: List[List[int]] | torch.Tensor, + max_gen_len: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + echo: bool = False, + seed: int = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + compile_sampling: bool = True, + compile_prefill: bool = False, + verbose: bool = True, + stop_tokens: Optional[Set[int]] = None, + images: Optional[torch.Tensor] = None, + ): + """ + Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast). + + Args: + prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. + num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. + seed (int, optional): Random seed for reproducibility. Defaults to None. + compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. + """ + assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." + if temperature == 0: + top_p, top_k = None, None + log.debug("Setting top_p and top_k to None because temperature is 0") + if top_p is not None: + log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + orig_precision = torch.get_default_dtype() + torch.set_default_dtype(self.precision) + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + if seed is not None: + misc.set_random_seed(seed) + + assert not logprobs, "logprobs are not supported for fast_generate yet" + # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags + if compile_sampling and not getattr(self, "inference_decode_compiled", False): + log.info("Compiling AR sampling function. Note: the first run will be slower due to compilation") + self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + self.inference_decode_compiled = True + log.info("Compiled AR sampling function.") + if compile_prefill and not getattr(self, "inference_prefill_compiled", False): + log.info("Compiling prefill function. Note: the first run will be slower due to compilation") + self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + self.inference_prefill_compiled = True + log.info("Compiled prefill function.") + + if not hasattr(self, "decode_one_token"): + self.decode_one_token = decode_one_token + if not hasattr(self, "prefill"): + self.prefill = prefill + + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.debug( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + if isinstance(prompt_tokens, list): + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") + if prompt_tokens.ndim == 1: + prompt_tokens = prompt_tokens.view(1, -1) + else: + assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" + batch_size, prompt_len = prompt_tokens.shape + total_len = min(params.max_seq_len, max_gen_len + prompt_len) + if max_gen_len + prompt_len > params.max_seq_len: + log.warning( + f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" + ) + max_gen_len = params.max_seq_len - prompt_len + + if context_mask is not None: + context_mask = context_mask.to(dtype=torch.bool) + if context_mask.ndim == 2: + assert ( + context_mask.shape[0] == batch_size + ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" + # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] + context_mask = context_mask.view(batch_size, 1, 1, -1) + + if num_gen_seq > 1: + assert ( + batch_size == 1 + ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" + log.debug(f"Generating {num_gen_seq} sequences with the same prompt") + assert ( + num_gen_seq <= params.max_batch_size + ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" + # repeat the prompt tokens for num_gen_seq times + prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) + assert prompt_tokens.shape == ( + num_gen_seq, + prompt_len, + ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" + batch_size = len(prompt_tokens) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) + empty[:, :prompt_len] = prompt_tokens + seq = empty + input_pos = torch.arange(0, prompt_len, device="cuda") + + if verbose: + prefill_start = time.time() + + if images is not None: + images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16) + prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images) + else: + prompt_token_embeddings = None + + if context is not None: + context = context.to(device=prompt_tokens.device, dtype=self.precision) + + # Prefill stage + next_token = self.prefill( + self.model, + input_pos=input_pos, + tokens=prompt_tokens if prompt_token_embeddings is None else None, + token_embeddings=prompt_token_embeddings, + temperature=temperature, + top_k=top_k, + top_p=top_p, + context=context, + context_mask=context_mask, + ) + if verbose: + prefill_time = time.time() - prefill_start + + seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) + input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") + stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens + stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") + + if verbose: + decode_start = time.time() + # Decode stage + generated_tokens = decode_n_tokens( + self.model, + next_token.view(batch_size, -1), + input_pos, + max_gen_len - 1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, + decode_one_token_function=self.decode_one_token, + context=context, + context_mask=context_mask, + ) + gen_len = len(generated_tokens) + if verbose: + decode_time = time.time() - decode_start + prefill_throughput = prompt_len / prefill_time + decode_throughput = gen_len / decode_time + log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") + log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") + + generated_tokens = torch.cat(generated_tokens, dim=1) + + log.debug(f"generated_tokens: {generated_tokens.shape}") + seq = seq[:, : prompt_len + 1 + gen_len] + seq[:, prompt_len + 1 :] = generated_tokens + if not echo: + seq = seq[:, prompt_len:] + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return seq, None + + def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: + """ + Embed vision and language features into a combined representation. + + Args: + input_ids (torch.Tensor): Input token IDs. + images (torch.tensor): Input images. + + Returns: + torch.Tensor: Combined vision-language features. + + Raises: + AssertionError: If vision encoder or mm projector is not initialized, + or if dimensions mismatch. + """ + # Ensure vision encoder and mm projector are initialized + assert self.vision_encoder is not None + assert self.mm_projector is not None + + # Get image token ID and validate it + image_token_id = self.vision_encoder.image_token_id + assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" + + # Identify text and image locations in the input + text_locations = input_ids != image_token_id + image_locations = input_ids == image_token_id + + # Process text features + text_features = self.model.tok_embeddings(input_ids[text_locations]) + + # Process image features + images = images.to(device=text_features.device, dtype=text_features.dtype) + vit_outputs = self.vision_encoder(images) + image_features = self.mm_projector(vit_outputs) + + # Get dimensions + B, seq_len = input_ids.shape + N_total = B * seq_len + N_txt, D_txt = text_features.shape + N_img, N_patch, D_img = image_features.shape + + # Reshape image features + image_features = image_features.reshape(N_img * N_patch, D_img) + + # Validate dimensions + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + N_total == N_txt + N_img * N_patch + ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" + + # Combine text and image features + combined_features = torch.empty( + (B, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + combined_features[image_locations, :] = image_features + + return combined_features + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if strict: + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + return _IncompatibleKeys(actual_missing_keys, unexpected_keys) diff --git a/cosmos_predict1/autoregressive/modules/__init__.py b/cosmos_predict1/autoregressive/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/modules/attention.py b/cosmos_predict1/autoregressive/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea6af4c39530aee882a58f5f956d1206a069562 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/attention.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch +from megatron.core import parallel_state +from torch import nn +from torch.distributed._functional_collectives import all_reduce + +from cosmos_predict1.autoregressive.modules.embedding import RotaryPositionEmbedding +from cosmos_predict1.autoregressive.modules.normalization import create_norm + + +class Attention(nn.Module): + """ + Attenion layer with KV cache. + """ + + def __init__( + self, + n_heads: int, + n_kv_heads: Union[int, None], + dim: int, + max_batch_size: int, + max_seq_len: int, + context_dim: Optional[int] = None, + use_qk_normalization: bool = False, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + causal_mask: Optional[bool] = True, + head_dim: Optional[int] = None, + fuse_qkv: bool = False, + precision: str = "bfloat16", + tensor_parallel_size: int = 1, + attn_type: str = "self", + ): + """ + Initializes the GQA module. + + Args: + n_heads (int): The number of attention heads. + n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. + dim (int): The dimensionality of the input and output. + max_batch_size (int): The maximum batch size. + max_seq_len (int): The maximum sequence length. + context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. + use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. + norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". + norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. + tp_group (int, optional): The tensor parallel group. + causal_mask (bool, optional): Whether to use causal mask. Defaults to True. + head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. + fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False. + precision (str, optional): The precision of the module. Defaults to "bfloat16". + tensor_parallel_size (int, optional): The tensor parallel size. Defaults to 1. + attn_type (str, optional): The type of attention. Defaults to "self". + """ + super().__init__() + assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}" + self.attn_type = attn_type + self.tp_size = tensor_parallel_size + context_dim = dim if context_dim is None else context_dim + + self.dim = dim + self.context_dim = context_dim + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_kv_heads = self.n_kv_heads // self.tp_size + self.n_local_heads = n_heads // self.tp_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads if head_dim is None else head_dim + self.causal_mask = causal_mask + self.fuse_qkv = fuse_qkv + self.precision = precision + + if fuse_qkv: + assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" + self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim + self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) + # Register hook to load fused QKV weights + self._register_load_state_dict_pre_hook(self.load_hook) + else: + self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) + + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + + if self.attn_type == "self": + # Cache for key and value tensors + self.init_kv_cache() + + # QK normalization layers + if use_qk_normalization: + assert n_heads % self.tp_size == 0, "n_heads must be divisible by tensor_model_parallel_size" + assert self.n_kv_heads % self.tp_size == 0, "n_kv_heads must be divisible by tensor_model_parallel_size" + self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + + self.use_qk_normalization = use_qk_normalization + + self.to(dtype=getattr(torch, self.precision)) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def init_kv_cache(self, dtype=None): + cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) + if dtype is None: + dtype = getattr(torch, self.precision) + if self.attn_type == "self": + self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() + self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbedding, + input_pos: torch.Tensor, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ): + """ + Forward pass of GQA. + + Args: + x: The input tensor of shape (batch_size, seq_len, dim). + rope: The rotary positional embedding module. + input_pos: The starting position of the current sequence. + mask: The attention mask tensor. + context: The context tensor of shape (batch_size, context_len, dim). + + Returns: + The output tensor after applying GQA. + """ + bsz, seqlen, _ = x.shape + + # Use one single module to handle both self-attn and cross-attn + context = x if context is None else context + context_len = seqlen if context is None else context.shape[1] + + if self.fuse_qkv: + q_size = self.n_local_heads * self.head_dim + kv_size = self.n_local_kv_heads * self.head_dim + xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + else: + # Compute query, key, and value projections + xq, xk, xv = self.wq(x), self.wk(context), self.wv(context) + + # Reshape projections + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + + # QK normalization + if self.use_qk_normalization: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Apply rotary positional embeddings to queries and keys + # Only apply RoPE to self-attention! + if self.attn_type in ["self", "full"]: + xq, xk = rope(xq, xk, input_pos, seqlen) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_kv_heads, cache_len + context_len, head_dim) + # xv: (bs, n_kv_heads, cache_len + context_len, head_dim) + if self.attn_type == "self": + # Update cache with current key and value tensors + assert input_pos is not None + self.cache_k[:bsz, :, input_pos] = xk + self.cache_v[:bsz, :, input_pos] = xv + keys, values = ( + self.cache_k[:bsz, :, :], + self.cache_v[:bsz, :, :], + ) + else: + keys, values = xk, xv + + # Repeat keys and values if necessary + keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + + # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used, + # since the masking is handled outside this attention module. + # For cross-attention, it's always full-attn without causal mask + is_causal = False + output = scaled_dot_product_attention( + xq, + keys, + values, + head_dim=self.head_dim, + mask=mask, + is_causal=is_causal, + dropout_p=0.0, + ) + output = output.view(bsz, seqlen, -1) + output = self.wo(output) + if self.tp_size > 1: + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_dim: int, + mask: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + dropout_p: float = 0.0, +) -> torch.Tensor: + """ + PyTorch's native implementation of Flash Attention 2. + + If `is_causal` is given, then the causal attention mask is applied accordingly: + - If `is_causal` is True, the standard upper-left causal attention masking is applied. + - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is + provided (i.e., `mask is not None`). + + If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied + based on the provided mask tensor: + - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, + leading to the standard upper-left causal attention masking. + - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, + and `is_causal` is set to False. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + head_dim (int): Dimension of each attention head + mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. + dropout_p (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + torch.Tensor: Output tensor after applying scaled dot-product attention + """ + scale = 1.0 / math.sqrt(head_dim) + if is_causal is None: + is_causal = mask is None + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + ) + return y.transpose(1, 2).contiguous() diff --git a/cosmos_predict1/autoregressive/modules/embedding.py b/cosmos_predict1/autoregressive/modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b167b331ccf4b79edd1d95fecd20bb161c5115 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/embedding.py @@ -0,0 +1,649 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +from einops import rearrange, repeat +from megatron.core import parallel_state + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _rotate_half_te(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even]. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb_te( + t: torch.Tensor, + cos_freqs: torch.Tensor, + sin_freqs: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[b, s, h, d]`, on which + rotary positional embedding will be applied. + cos_freqs: torch.Tensor + Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + sin_freqs: torch.Tensor + Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + """ + rot_dim = cos_freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs) + output = torch.cat((t, t_pass), dim=-1) + return output + + +def get_pos_emb_on_this_cp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Get the position embedding for the current context parallel rank. + + Args: + pos_emb (torch.Tensor): The position embedding tensor. + seq_dim (int): The sequence dimension to slice. + + Returns: + torch.Tensor: The position embedding tensor for the current rank. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(non_blocking=True) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb + + +def get_pos_emb_on_this_sptp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Get the position embedding for the current tensor parallel rank (only used when sequence parallel is turned on) + + Args: + pos_emb (torch.Tensor): The position embedding tensor. + seq_dim (int): The sequence dimension to slice. + + Returns: + torch.Tensor: The position embedding tensor for the current rank. + """ + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pos_emb_chunks = torch.chunk(pos_emb, tp_size, dim=seq_dim) + pos_emb = pos_emb_chunks[tp_rank] + return pos_emb + + +class RotaryPositionEmbedding(torch.nn.Module): + """ + Rotary Position Embedding module as described in the paper: + https://arxiv.org/abs/2104.09864 + + This module implements rotary positional embeddings, which are used to + enhance the performance of transformer models. + + Args: + dim (int): Dimensionality of the input tensor. + max_position_embeddings (Optional[int]): Maximum position embeddings. + original_max_position_embeddings (Optional[int]): Original maximum position embeddings. + rope_theta (Optional[float]): Base for the frequency calculation. + apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary). + scale (Optional[int]): Scaling factor for the frequency calculation. + extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension. + attn_factor (Optional[int]): Attention factor for the frequency calculation. + beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation. + beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: Optional[int] = None, + original_max_position_embeddings: Optional[int] = None, + rope_theta: Optional[float] = 10000.0, + apply_yarn: Optional[bool] = False, + scale: Optional[int] = None, + extrapolation_factor: Optional[int] = 1, + attn_factor: Optional[int] = 1, + beta_fast: Optional[int] = 32, + beta_slow: Optional[int] = 1, + rope_dim: Optional[str] = "1D", + latent_shape: Optional[List[int]] = None, + original_latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.rope_theta = rope_theta + self.apply_yarn = apply_yarn + self.scale = scale + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = 1.0 + self.rope_dim = rope_dim + self.latent_shape = latent_shape + self.original_latent_shape = original_latent_shape + self.pad_to_multiple_of = pad_to_multiple_of + self.get_inv_freq(torch.cuda.current_device()) + + def get_mscale(self, scale: float = 1.0) -> float: + """Get the magnitude scaling factor for YaRN.""" + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + def forward(self, seq_len: Optional[int] = None) -> torch.Tensor: + """ + Forward pass for the rotary position embedding. + + Args: + seq_len (Optional[int]): Length of the sequence. + + Returns: + torch.Tensor: The computed frequencies for positional embedding. + """ + + if self.apply_yarn and seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.freqs = self.compute_freqs() + + return self.freqs + + def compute_freqs( + self, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the spatial frequencies for the latent tensor.""" + self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda() + if self.rope_dim == "1D": + emb = torch.einsum("i,j->ij", self.seq, self.inv_freq) + + elif self.rope_dim == "2D": + H, W = self.latent_shape + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_h, "h d -> h w d", w=W), + repeat(half_emb_w, "w d -> h w d", h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "h w d -> (h w) 1 1 d").float() + + elif self.rope_dim == "3D": + T, H, W = self.latent_shape + half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq) + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + return emb + + def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: + """Get the scale factors for YaRN.""" + # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called + # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. + high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len + low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len + # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear + # interpolation in between. + smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) + # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. + scale_factors = (1 - smooth_mask) / self.scale + smooth_mask + return scale_factors + + def get_inv_freq(self, device: torch.device) -> None: + """Get the inverse frequency.""" + if self.rope_dim == "1D": + assert self.max_position_embeddings is not None, "Max position embeddings required." + inv_freq = 1.0 / ( + self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + if self.apply_yarn: + assert self.original_max_position_embeddings is not None, "Original max position embeddings required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings) + # Apply the scaling factors to inv_freq. + inv_freq = inv_freq * scale_factors + # Set the magnitude scaling factor. + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.max_seq_len_cached = self.max_position_embeddings + self.inv_freq = inv_freq + + elif self.rope_dim == "2D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 2 + spatial_inv_freq = 1.0 / ( + self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h + ) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0]) + spatial_inv_freq = spatial_inv_freq * scale_factors + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + + elif self.rope_dim == "3D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 6 * 2 + dim_t = self.dim - 2 * dim_h + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h + spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t + temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) + spatial_inv_freq = spatial_inv_freq * scale_factors_spatial + scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) + temporal_inv_freq = temporal_inv_freq * scale_factors_temporal + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.temporal_inv_freq = temporal_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + + self.freqs = self.compute_freqs() + + +class RotaryPositionEmbeddingTE(RotaryPositionEmbedding): + """ + Rotary Position Embedding with context parallelism support. + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + + def forward(self, seq_len: int, training_type: str = None) -> torch.Tensor: + """ + Create rotary position embedding frequencies. + + Args: + seq_len (int): Sequence length of a sample. + + Returns: + torch.Tensor: The computed positional embeddings. + """ + if self.rope_dim == "1D": + freqs = super().forward(seq_len=seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + elif self.rope_dim in ["2D", "3D"]: + emb = super().forward(seq_len=seq_len) + if training_type == "text_to_video": + # since we added token at the beginning of the video for text2video, we also extend the position embedding by one token in the beginning + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) + emb = torch.cat((bov_pe, emb), dim=0) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) + + return emb + + +class RotaryPositionEmbeddingPytorch(RotaryPositionEmbedding): + """ + Rotary Position Embedding with PyTorch specific adjustments. + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + if self.rope_dim == "1D": + emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(self.freqs, "s 1 1 d -> s d").float() + self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input tensor.""" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + x1 = x_reshaped[..., 0] + x2 = x_reshaped[..., 1] + output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) + return output + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the rotary position embedding. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + input_pos (Optional[torch.Tensor]): Starting position for the sequence. + seq_len (Optional[int]): Length of the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + if self.apply_yarn and seq_len > self.max_seq_len_cached: + freqs = super().forward(seq_len) + if self.rope_dim == "1D": + emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(freqs, "s 1 1 d -> s d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + + if input_pos is not None: + cos_cached = self.cos_cached[:, input_pos] + sin_cached = self.sin_cached[:, input_pos] + else: + assert ( + self.cos_cached.shape[1] >= seq_len + ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." + cos_cached = self.cos_cached[:, :seq_len, ...] + sin_cached = self.sin_cached[:, :seq_len, ...] + xq = q * cos_cached + self.rotate_half(q) * sin_cached + xk = k * cos_cached + self.rotate_half(k) * sin_cached + + return xq.type_as(q), xk.type_as(k) + + +class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as the TransformerEngine RoPE + (https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) + + """ + + def __init__( + self, + seq_len: int, + training_type: str = None, + **kwargs, + ): + super().__init__( + **kwargs, + ) + emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type) + emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim] + assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}" + # cos/sin first then dtype conversion for better precision + self.register_buffer("cos_cached", torch.cos(emb), persistent=False) + self.register_buffer("sin_cached", torch.sin(emb), persistent=False) + + def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor: + """ + Create rotary position embedding frequencies. + + Args: + seq_len (int): Sequence length of a sample. + + Returns: + torch.Tensor: The computed positional embeddings. + """ + if self.rope_dim == "1D": + freqs = super().forward(seq_len=seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + elif self.rope_dim in ["2D", "3D"]: + emb = super().forward(seq_len=seq_len) + if training_type == "text_to_video": + # since we added token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) + emb = torch.cat((bov_pe, emb), dim=0) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) + + return emb + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if q.dtype != self.cos_cached.dtype: + self.cos_cached = self.cos_cached.to(q.dtype) + self.sin_cached = self.sin_cached.to(q.dtype) + + cos_emb = self.cos_cached + sin_emb = self.sin_cached + if input_pos is not None: + cos_emb = cos_emb[:, input_pos, :, :] + sin_emb = sin_emb[:, input_pos, :, :] + elif seq_len is not None: + cos_emb = cos_emb[:, :seq_len, :, :] + sin_emb = sin_emb[:, :seq_len, :, :] + q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb) + k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb) + return q, k + + +class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as + mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py) + or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py) + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + if self.rope_dim == "1D": + emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(self.freqs, "s 1 1 d -> s d").float() + self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input tensor.""" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + x1 = x_reshaped[..., 0] + x2 = x_reshaped[..., 1] + output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) + return output + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the rotary position embedding. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + input_pos (Optional[torch.Tensor]): Starting position for the sequence. + seq_len (Optional[int]): Length of the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + if self.apply_yarn and seq_len > self.max_seq_len_cached: + freqs = super().forward(seq_len) + if self.rope_dim == "1D": + emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(freqs, "s 1 1 d -> s d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + + if input_pos is not None: + cos_cached = self.cos_cached[:, input_pos] + sin_cached = self.sin_cached[:, input_pos] + else: + assert ( + self.cos_cached.shape[1] >= seq_len + ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." + cos_cached = self.cos_cached[:, :seq_len, ...] + sin_cached = self.sin_cached[:, :seq_len, ...] + xq = q * cos_cached + self.rotate_half(q) * sin_cached + xk = k * cos_cached + self.rotate_half(k) * sin_cached + + return xq.type_as(q), xk.type_as(k) + + +class SinCosPosEmbAxisTE(torch.nn.Module): + def __init__( + self, + dim: int, + latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, + device="cuda", + **kwargs, + ): + """ + Args: + dim (int): Dimensionality of the input tensor. + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + dtype (torch.dtype): Data type of the position embedding tensor. + """ + super().__init__() + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.latent_shape = latent_shape + T, H, W = latent_shape + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H)) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W)) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T)) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device=device), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device=device), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device=device), persistent=False) + self.pad_to_multiple_of = pad_to_multiple_of + + def forward( + self, + training_type: str | None = None, + ) -> torch.Tensor: + T, H, W = self.latent_shape + emb = torch.cat( + [ + repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W), + repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W), + repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H), + ], + dim=-1, + ) + # Flatten the T,H,W dimensions + emb = rearrange(emb, "t h w d -> (t h w) d") + + if training_type == "text_to_video": + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype) + emb = torch.cat((bov_pe, emb), dim=0) + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0) + seq_len, dim = emb.shape + emb = emb.reshape(1, seq_len, dim) + return emb diff --git a/cosmos_predict1/autoregressive/modules/linear.py b/cosmos_predict1/autoregressive/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..cce025a8a3c67037865791202f4ad05ec16673c5 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/linear.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel import ColumnParallelLinear as McoreColumnParallelLinear +from megatron.core.tensor_parallel import RowParallelLinear as McoreRowParallelLinear +from megatron.core.tensor_parallel import VocabParallelEmbedding as McoreVocabParallelEmbedding +from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.tensor_parallel.utils import VocabUtility +from torch.distributed import _functional_collectives as funcol +from torch.distributed._functional_collectives import all_reduce + + +class VocabParallelEmbedding(torch.nn.Module): + """ + Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + + Args: + num_embeddings (int): vocabulary size. + embedding_dim (int): size of hidden state. + precision (str): precision of the embedding. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + precision: str = "bfloat16", + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + (self.vocab_start_index, self.vocab_end_index) = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, + parallel_state.get_tensor_model_parallel_rank(), + self.tensor_model_parallel_size, + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + self.weight = torch.nn.Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=getattr(torch, precision), + ) + ) + + def forward(self, input_): + """Forward. + + Args: + input_ (torch.Tensor): Input tensor. + """ + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output = self.weight[masked_input] + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output[input_mask, :] = 0.0 + + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +class ColumnParallelLinear(McoreColumnParallelLinear): + """ + A modified version of Mcore's ColumnParallelLinear that only returns the output tensor. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input_: torch.Tensor): + """ + Performs the forward pass of the column parallel linear layer. + + Args: + input_ (torch.Tensor): The input tensor. + weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. + + Returns: + torch.Tensor: The output tensor after the linear transformation. + """ + output, _ = super().forward(input_) + return output + + +class RowParallelLinear(McoreRowParallelLinear): + """ + A modified version of Mcore's RowParallelLinear that only returns the output tensor. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input_: torch.Tensor): + """ + Performs the forward pass of the Row Parallel linear layer. + + Args: + input_ (torch.Tensor): The input tensor. + weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. + + Returns: + torch.Tensor: The output tensor after the linear transformation. + """ + output, _ = super().forward(input_) + return output + + +class TrainingVocabParallelEmbedding(McoreVocabParallelEmbedding): + """ + Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + + Args: + num_embeddings (int): vocabulary size. + embedding_dim (int): size of hidden state. + + Keyword Args: + sequence_parallel (bool): Decides whether to perform ReduceScatter after embedding lookup + batch_first (bool): If True, then output tensor shape is [batch, seq, feature]. If False, then shape becomes + [seq, batch, feature]. Note: We assume the input tensor is always in the shape of [seq, batch]. + config: A megatron.core.ModelParallelConfig object + use_inference_allreduce (bool): If True, then Megatron's allreduce in the forward pass is disabled, and the pytorch's + allreduce is used instead (inference mode only). + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method: Callable, + sequence_parallel: bool = False, + batch_first: bool = False, + config: ModelParallelConfig, + use_inference_allreduce: bool = False, + ): + super(TrainingVocabParallelEmbedding, self).__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + init_method=init_method, + config=config, + ) + self.sequence_parallel = sequence_parallel + if sequence_parallel: + # If sequence parallel, then the output tensor should be in the shape of [seq, batch, feature] + batch_first = False + self.batch_first = batch_first + self.use_inference_allreduce = use_inference_allreduce + + def forward(self, input_): + """Forward. + + Args: + input_ (torch.Tensor): Input tensor. + """ + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output = self.weight[masked_input] + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output[input_mask, :] = 0.0 + + if self.sequence_parallel: + assert not self.batch_first + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + output = output.transpose(0, 1).contiguous() + if not self.use_inference_allreduce: + output = reduce_scatter_to_sequence_parallel_region(output) + else: + # Reduce across all the model parallel GPUs. + if not self.use_inference_allreduce: + output = reduce_from_tensor_model_parallel_region(output) + if not self.batch_first: + # Shape: [b, s, h] --> [s, b, h] + output = output.transpose(0, 1).contiguous() + + if self.use_inference_allreduce: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output diff --git a/cosmos_predict1/autoregressive/modules/mlp.py b/cosmos_predict1/autoregressive/modules/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..61ef18a8049d65760f3c2fdaaa8a21845d706cfa --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/mlp.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core import ModelParallelConfig, parallel_state +from torch.distributed import _functional_collectives as funcol +from torch.distributed._functional_collectives import all_reduce + +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, RowParallelLinear + + +def compute_llama3_ffn_hidden_dim(dim: int, multiple_of: int, ffn_dim_multiplier: float) -> int: + """ + Computes the feedforward network dimensionality. + + Args: + dim (int): The embedding dimensionality. + multiple_of (int): The multiple to round up the hidden dimensionality. + ffn_dim_multiplier (float): The multiplier for the hidden dimensionality. + + Returns: + The feedforward network dimensionality. + """ + hidden_dim = 4 * dim + hidden_dim = int(2 * hidden_dim / 3) # custom dim factor + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + # Round up hidden dimensionality to the nearest multiple + return multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + tensor_parallel_size: int = 1, + ): + """ + Initializes the multilayer perceptron (MLP) module. + + Args: + dim: The input and output dimensionality. + hidden_dim: The dimensionality of the hidden layer. + """ + super().__init__() + self.tp_size = tensor_parallel_size + self.w1 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + self.w2 = nn.Linear(hidden_dim // self.tp_size, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the MLP module. + + Args: + x: The input tensor of shape (batch_size, dim). + + Returns: + The output tensor of shape (batch_size, dim). + """ + output = self.w2(F.silu(self.w1(x)) * self.w3(x)) + if self.tp_size > 1: + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +class TrainingMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + hidden_dropout: float = 0.0, + set_parallel_mode: bool = False, + model_parallel: Optional[ModelParallelConfig] = None, + inference: bool = False, + ): + """ + Initializes the multilayer perceptron (MLP) module. + + Args: + dim: The input and output dimensionality. + hidden_dim: The dimensionality of the hidden layer. + hidden_dropout: Dropout after the attention and feed-forward layers (following TransformerEngine's + implementation in its TransformerLayer class). + set_parallel_mode: Whether to use column and row parallel linear layers. + model_parallel: The model parallel configuration. + inference: Whether the model is used for inference. + """ + super().__init__() + self.hidden_dropout = hidden_dropout + if model_parallel and model_parallel.tensor_model_parallel_size > 1: + self.tp_size = model_parallel.tensor_model_parallel_size + else: + self.tp_size = 1 + if set_parallel_mode and not inference: + kwargs = {"bias": False, "init_method": lambda x: x, "config": model_parallel} + # Using column and row parallel linear layers + self.w1 = ColumnParallelLinear(dim, hidden_dim, gather_output=False, **kwargs) + self.w2 = RowParallelLinear(hidden_dim, dim, input_is_parallel=True, skip_bias_add=True, **kwargs) + self.w3 = ColumnParallelLinear(dim, hidden_dim, gather_output=False, **kwargs) + else: + self.w1 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + self.w2 = nn.Linear(hidden_dim // self.tp_size, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + + self.inference = inference + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the MLP module. + + Args: + x: The input tensor of shape (batch_size, dim). + + Returns: + The output tensor of shape (batch_size, dim). + """ + x = F.dropout(x, p=self.hidden_dropout, training=self.training) + output = self.w2(F.silu(self.w1(x)) * self.w3(x)) + output = F.dropout(output, p=self.hidden_dropout, training=self.training) + + if self.inference and self.tp_size > 1: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + def init_weights(self, init_std: float): + """ + Initializes the weights of the MLP module. + """ + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) diff --git a/cosmos_predict1/autoregressive/modules/mm_projector.py b/cosmos_predict1/autoregressive/modules/mm_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..ee54c961498ff108a92fe621e9322649f7ad891b --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/mm_projector.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multimodal projector to connect vision encoder / tokenizer with the LLM.""" + +from typing import Any, Optional + +import torch +import torch.nn as nn + + +class DownSampleBlock(nn.Module): + """Downsample block.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + """ + Performs the forward pass of the downsample block. + + Args: + x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings. + Shape: (b, seq_len, c). + + Returns: + torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4). + """ + vit_embeds = x + # Get h and w as the sqrt of seq length. This assumes that the input is square-shaped. + h = w = int(vit_embeds.shape[1] ** 0.5) + b = vit_embeds.shape[0] + vit_embeds = vit_embeds.reshape(b, h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs spatial downsampling while increasing the number of channels. + + Args: + x (torch.Tensor): The input tensor reshaped to a 2D grid. + Shape: (b, h, w, c) + + Returns: + torch.Tensor: The output tensor after the spatial downsampling. + Shape: (b, h/2, w/2, c*4) + """ + b, h, w, c = x.size() + # If w or h is odd, pad a column or a row of zeros. + if h % 2 == 1: + x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() + b, h, w, c = x.size() + if w % 2 == 1: + x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() + b, h, w, c = x.size() + # 2x spatial downsampling, 4x channel increasing. + x = x.view(b, h, int(w / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(b, int(h / 2), int(w / 2), int(c * 4)) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + +class MultimodalProjector(nn.Module): + """Multimodal projector.""" + + def __init__( + self, + mm_projector_type: str, + in_dim: int, + out_dim: Optional[int] = None, + **kwargs: Any, + ): + super().__init__() + if out_dim is None: + out_dim = in_dim + if mm_projector_type == "identity": + self.projector = nn.Identity() + elif mm_projector_type == "linear": + self.projector = nn.Linear(in_dim, out_dim) + elif mm_projector_type == "mlp": + self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) + elif mm_projector_type == "mlp_downsample": + self.projector = nn.Sequential( + DownSampleBlock(), + nn.LayerNorm(in_dim * 4), + nn.Linear(in_dim * 4, out_dim), + nn.GELU(), + nn.Linear(out_dim, out_dim), + ) + else: + raise ValueError(f"Unknown projector type: {mm_projector_type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) diff --git a/cosmos_predict1/autoregressive/modules/normalization.py b/cosmos_predict1/autoregressive/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..37af6f2f63ae7aa1bbc37a5c815226ceebf4ccbb --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/normalization.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def create_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Creates the specified normalization layer based on the norm_type. + Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py + + Args: + norm_type (str): The type of normalization layer to create. + Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The created normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps, compile=False) + elif norm_type == "compiled_rmsnorm": + return RMSNorm(dim, eps=eps, compile=True) + elif norm_type == "fused_rmsnorm": + raise NotImplementedError("Fused RMSNorm is not supported yet.") + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + compile (bool, optional): Whether to compile the forward function. Default is False. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm + + @staticmethod + def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): + def _norm(x, eps): + # Computes the root-mean-square norm of the input tensor. + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + output = _norm(x.float(), eps).type_as(x) + return output * weight + + def forward(self, x: torch.Tensor): + return self.rmsnorm_fn(x, self.weight, self.eps) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) diff --git a/cosmos_predict1/autoregressive/networks/transformer.py b/cosmos_predict1/autoregressive/networks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f573636e2ce27f4f337fd02c0c60de208808664 --- /dev/null +++ b/cosmos_predict1/autoregressive/networks/transformer.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from megatron.core import parallel_state +from torch.distributed import broadcast, get_process_group_ranks +from torch.distributed._functional_collectives import all_gather_tensor +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.modules.attention import Attention +from cosmos_predict1.autoregressive.modules.embedding import ( + RotaryPositionEmbeddingPytorchV1, + RotaryPositionEmbeddingPytorchV2, + SinCosPosEmbAxisTE, +) +from cosmos_predict1.autoregressive.modules.linear import VocabParallelEmbedding +from cosmos_predict1.autoregressive.modules.mlp import MLP +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore +from cosmos_predict1.autoregressive.utils.misc import maybe_convert_to_namespace +from cosmos_predict1.utils import log + + +class TransformerBlock(nn.Module): + """ + A single transformer block consisting of an attention layer and a feed-forward layer. + """ + + def __init__(self, layer_id: int, args=None): + """ + Initializes the TransformerBlock module. + + Args: + layer_id: The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + super().__init__() + args = maybe_convert_to_namespace(args) + attention_args = { + "n_heads": args["n_heads"], + "n_kv_heads": args["n_kv_heads"], + "dim": args["dim"], + "context_dim": None, + "max_batch_size": args["max_batch_size"], + "max_seq_len": args["max_seq_len"], + "use_qk_normalization": args["use_qk_normalization"], + "causal_mask": args["causal_mask"], + "head_dim": args["head_dim"], + "fuse_qkv": getattr(args, "fuse_qkv", False), + "precision": getattr(args, "precision", "bfloat16"), + "tensor_parallel_size": args["tensor_model_parallel_size"], + "attn_type": getattr(args, "attn_type", "self"), + } + self.attention = Attention(**attention_args) + + self.has_cross_attention = False + self.cross_attention, self.cross_attention_norm = None, None + + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + cross_attention_args = attention_args.copy() + cross_attention_args.update({"context_dim": args["context_dim"], "fuse_qkv": False, "attn_type": "cross"}) + self.cross_attention = Attention(**cross_attention_args) + self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + self.feed_forward = MLP( + dim=args["dim"], + hidden_dim=args["ffn_hidden_size"], + tensor_parallel_size=args["tensor_model_parallel_size"], + ) + self.layer_id = layer_id + self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbeddingPytorchV2, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the TransformerBlock module. + + Args: + x: The input tensor. + input_pos: The position of the current sequence. Used in inference (with KV cache) only. + freqs_cis: The precomputed frequency values for rotary position embeddings. + mask: The attention mask tensor. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + The output tensor after applying the transformer block. + """ + # Apply attention and residual connection + h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) + + # If insert cross-attention, apply CA and residual connection + if self.has_cross_attention: + h = h + self.cross_attention( + self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context + ) + + # Apply feed-forward network and residual connection + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + if self.has_cross_attention: + self.cross_attention_norm.reset_parameters() + self.cross_attention.init_weights(self.weight_init_std) + # zero-init the final output layer of cross-attention + # nn.init.zeros_(self.cross_attention.wo.weight) + + +class Transformer(nn.Module): + """ + The Transformer network consisting of transformer blocks. + """ + + def __init__(self, params, model_parallel=None, tokenizer_config=None, init_weights: bool = True): + """ + Initializes the Transformer module. + + Args: + params: The model parameters containing hyperparameters. + model_parallel: The model parallel configuration. + tokenizer_config: The model tokenizer configuration. + init_weights (bool): Whether to initialize the weights of the transformer following + TorchTitan's Llama3 initialization scheme. + """ + super().__init__() + # Check if self.params is an OmegaConf DictConfig instance + self.params = maybe_convert_to_namespace(params) + self.vocab_size = params["vocab_size"] + self.n_layers = params["n_layers"] + self.precision = getattr(torch, params["precision"]) + self.tokenizer_config = tokenizer_config + self.model_parallel = model_parallel + self.num_video_frames = params["num_video_frames"] + tp_group = self._get_tp_group() + + # Token embeddings + self.tok_embeddings = self._create_token_embeddings(self.model_parallel) + self.rope_config = self._create_rope_config() + + # Transformer layers + self.layers = nn.ModuleList( + [TransformerBlock(layer_id, self.params).to(self.precision) for layer_id in range(self.n_layers)] + ) + + # Final layer normalization + self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( + self.precision + ) + if self.params["pytorch_rope_version"] == "v1": + self.rope = RotaryPositionEmbeddingPytorchV1(**self.rope_config) + elif self.params["pytorch_rope_version"] == "v2": + # Rotary position embeddings + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + self.rope = RotaryPositionEmbeddingPytorchV2( + seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config + ) + self._broadcast_pos_emb(self.rope.cos_cached, tp_group=self._get_tp_group()) + self._broadcast_pos_emb(self.rope.sin_cached, tp_group=self._get_tp_group()) + else: + raise ValueError(f"Invalid PyTorch RoPE version: {self.params['pytorch_rope_version']}") + # Causal mask + self.causal_mask = torch.tril( + torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) + ).cuda() + + # Output projection + self.output = self._create_output_projection() + + # Freeze network parameters for finetuning w/ cross-attention + self.has_cross_attention = getattr(params, "insert_cross_attn", False) + + # Absolute position embeddings + if self.params["apply_abs_pos_emb"]: + self.pos_emb_config = self._create_abs_pos_emb_config() + self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() + self._broadcast_pos_emb(self.abs_pos_emb, tp_group) + + def _create_rope_config(self) -> Dict: + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + head_dim = self.params["head_dim"] + if head_dim is None: + head_dim = self.params["dim"] // self.params["n_heads"] + return { + "dim": head_dim, + "max_position_embeddings": self.params["max_seq_len"], + "original_max_position_embeddings": self.params["original_seq_len"], + "rope_theta": self.params["rope_theta"], + "apply_yarn": self.params["apply_yarn"], + "scale": self.params["yarn_scale"], + "beta_fast": self.params["yarn_beta_fast"], + "beta_slow": self.params["yarn_beta_slow"], + "rope_dim": self.params["rope_dim"], + "latent_shape": latent_shape, + "original_latent_shape": self.params["original_latent_shape"], + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_abs_pos_emb_config(self): + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + return { + "dim": self.params["dim"], + "latent_shape": latent_shape, + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_token_embeddings(self, model_parallel=None, vocab_size: int = None): + """ + Create token embeddings. + + Args: + model_parallel: The model parallel configuration. + + Returns: + nn.Module: Token embeddings module. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + if tp_size > 1: + emb = VocabParallelEmbedding( + vocab_size, + self.params["dim"], + ).to(self.precision) + return emb + else: + return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) + + def _get_tp_group( + self, + ): + """ + Get tensor parallel process group if applicable. + + Returns: + torch.distributed.ProcessGroup or None: Tensor parallel process group if tensor parallelism is enabled, else None. + """ + if self.params["tensor_model_parallel_size"] > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + log.debug(f"Using tensor model parallel group: {tp_group}") + return tp_group + + return None + + def _create_output_projection(self, vocab_size: int = None): + """ + Create the output projection layer. + + Args: + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + vocab_size (int): Vocabulary size (to override the default vocab size). + Returns: + LinearTE: Output projection layer. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + return nn.Linear(self.params["dim"], vocab_size // tp_size, bias=False).to(self.precision) + + def _initialize_abs_pos_emb(self): + pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + abs_pos_emb = pos_emb.forward(training_type=training_type) + return pos_emb, abs_pos_emb + + def _broadcast_pos_emb(self, pos_emb, tp_group): + """ + Broadcast the position embeddings across the tensor parallel group. + + Args: + pos_emb (torch.Tensor): Position embeddings to broadcast. + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + """ + if self.params["tensor_model_parallel_size"] > 1: + broadcast(pos_emb, min(get_process_group_ranks(tp_group)), group=tp_group) + + def forward( + self, + tokens: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the Transformer module. + + Args: + tokens (torch.Tensor, optional): The input tensor of token IDs. + input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. + token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + The output tensor after applying the transformer layers. + """ + # Token embeddings + assert ( + tokens is None or token_embeddings is None + ), "Either tokens or token_embeddings should be provided, not both." + + if token_embeddings is None: + seq_len = tokens.shape[1] + h = self.tok_embeddings(tokens) + else: + seq_len = token_embeddings.shape[1] + h = token_embeddings + + # Create attention mask + mask = self._create_attention_mask(input_pos=input_pos) + + # Prepare layer arguments + layer_kwargs = self._prepare_layer_kwargs( + input_pos=input_pos, + mask=mask, + context=context, + context_mask=context_mask, + ) + + # Apply transformer layers + for layer in self.layers: + if self.params["apply_abs_pos_emb"]: + h = self.apply_abs_pos_emb(h, input_pos=input_pos) + h = layer(h, **layer_kwargs) + + # Apply final layer normalization + h = self.norm(h) + + # Output linear projection + output = self.output(h) + if self.params["tensor_model_parallel_size"] > 1: + # Use PyTorch all gather + output = all_gather_tensor(output, gather_dim=-1, group=parallel_state.get_tensor_model_parallel_group()) + return output + + def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """ + Creates an attention mask for the transformer layers. + + Args: + input_pos[torch.Tensor]: The position of input sequence (used for inference only). + + Returns: + Optional[torch.Tensor]: The attention mask, or None for causal mask. + """ + + assert input_pos is not None, "input_pos must be provided for inference" + mask = self.causal_mask[input_pos] + return mask + + def _prepare_layer_kwargs( + self, + input_pos: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + context: Optional[torch.Tensor], + context_mask: Optional[torch.Tensor], + ) -> Dict[str, Any]: + """ + Prepares the keyword arguments for transformer layers. + + Args: + input_pos (Optional[torch.Tensor]): The position of the current sequence. + mask (Optional[torch.Tensor]): The attention mask. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. + """ + if context is not None: + context = context.to(self.precision) + + if isinstance(mask, torch.Tensor) and mask.ndim == 2: + mask = mask[None, None, :, :] + if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: + context_mask = context_mask[None, None, :, :] + + layer_kwargs = { + "mask": mask, + "context": context, + "context_mask": context_mask, + } + + layer_kwargs["input_pos"] = input_pos + layer_kwargs["rope"] = self.rope + + return layer_kwargs + + def apply_abs_pos_emb(self, x: torch.Tensor, input_pos: int = None) -> torch.Tensor: + """ + Applies the absolute position embeddings to the input tensor. + """ + abs_pos_emb = self.abs_pos_emb + abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb + return x + abs_pos_emb + + @torch.no_grad() + def expand_vocab( + self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True + ): + """ + Expands the vocabulary of the model to the new size. + + Args: + new_vocab_size (int): The new vocabulary size. + init_method (str): The initialization method for new embeddings. + Can be "zero" or "gaussian". Default is "gaussian". + multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully + leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, + source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) + expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. + + Returns: + None + """ + tp_size = self.params["tensor_model_parallel_size"] + if new_vocab_size <= self.vocab_size: + raise ValueError( + f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" + ) + if new_vocab_size % multiple_of != 0: + log.debug(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") + new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of + log.debug(f"Rounded vocabulary size to {new_vocab_size}.") + # Resize token embeddings + old_embeddings = self.tok_embeddings + tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} + self.tok_embeddings = self._create_token_embeddings( + model_parallel=self.model_parallel, vocab_size=new_vocab_size + ).to(**tensor_kwargs) + # Initialize new embeddings + if init_method not in ["zero", "gaussian"]: + raise ValueError(f"Unknown initialization method: {init_method}") + # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything + # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. + if init_method == "zero": + self.tok_embeddings.weight.data[self.vocab_size // tp_size :].zero_() + + # Copy old embeddings + log.debug( + f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" + ) + self.tok_embeddings.weight.data[: self.vocab_size // tp_size] = old_embeddings.weight.data + # Resize output layer + old_output = self.output + self.output = self._create_output_projection(vocab_size=new_vocab_size if expand_output_layer else None) + + # Initialize new output weights + self.output.weight.data[self.vocab_size // tp_size :].zero_() + # Copy old output weights + self.output.weight.data[: self.vocab_size // tp_size] = old_output.weight.data + + # Update vocab size + self.vocab_size = new_vocab_size + log.debug(f"Expanded vocabulary size to {new_vocab_size}") + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + if strict: + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + missing_keys = actual_missing_keys + return _IncompatibleKeys(missing_keys, unexpected_keys) diff --git a/cosmos_predict1/autoregressive/networks/vit.py b/cosmos_predict1/autoregressive/networks/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..46405ac37fe90c3e212e37493d68774702f02882 --- /dev/null +++ b/cosmos_predict1/autoregressive/networks/vit.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, +designed for processing image inputs in vision-language models. + +This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): +https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py +""" +from functools import partial +from typing import Any, Callable, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.networks.transformer import TransformerBlock +from cosmos_predict1.utils import log + + +def get_vit_config(model_name: str) -> Mapping[str, Any]: + """ + Get the ViT configuration for a given model name. + """ + if model_name == "pixtral-12b-vit": + # The 400M ViT of Pixtral 12B VLM + return dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + else: + raise ValueError(f"Unknown model name: {model_name}") + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + Precompute 2D complex tensor for rotary position embedding. + + This function generates a 2D complex tensor used for rotary position embeddings, + which helps the model understand spatial relationships in the input image. + + Args: + dim (int): Dimension of the model (typically the hidden size divided by number of heads). + height (int): Height of the image in patches. + width (int): Width of the image in patches. + theta (float): Base value for the angle calculation, controls the frequency range. + + Returns: + torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting with input tensor. + + This function ensures that the frequency tensor can be properly broadcast + with the input tensor during the rotary embedding process. + + Args: + freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. + x (torch.Tensor): Input tensor to be embedded. + + Returns: + torch.Tensor: Reshaped frequency tensor ready for broadcasting. + """ + ndim = x.ndim + assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + *args, + freqs_cis: torch.Tensor, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary positional embeddings to input tensors. + + This function applies the rotary positional embeddings to the query and key tensors, + which helps the model understand spatial relationships in the input. + + Args: + xq (torch.Tensor): Query tensor. + xk (torch.Tensor): Key tensor. + freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. + *args: Variable length argument list (unused). + **kwargs: Arbitrary keyword arguments (unused). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class VisionTransformer(nn.Module): + """ + Vision Transformer model for image processing. + + This class implements a Vision Transformer that processes images using a patch-based approach + and applies transformer layers with rotary position embeddings. + + Args: + dim (int): Dimension of the model (hidden size). + num_channels (int): Number of input image channels (e.g., 3 for RGB). + patch_size (int): Size of each image patch (e.g., 16x16 pixels). + n_layers (int): Number of transformer layers. + n_heads (int): Number of attention heads. + ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. + norm_type (str): Type of normalization to use (e.g., "rmsnorm"). + norm_eps (float): Epsilon value for normalization layers. + image_size (int): Size of the input image (assumed square). + rope_theta (float): Base value for rotary position embedding calculation. + attention_dropout (float): Dropout rate for attention layers. + hidden_dropout (float): Dropout rate for hidden layers. + image_token_id (int): Token ID for the image token (if present). + """ + + def __init__( + self, + dim: int = 1024, + num_channels: int = 3, + patch_size: int = 16, + n_layers: int = 24, + n_heads: int = 16, + n_kv_heads: int = None, + ffn_hidden_size: int = 4096, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + image_size: int = 1024, + rope_theta: float = 1000000.0, + image_token_id: int = None, + tensor_model_parallel_size: int = 1, + ): + super().__init__() + self.patch_conv = nn.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) + if n_kv_heads is None: + n_kv_heads = n_heads + layer_args = dict( + n_layers=n_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dim=dim, + use_qk_normalization=False, + max_seq_len=None, + max_batch_size=None, + ffn_hidden_size=ffn_hidden_size, + norm_type=norm_type, + norm_eps=norm_eps, + causal_mask=False, # Full attention in ViT + head_dim=None, + insert_cross_attn=False, + tensor_model_parallel_size=tensor_model_parallel_size, + attn_type="full", + ) + + self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) + + head_dim = dim // n_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + + self.dim = dim + self.n_heads = n_heads + self.max_patches_per_side = image_size // patch_size + self.image_size = image_size + self.patch_size = patch_size + self.rope_theta = rope_theta + self._freqs_cis: Optional[torch.Tensor] = None + self.image_token_id = image_token_id + + num_params = self.get_num_params() + log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") + + @classmethod + def build( + cls, + config: Mapping[str, Any], + ) -> "VisionTransformer": + """ + Create a Vision Transformer from a configuration dictionary. + + This class method creates a Vision Transformer from a configuration dictionary, + which is typically loaded from a JSON file or other configuration source. + + Args: + config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. + + Returns: + VisionTransformer: Vision Transformer model instance. + """ + necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] + missing_keys = [k for k in necessary_keys if k not in config] + assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" + return cls( + **config, + ) + + def expand_in_channels(self, new_in_channels: int): + """ + Expand the input channels of the patch convolution layer. + This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. + Note that you should only call this method after the weight is loaded. + """ + assert ( + new_in_channels > self.patch_conv.in_channels + ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." + log.debug( + f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." + ) + new_conv = nn.Conv2d( + in_channels=new_in_channels, + out_channels=self.patch_conv.out_channels, + kernel_size=self.patch_conv.kernel_size, + stride=self.patch_conv.stride, + bias=False, + ) + new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) + new_conv.weight.data[ + :, self.patch_conv.in_channels : + ].zero_() # zeroize, such that initially it has no effect to output + self.patch_conv = new_conv + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + """ + Get or compute the frequency tensor for rotary position embedding. + + This property lazily initializes and caches the frequency tensor used for + rotary position embeddings, ensuring it's on the correct device. + + Returns: + torch.Tensor: The frequency tensor for rotary position embeddings. + """ + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.dim // self.n_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the Vision Transformer. + + This method processes the input image through the Vision Transformer, + including patch embedding, position embedding, and transformer layers. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, + C is number of channels, and H, W are height and width. + + Returns: + torch.Tensor: Output features of shape (B, N, D), where N is the number of patches + and D is the embedding dimension. + """ + + patch_embeds = self.patch_conv(x) # (B, D, Hp, Wp) + _, _, Hp, Wp = patch_embeds.shape # Patch embeds dim + patch_embeds = patch_embeds.flatten(2) # (B, D, Hp*Wp) + patch_embeds = patch_embeds.transpose(1, 2) # (B, Hp*Wp, D) + patch_embeds = self.ln_pre(patch_embeds) # (B, Hp*Wp, D) + positions = torch.stack( + torch.meshgrid( + torch.arange(Hp), + torch.arange(Wp), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) + out = self.transformer(patch_embeds, rope=rope) + + return out + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + +class VisionTransformerBlocks(nn.Module): + """ + Vision Transformer Blocks. + + This class implements a stack of Transformer blocks used in the Vision Transformer. + + Args: + n_layers (int): Number of transformer layers. + args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, + """ + + def __init__( + self, + n_layers: int, + args: Mapping[str, Any], + ): + super().__init__() + self.layers = torch.nn.ModuleList() + + for layer_id in range(n_layers): + self.layers.append( + TransformerBlock( + layer_id=layer_id, + args=args, + ) + ) + + def forward( + self, + x: torch.Tensor, + rope: Callable, + ) -> torch.Tensor: + """ + Forward pass through the Vision Transformer Blocks. + + This method applies a series of Transformer blocks to the input tensor, + using the provided rotary position embedding function. + + Args: + x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, + N is the number of patches, and D is the embedding dimension. + rope (Callable): Rotary position embedding function to be applied in each layer. + + Returns: + torch.Tensor: Output tensor after passing through all transformer layers, + with the same shape as the input. + """ + for layer in self.layers: + x = layer(x, input_pos=None, mask=None, rope=rope) + return x diff --git a/cosmos_predict1/autoregressive/tokenizer/__init__.py b/cosmos_predict1/autoregressive/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/autoregressive/tokenizer/discrete_video.py b/cosmos_predict1/autoregressive/tokenizer/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd0a832b264387c28932c7dbc4dcf77fbbd935b --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/discrete_video.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange + +from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer + +# Make sure jit model output consistenly during consecutive calls +# Check here: https://github.com/pytorch/pytorch/issues/74534 +torch._C._jit_set_texpr_fuser_enabled(False) + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + model = torch.jit.load(jit_filepath) + return model.eval().to(device) + + +class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): + """ + A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. + + Attributes: + encoder (Module | Callable): Encoder loaded from storage. + decoder (Module | Callable): Decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__() + self.channel = latent_ch + self.name = name + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.pixel_chunk_duration = pixel_chunk_duration + self.latent_chunk_duration = latent_chunk_duration + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + self.levels = levels + self.compress_ratio = compression_ratio + self.fsq_quantizer = FSQuantizer(levels) + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the tokenizer. + """ + return self.channel + + @torch.no_grad() + def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, C, T, H, W = state.shape + if pixel_chunk_duration is None: + # Use the default pixel chunk duration and latent chunk duration + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + # Update the latent chunk duration based on the given pixel chunk duration + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + + assert ( + T % pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" + state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) + + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + quantized_out_list = [] + indices_list = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) + quantized_out_list.append(quantized_out) + indices_list.append(indices) + quantized_out = torch.cat(quantized_out_list, dim=0) + indices = torch.cat(indices_list, dim=0) + else: + indices, quantized_out, _ = self.encoder(state.to(self.dtype)) + assert quantized_out.shape[2] == latent_chunk_duration + return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( + indices, "(b n) t h w -> b (n t) h w", b=B + ) + + @torch.no_grad() + def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, T, _, _ = indices.shape + if pixel_chunk_duration is None: + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + assert ( + T % latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" + indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) + + # use max_dec_batch_size to avoid OOM + if indices.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, indices.shape[0], self.max_dec_batch_size): + state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = self.decoder(indices) + + assert state.shape[2] == pixel_chunk_duration + return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = load_jit_model(enc_fp, device="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = load_jit_model(dec_fp, device="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder + into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, + handles data type conversions, and normalization using provided mean and standard deviation values for latent + space representation. + + Attributes: + tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints + encoder (Callable): tokenizer_module's encode method + decoder (Callable): tokenizer_module's decode method + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + tokenizer_module (Module): Tokenizer module that will have it's weights loaded + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + tokenizer_module: torch.nn.Module, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) + + def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + - def_fp (str): File path to the decoder's JIT file on the remote store. + - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints + """ + self.decoder = load_jit_model(dec_fp) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + encoder_sd = load_jit_model(enc_fp).state_dict() + + del tokenizer_module.post_quant_conv + del tokenizer_module.decoder + + state_dict = { + k: v + for k, v in (encoder_sd).items() + # Variables captured by JIT + if k + not in ( + "encoder.patcher3d.wavelets", + "encoder.patcher3d._arange", + "encoder.patcher3d.patch_size_buffer", + "quantizer._levels", + "quantizer._basis", + "quantizer.implicit_codebook", + ) + } + + tokenizer_module.load_state_dict(state_dict) + + tokenizer_module.eval() + for param in tokenizer_module.parameters(): + param.requires_grad = False + tokenizer_module.to(self.dtype) + + self.tokenizer_module = tokenizer_module + self.encoder = self.tokenizer_module.encode + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.tokenizer_module.to(self.dtype) diff --git a/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..447092664fb89f7e205721571646b3ea29fe7d71 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import transformers +from transformers import AutoImageProcessor +from transformers.image_utils import ImageInput, is_valid_image, load_image + +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer +from cosmos_predict1.utils import log + +# Configuration for different vision-language models +IMAGE_CONFIGS = { + "pixtral": { + "patch_size": 16, + "image_token": "[IMG]", + "image_break_token": "[IMG_BREAK]", + "image_end_token": "[IMG_END]", + } +} + +# Chat template for Pixtral-12B-Instruct +PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}' + + +# Copied from transformers.models.pixtral.processing_pixtral.is_url +def is_url(val) -> bool: + """Check if the given value is a URL.""" + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url +def is_image_or_image_url(elem): + """Check if the given element is an image or an image URL.""" + return is_url(elem) or is_valid_image(elem) + + +def load_image_list( + image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None +) -> List["PIL.Image.Image"]: + """ + Load a list of images. + + Args: + image_list (List[Union[str, PIL.Image.Image]]): The list of images to load. + timeout (Optional[float]): The timeout for loading the image. + + Returns: + List[PIL.Image.Image]: The list of loaded images. + """ + return [load_image(image, timeout=timeout) for image in image_list] + + +class ImageTextTokenizer(TextTokenizer): + """ + Image-text tokenizer class that extends the text tokenizer to support vision tokens as well. + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + tokenizer_path: str, + image_processor_path: str, + ): + """ + Initialize the ImageTextTokenizer. + + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + + Raises: + AssertionError: If the model family is not supported or if the transformers version is incompatible. + """ + super().__init__( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ) + assert model_family in ["pixtral"], f"Unsupported model family: {model_family}" + if model_family == "pixtral": + # Need transformers>=4.45.0 + assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0" + assert is_instruct_model, "Pixtral requires is_instruct_model=True" + if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None: + setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE) + log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}") + + # Set up image-specific configurations + image_config = IMAGE_CONFIGS[model_family] + self.patch_size = image_config["patch_size"] + self.image_token = image_config["image_token"] + self.image_break_token = image_config["image_break_token"] + self.image_end_token = image_config["image_end_token"] + + # Initialize the image processor + self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + **text_kwargs, + ) -> List[int]: + """ + Process the images and return the tokenized images and text. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + **text_kwargs: Additional keyword arguments for text processing. + + Returns: + A dictionary with the following fields: + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. + + Raises: + ValueError: If the input images are in an invalid format. + """ + + output_dict, image_inputs = {}, {} + + if images is not None: + if is_image_or_image_url(images): + images = [images] + elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]): + pass + elif ( + isinstance(images, (list, tuple)) + and isinstance(images[0], (list, tuple)) + and is_image_or_image_url(images[0][0]) + ): + images = [image for sublist in images for image in sublist] + else: + raise ValueError( + "Invalid input images. Please provide a single image, a list of images, or a list of lists of images." + ) + images = [load_image(im) if isinstance(im, str) else im for im in images] + image_kwargs = image_kwargs or {} + image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs) + + # Validate image inputs + assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs" + assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs" + assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format( + image_inputs.keys() + ) + + # Extract pixel values and image sizes + pixel_values = image_inputs["pixel_values"] + image_sizes = image_inputs["image_sizes"] + unique_sizes = np.unique(image_sizes, axis=0) + + assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes) + + # Convert pixel values to PyTorch tensor + pixel_values = np.asarray(pixel_values) + pixel_values = torch.from_numpy(pixel_values) + output_dict["pixel_values"] = pixel_values + output_dict["image_sizes"] = image_sizes + + # Expand image tokens in text + if image_inputs.get("pixel_values") is not None: + replace_strings = [] + # Calculate the number of tokens needed for each image and create a placeholder + for image_size in image_sizes: + height, width = image_size + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + text = text.replace(self.image_token, "", 1) + + # Replace placeholders with actual image token sequences + while "" in text: + replace_str = replace_strings.pop(0) + text = text.replace("", replace_str, 1) + + # Encode the text + text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs) + + output_dict["input_ids"] = text_inputs + return output_dict + + def apply_chat_template( + self, + conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]], + *, + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = True, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Apply the chat template to the conversation. + + Args: + conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process. + images (Optional[ImageInput]): Images to include in the conversation. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + add_generation_prompt (bool): Whether to add a generation prompt. + tokenize (bool): Whether to tokenize the output. + padding (bool): Whether to pad the output. + truncation (bool): Whether to truncate the output. + max_length (Optional[int]): Maximum length of the output. + return_tensors (Optional[str]): The type of tensors to return. + return_dict (bool): Whether to return a dictionary. + return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. + **kwargs: Additional keyword arguments. + + Returns: + The processed conversation with applied chat template. + + Raises: + AssertionError: If return_dict is False or if the conversation format is invalid. + """ + assert return_dict, "return_dict must be True for ImageTextTokenizer" + assert isinstance(conversation, list), "conversation must be a list" + if isinstance(conversation[0], list): + assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation) + conversation = conversation[0] + + # Extract images from the conversation if not provided + if images is None: + images = [] + for msg in conversation: + if msg.get("images", None) is not None: + images = images + (msg["images"]) + images = load_image_list(images) + # In case the input does not have images, will ignore + # Useful in feeding VLM inputs with and without images + if isinstance(images, list) and len(images) == 0: + images = None + + # Apply the chat template to the text + text = super().apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=add_generation_prompt, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=False, + return_assistant_tokens_mask=return_assistant_tokens_mask, + generation_prefix=generation_prefix, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + # Encode the text and images + output = self.encode( + text, + images=images, + image_kwargs=image_kwargs, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + return output + + @property + def model_input_names(self): + """ + Get the combined model input names from both the text tokenizer and image processor. + + Returns: + List[str]: A list of unique input names. + """ + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/cosmos_predict1/autoregressive/tokenizer/modules.py b/cosmos_predict1/autoregressive/tokenizer/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..68fee5493ec40f3ba4ba205eb2e3c26dd1c5c9f0 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/modules.py @@ -0,0 +1,560 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/ +magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cosmos_predict1.autoregressive.tokenizer.patching import Patcher3D, UnPatcher3D +from cosmos_predict1.autoregressive.tokenizer.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) +from cosmos_predict1.utils import log + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalHybridUpsample3d(nn.Module): + def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0) + if temporal_up + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1) + if spatial_up + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_up or temporal_up + else nn.Identity() + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0) + if spatial_down + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0) + if temporal_down + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_down or temporal_down + else nn.Identity() + ) + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, spatial_down=spatial_down, temporal_down=temporal_down + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed in the encoder should correspond + # to the layer index, inreverse order, where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + # For example: + # Input tensor = (1, 3, 17, 32, 32) + # Patch size = 4 for 3D wavelet transform + # Compression rate = (8x16x16) + # + # We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)` + # + # if legacy_mode is True, the temporal upsampling is not perfectly mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)` + # + # Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored. + # Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/cosmos_predict1/autoregressive/tokenizer/networks.py b/cosmos_predict1/autoregressive/tokenizer/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..56b0c5fb7a1dec7e6282c66f7a34253925c11ffe --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/networks.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +import torch +from torch import nn + +from cosmos_predict1.autoregressive.tokenizer.modules import CausalConv3d, DecoderFactorized, EncoderFactorized +from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer +from cosmos_predict1.utils import log + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) + self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + self.quantizer = FSQuantizer(**kwargs) + + num_parameters = sum(param.numel() for param in self.parameters()) + log.debug(f"model={self.name}, num_parameters={num_parameters:,}") + log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) + return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) diff --git a/cosmos_predict1/autoregressive/tokenizer/patching.py b/cosmos_predict1/autoregressive/tokenizer/patching.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5b621f9d526cff7966c77225656e9327adde30 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/patching.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The patcher and unpatcher implementation for 2D and 3D data.""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=self.patch_size, p2=self.patch_size).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", patch_size * torch.ones([1], dtype=torch.int32), persistent=_PERSISTENT + ) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + return x + + def _iarrange(self, x): + x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/cosmos_predict1/autoregressive/tokenizer/quantizers.py b/cosmos_predict1/autoregressive/tokenizer/quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..5618d1fbffa4fc68133332de23e3d8fb5dd03dc5 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/quantizers.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from cosmos_predict1.autoregressive.tokenizer.utils import default, pack_one, round_ste, unpack_one + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) diff --git a/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8cb96aefc49922d24173cdb6af1c58fec231bc --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoTokenizer + +from cosmos_predict1.utils import log + + +def get_tokenizer_path(model_family: str, is_instruct_model: bool = False): + """ + Get the tokenizer path from the model family and instruct model flag. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + Returns: + str: The tokenizer path. + """ + model_family = model_family.lower() + if model_family == "mistral": + return "mistralai/Mistral-Nemo-Instruct-2407" + else: + assert model_family in ["llama3", "llama3.1"] + if model_family == "llama3": + model_path = "meta-llama/Meta-Llama-3-8B" + elif model_family == "llama3.1": + model_path = "meta-llama/Llama-3.1-8B" + else: + raise ValueError(f"Unsupported model family: {model_family}") + suffix = "-Instruct" if is_instruct_model else "" + model_path = f"{model_path}{suffix}" + return model_path + + +class TextTokenizer: + """ + Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based). + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + local_path: Optional[str] = None, + ): + """ + Initialize the TextTokenizer. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path. + """ + if local_path is None: + tokenizer_path = get_tokenizer_path(model_family, is_instruct_model) + else: + tokenizer_path = local_path + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + self.stop_tokens = { + self.tokenizer.eos_token_id, + } + self.model_family = model_family + self.is_instruct_model = is_instruct_model + self.eos_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + if model_family.startswith("llama"): + self.pad_id = 128004 # "<|finetune_right_pad_id|>" + elif model_family == "mistral": + self.pad_id = 10 # "" + elif model_family == "pixtral": + self.pad_id = 11 # "" + else: + raise ValueError(f"pad_id not defined for model_family {model_family}") + else: + self.pad_id = self.tokenizer.pad_token_id + + def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]: + """ + Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. + + Args: + text (`str`): + The sequence to be encoded. + add_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add the special tokens associated with the corresponding model. + Returns: + `List[str]`: The list of tokens. + """ + return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[str] = None, + **kwargs, + ) -> List[int]: + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Args: + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding the sequences. This will use the underlying + `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are + automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens + automatically. + padding (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + return self.tokenizer.encode( + text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + ) + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"], + *, # Enforce keyword-only arguments + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + return self.tokenizer.decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + *, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = False, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting. + + More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template + + Args: + conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts + with "role" and "content" keys, representing the chat history so far. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, defaults to `False`): + Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + + Returns: + `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. + """ + if not self.is_instruct_model: + raise ValueError( + "apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor." + ) + # Since generation_prefix is added to the text in the end, ensure that the setting is correct + if generation_prefix: + assert not tokenize, "tokenize must be False when generation_prefix is provided." + assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided." + formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=return_dict, + return_assistant_tokens_mask=return_assistant_tokens_mask, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + if generation_prefix: + formatted_text: str = formatted_text + generation_prefix + log.debug( + f"Adding generation prefix: {generation_prefix} to the formatted text\n" + f"Formatted text: {formatted_text}" + ) + return formatted_text diff --git a/cosmos_predict1/autoregressive/tokenizer/tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbda5f7c4eccc36ba85eca89fe5dd841fcc0786 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/tokenizer.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Optional + +import torch +from einops import rearrange + +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +def update_vocab_size( + existing_vocab_size, + to_be_added_vocab_size, + training_type, + add_special_tokens, + video_special_tokens={}, +): + # New vocab size + if add_special_tokens: + existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens) + # For text_to_video, we add one special token at the beginning of the video + elif training_type == "text_to_video": + existing_vocab_size += to_be_added_vocab_size + 1 + else: + existing_vocab_size += to_be_added_vocab_size + return existing_vocab_size + + +class DiscreteMultimodalTokenizer: + def __init__(self, tokenizer_config: TokenizerConfig): + self.tokenizer_config = tokenizer_config + self.vocab_size = 0 + self.total_seq_len = tokenizer_config.seq_len + self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of + self.training_type = tokenizer_config.training_type + assert self.training_type in [ + "text_only", + "text_to_video", + "video_to_video", + "image_text_interleaved", + ], f"{self.training_type} not supported" + + self._build_text_tokenizer() + self._build_video_tokenizer() + + def _build_text_tokenizer(self): + r"""Function to initialize the text tokenizer model.""" + if self.tokenizer_config.text_tokenizer is not None: + self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config) + self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size + else: + self.text_tokenizer = None + + def _build_video_tokenizer(self): + r"""Function to initialize the video tokenizer model.""" + if self.tokenizer_config.video_tokenizer is not None: + self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config) + self.video_tokenizer = self.video_tokenizer.to("cuda") + self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size + special_token_offset = ( + self.tokenizer_config.video_tokenizer.tokenizer_offset + + self.tokenizer_config.video_tokenizer.vocab_size + ) + self.video_special_tokens = { + "<|begin_of_video|>": special_token_offset, + "<|end_of_video|>": special_token_offset + 1, + "<|pad_token_video|>": special_token_offset + 2, + } + + self.vocab_size = update_vocab_size( + existing_vocab_size=self.vocab_size, + to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size, + training_type=self.training_type, + add_special_tokens=self.tokenizer_config.add_special_tokens, + video_special_tokens=self.video_special_tokens, + ) + else: + self.video_tokenizer = None + + @property + def pad_id(self): + r"""Returns the pad_id.""" + + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + pad_id = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + pad_id = self.video_special_tokens["<|pad_token_video|>"] + else: + raise ValueError(f"training_type {self.training_type} not defined") + return pad_id + + @property + def ignore_index(self): + r"""Returns which token should be ignored during loss computation.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id: + # If the PAD token is the same as the EOS token, we do not ignore it during loss + # computation, since we want the model to be able to predict EOS tokens in inference. + # The PyTorch default ignore_index for the cross-entropy loss is -100. + ignore_index = -100 + else: + ignore_index = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + ignore_index = self.pad_id + else: + raise ValueError(f"training_type {self.training_type} not defined") + return ignore_index + + @property + def stop_tokens(self): + r"""Returns the stop tokens.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + stop_tokens = self.text_tokenizer.stop_tokens + elif self.training_type in ["text_to_video", "video_to_video"]: + stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]]) + else: + raise ValueError(f"training_type {self.training_type} not defined") + return stop_tokens + + def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1): + r"""Function to tokenize text. + Args: + raw_text (list[str]): List of input strings + max_text_seq_len (int): Maximum sequence length returned by text tokenizer + Returns: + text_tokens (list[list[int]]): List of text tokens + """ + + batch_size = len(raw_text) + text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)] + + # Clipping the text tokens so that the sequence length does not exceed max_text_seq_len + if max_text_seq_len > -1: + for i in range(len(text_tokens)): + if len(text_tokens[i]) > max_text_seq_len: + # Simply clip and add end of seq token + text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id] + return text_tokens + + def _tokenize_class(self, cls_labels: list[str]): + r"""Function to tokenize the class label. + Args: + cls_labels (list[str]): List of class indices + Returns: + class_tokens (list[list[int]]): List of class tokens + """ + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels] + + return class_tokens + + def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None): + r"""Function to tokenize video. + Args: + videos (torch.Tensor): Input video data tensor + pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer. + Returns: + video_tokens (list[list[int]]): List of video tokens + """ + + video_tokens = [] + batch_size = videos.shape[0] + + quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration) + indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) + + # Flatten the indices + indices = rearrange(indices, "B T H W -> B (T H W)") + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + indices += self.tokenizer_config.video_tokenizer.tokenizer_offset + + # Add begin and end of video tokens + bov_token = self.video_special_tokens["<|begin_of_video|>"] + eov_token = self.video_special_tokens["<|end_of_video|>"] + + # Append bov and eov tokens + if self.tokenizer_config.add_special_tokens: + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist() + [eov_token]) + else: + if self.training_type == "text_to_video": + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist()) + else: + for i in range(batch_size): + video_tokens.append(indices[i].tolist()) + assert ( + len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len + ), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}" + + return video_tokens + + def tokenize(self, data_batch: dict): + r"""Function to tokenize data_dict. + Args: + data_batch (dict): Input data dict + Returns: + tokens (torch.LongTensor): Token tensor dict + """ + + if ( + self.training_type in ["text_only", "image_text_interleaved"] + and not self.tokenizer_config.text_tokenizer.tokenize_here + ): + # In case of pre-computed tokens, just return the data_batch + return data_batch["tokens"], None + + # Online tokenization + tokens = [] + token_boundaries = defaultdict(list) + + # Obtain maximum sequence length + max_text_seq_len = -1 + max_visual_seq_len = -1 + + if self.training_type in ["text_to_video", "video_to_video"]: + max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len + + # If max visual sequence length is specified, make sure that text is clipped so that + # the full video/image is always seen. + if max_visual_seq_len > -1: + if self.tokenizer_config.add_special_tokens: + max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token + elif self.training_type == "text_to_video": + max_visual_seq_len = max_visual_seq_len + 1 + else: + max_visual_seq_len = max_visual_seq_len + assert ( + max_visual_seq_len <= self.total_seq_len + ), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})" + max_text_seq_len = self.total_seq_len - max_visual_seq_len + + # Tokenize the text + if ( + "text" in self.training_type + and self.text_tokenizer is not None + and self.tokenizer_config.text_tokenizer.tokenize_here + ): + key = self.tokenizer_config.text_tokenizer.data_key + batch_size = len(data_batch[key]) + assert key in data_batch, f"Key {key} should be present in data for text tokenizer" + tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len) + + for i in range(batch_size): + token_boundaries["text"].append((0, len(tokens[i]))) + else: + tokens = [] + batch_size = None + + # Tokenize the class label + if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None: + key = self.tokenizer_config.class_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for class tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + tokens_class = self._tokenize_class(data_batch[key]) + if len(tokens) == 0: + tokens = tokens_class + for i in range(batch_size): + token_boundaries["class"].append((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i]))) + tokens[i] = tokens[i] + tokens_class[i] + + # Tokenize the video + if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here: + key = self.tokenizer_config.video_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for video tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + + pixel_chunk_duration = ( + None # If not specified, we assume it's a video dataset and use the default chunk duration + ) + dataset_name = data_batch.get("dataset_name", None) + if dataset_name is not None and dataset_name.startswith("image"): + # If it's an image dataset, we use a pixel chunk duration of 1 + pixel_chunk_duration = 1 + tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration) + if len(tokens) == 0: + tokens = tokens_video + for i in range(batch_size): + token_boundaries["video"].append((0, len(tokens[i]))) + # [B,] each entry is ((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i]))) + tokens[i] = tokens[i] + tokens_video[i] + + # Combine the tokens and do padding + max_seq_len_in_batch = max([len(token) for token in tokens]) + if self.pad_to_multiple_of is not None: + # Pad the sequence length to the nearest multiple of pad_to_multiple_of + max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of + pad_to_len = min(max_seq_len_in_batch, self.total_seq_len) + for i in range(len(tokens)): + if len(tokens[i]) < pad_to_len: + tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i])) + else: + tokens[i] = tokens[i][0:pad_to_len] + + # Convert it to long tensor + tokens = torch.LongTensor(tokens) + return tokens, token_boundaries diff --git a/cosmos_predict1/autoregressive/tokenizer/utils.py b/cosmos_predict1/autoregressive/tokenizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dd58c7830e60e5a09a38b991ccb5fef3b13293 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/utils.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() diff --git a/cosmos_predict1/autoregressive/train.py b/cosmos_predict1/autoregressive/train.py new file mode 100644 index 0000000000000000000000000000000000000000..aa95d4e4738e8e8c6147cdab9970f92b5e192597 --- /dev/null +++ b/cosmos_predict1/autoregressive/train.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +from loguru import logger as logging +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from omegaconf import OmegaConf + +from cosmos_predict1.utils import misc +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig + + +@misc.timer("instantiate LLM") +def instantiate_model(config, trainer) -> None: + model_parallel_cuda_manual_seed(config.trainer.seed) + model = instantiate(config.model) + if not config.model["model_config"].set_parallel_mode: + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + + return model + + +@logging.catch(reraise=True) +def launch(config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # Create the model + model = instantiate_model(config, trainer) + + model.on_model_init_end() + dataloader_train = instantiate(config.dataloader_train) + dataloader_val = instantiate(config.dataloader_val) + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training") + parser.add_argument( + "--config", default="projects.cosmos.ar.v1.configs.train_openhermes", help="Path to the config file" + ) + parser.add_argument("--cluster", default=None, help="Cluster name") + parser.add_argument( + "opts", + help="""Modify config options at the end of the command. For Yacs configs, use + space-separated "PATH.KEY VALUE" pairs. + For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + args = parser.parse_args() + config = importlib.import_module(get_config_module(args.config)).make_config() + config = override(config, args.opts) + if args.dryrun: + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/autoregressive/trainer.py b/cosmos_predict1/autoregressive/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7684384cab86e58071d9eb4da8fc732504f4811b --- /dev/null +++ b/cosmos_predict1/autoregressive/trainer.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signal + +import torch +import torch.distributed as dist +import torch.utils.data +from megatron.core import parallel_state + +from cosmos_predict1.checkpointer.tp import Checkpointer as TensorParallelCheckpointer +from cosmos_predict1.utils import distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class Trainer(Trainer): + def __init__(self, config): + super(Trainer, self).__init__(config) + if config.trainer.distributed_parallelism == "ddp": + if parallel_state.get_tensor_model_parallel_world_size() > 1: + self.checkpointer = TensorParallelCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + log.critical("Using Tensor Parallelism Checkpointer") + else: + self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) + + elif config.trainer.distributed_parallelism == "fsdp": + self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") + + """ + Modify the original trainer to log average loss (averaging across all devices and gradient accumulation) + """ + + def train( + self, + model: Model, + dataloader_train: torch.utils.data.DataLoader, + dataloader_val: torch.utils.data.DataLoader, + ) -> None: + """The training function. + + Args: + model (Model): The PyTorch model. + dataloader_train (torch.utils.data.DataLoader): The training data loader. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + """ + # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. + model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore + log.info(f"Model Architecture:\n {model}") + model.on_train_start(self.config.trainer.memory_format) + # Initialize the optimizer and scheduler. + self.callbacks.on_optimizer_init_start() + + optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) + + grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) + self.callbacks.on_optimizer_init_end() + # Load the model checkpoint and get the starting iteration number. + iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) + # Set the scheduler to the current iteration. + scheduler.last_epoch = iteration + scheduler._step_count = iteration + 1 + + grad_accum_iter = 0 + log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + if self.config.trainer.distributed_parallelism == "ddp": + # Create a DDP model wrapper. + model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) + elif self.config.trainer.distributed_parallelism == "fsdp": + model_ddp = model + else: + raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + log.info("Starting training...") + self.callbacks.on_train_start(model, iteration=iteration) + # Initial validation. + if self.config.trainer.run_validation and iteration == 0: + self.validate(model, dataloader_val, iteration=iteration) + _end_training = False + self.callbacks.on_before_dataloading(iteration) + accumulated_loss = 0.0 + + while True: + dataloader_train_iter = iter(dataloader_train) + while True: + self.callbacks.on_before_dataloading(iteration) + try: + data_batch = next(dataloader_train_iter) + except StopIteration: + break + self.callbacks.on_after_dataloading(iteration) + # If max_iter is reached, exit the training loop. + if iteration >= self.config.trainer.max_iter: + _end_training = True + break + # Move all tensors in the data batch to GPU device. + + data_batch = misc.to(data_batch, device="cuda") + # The actual training step. + self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) + model_ddp.train() + output_batch, loss, grad_accum_iter = self.training_step( + model_ddp, + optimizer, + scheduler, + grad_scaler, + data_batch, + iteration=iteration, + grad_accum_iter=grad_accum_iter, + ) + + # Accumulate loss + accumulated_loss += loss.detach() + + # If the gradients are still being accumulated, continue to load the next training batch. + if grad_accum_iter != 0: + if self.enable_one_logger: + # Callback for skipped OneLoggerCallback.on_training_step_end() + self.one_logger.on_train_batch_end(set_barrier=False) + continue + # Do the following when an actual optimizer (update) step has been made. + iteration += 1 + + # Average loss over accumulation steps + grad_accum_avg_loss = accumulated_loss / self.config.trainer.grad_accum_iter + # Average loss across all devices + device_avg_loss = grad_accum_avg_loss.clone() + dist.all_reduce(device_avg_loss, op=dist.ReduceOp.SUM) + device_avg_loss /= dist.get_world_size() + # Reset accumulation variables + accumulated_loss = 0.0 + + self.callbacks.on_training_step_end( + model, data_batch, output_batch, device_avg_loss, iteration=iteration + ) + + # self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) + + # Validation. + if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: + self.validate(model, dataloader_val, iteration=iteration) + # Save checkpoint. + if iteration % self.config.checkpoint.save_iter == 0: + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + # This iteration is successful; reset the timeout signal. + signal.alarm(self.config.trainer.timeout_period) + if _end_training: + break + log.success("Done with training.") + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + self.callbacks.on_train_end(model, iteration=iteration) + self.checkpointer.finalize() + distributed.barrier() + self.callbacks.on_app_end() + + def training_step( + self, + model_ddp: torch.nn.Module | distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + data: dict[str, torch.Tensor], + iteration: int = 0, + grad_accum_iter: int = 0, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: + """The training step. + + Args: + model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare + module, depending on whether distributed training is enabled or not. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + grad_accum_iter (int): Number of gradient accumulation iterations. + + Returns: + output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). + loss (torch.Tensor): The total loss of the training data batch. + """ + # Only let DDP sync gradient at the last iteration of the gradient accumulation window + with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): + with self.training_timer("forward"): + output_batch, loss = model_ddp.training_step(data, iteration) + self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) + with self.training_timer("backward"): + loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) + loss_scaled.backward() + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_after_backward() + else: + model_ddp.on_after_backward() + self.callbacks.on_after_backward(model_ddp, iteration=iteration) + grad_accum_iter += 1 + if grad_accum_iter == self.config.trainer.grad_accum_iter: + with self.training_timer("optimizer_step"): + self.callbacks.on_before_optimizer_step( + model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration + ) + grad_scaler.step(optimizer) + grad_scaler.update() + scheduler.step() + self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + else: + model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + optimizer.zero_grad(set_to_none=True) + grad_accum_iter = 0 + return output_batch, loss, grad_accum_iter + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + with ema.ema_scope(model, enabled=getattr(model.config.ema, "enabled", False)): + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, loss = model.validation_step(data_batch, iteration) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/autoregressive/training/model.py b/cosmos_predict1/autoregressive/training/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3fccb2c37a8c703470005ef46a9a5b6ede146ee5 --- /dev/null +++ b/cosmos_predict1/autoregressive/training/model.py @@ -0,0 +1,1240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from megatron.core import InferenceParams, ModelParallelConfig, parallel_state +from safetensors.torch import load_file +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.configs.base.model import TrainingModelConfig as ModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector +from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config + +# from cosmos_predict1.autoregressive.training.networks.transformer_medusa import TransformerMedusa +from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer +from cosmos_predict1.autoregressive.training.networks.transformer import ( + Transformer, + TransformerBlock, + TransformerBlockTE, +) +from cosmos_predict1.autoregressive.utils.checkpoint import ( + get_partial_state_dict, + maybe_convert_checkpoint_to_backend, + obtain_tensor_parallel_state_dict, + process_state_dict, + substrings_to_ignore, +) +from cosmos_predict1.autoregressive.utils.misc import random_dropout +from cosmos_predict1.autoregressive.utils.parallel import broadcast_data_batch_in_tp_cp_group, get_batch_on_this_cp_rank +from cosmos_predict1.autoregressive.utils.sampling import ( + decode_n_tokens, + decode_one_token, + prefill, + sample_top_k, + sample_top_p, +) +from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.misc import download_from_s3_with_cache, sync_s3_dir_to_local +from cosmos_predict1.utils.model import Model + + +class AutoRegressiveTrainingModel(Model): + """ + A class to build and use a Llama model for text generation. + + Methods: + build: Build a Llama instance by initializing and loading a model checkpoint. + generate: Generate text sequences based on provided prompts using the language generation model. + """ + + def __init__( + self, + model: Transformer, + tokenizer: DiscreteMultimodalTokenizer, + config: ModelConfig, + model_parallel: ModelParallelConfig = None, + vision_encoder: VisionTransformer = None, + mm_projector: MultimodalProjector = None, + ): + """ + Initialize the Llama instance with a model and tokenizer. + + Args: + model (Transformer): The Transformer model for text generation. + tokenizer (Tokenizer): The tokenizer for encoding and decoding text. + config (Config): The configuration for the Llama model. + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + self.precision = self.model.precision + self.vision_encoder = vision_encoder + self.mm_projector = mm_projector + assert (self.vision_encoder is None) == (self.mm_projector is None), ( + "vision_encoder and mm_projector should be " "both None or not None simultaneously" + ) + self.model_parallel = model_parallel + self.monitor_output_logits = False + self.inference_params = None + # self.insert_medusa_head = self.config.insert_medusa_head + + if self.config.freeze_vision_encoder and vision_encoder is not None: + for param in self.vision_encoder.parameters(): + param.requires_grad = False + log.critical("Vision encoder parameters are frozen.") + + num_params = self.get_num_params() + log.info(f"Number of model parameters: {round(num_params / 1e9, 3)}B") + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + broadcast_data_batch_in_tp_cp_group(data_batch) + # get the context embedding and mask + context = data_batch.get("context", None) + context_mask = data_batch.get("context_mask", None) + if context is not None: + if self.config.embedding_dropout > 0: + context = random_dropout( + context, + self.config.embedding_dropout, + ) + context = misc.to(context, device="cuda") + if context_mask is not None: + context_mask = misc.to(context_mask, device="cuda") + action = data_batch.get("action", None) + if action is not None: + action = misc.to(action, device="cuda") + # Input tokens + tokens, token_boundaries = self.tokenizer.tokenize(data_batch) + tokens = misc.to(tokens, device="cuda") + # Tokens to predict + labels = data_batch.get("labels", None) + # Token Mask (Note: this is not attention mask) + masks = data_batch.get("token_mask", None) + apply_token_mask = masks is not None + if masks is None: + masks = torch.ones_like(tokens, dtype=torch.bool) + masks = misc.to(masks, device="cuda") + assert ( + data_batch.get("labels", None) is None or apply_token_mask + ), "The code is not tested for the case when both labels and token_mask are provided." + + if self.config.ignore_first_num_tokens > 0: + assert self.config.ignore_first_num_tokens < masks.shape[1] + masks[:, : self.config.ignore_first_num_tokens] = False + seq_len = tokens.shape[1] + + # Boradcast inputs to TP and CP ranks, alternatively we can use the `_broadcast` function from cosmos/diffusion/v1 + # Currently we only handled video tokens (with label and mask) and text tokens (with mask), action and other inputs might also need to be handled + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.model.enable_context_parallel(cp_group) + tokens = get_batch_on_this_cp_rank(tokens) + masks = get_batch_on_this_cp_rank(masks) + if labels is not None: + labels = get_batch_on_this_cp_rank(labels) + if self.vision_encoder is None: + logits = self.model.forward( + tokens=tokens, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + total_seq_len=seq_len, + ) + else: + assert "images" in data_batch + images = data_batch["images"] + if images.ndim == 5: + # The shape is (batch_size, n_images_per_sample, C, H, W). Flatten the first two dimensions. + images = images.view(-1, *images.shape[2:]) + assert images.ndim == 4, f"Invalid shape: {images.shape}" + token_embeddings = self.embed_vision_language_features(tokens, images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + total_seq_len=seq_len, + ) + + if labels is None: + # For auto-regressive models, the labels are the same as the + # input tokens shifted by one position + logits = logits[:, :-1] + masks = masks[:, :-1] + labels = tokens[:, 1:].clone() + + batch_size = tokens.shape[0] + # Apply ignore_index + for sample_num in range(batch_size): + if self.tokenizer.training_type == "text_to_video": + # For text-to-video training, we do not compute the loss of text part + # Hence, we set the labels of text tokens to that of ignore_index + if len(token_boundaries["text"]) > 0: + labels[sample_num][0 : token_boundaries["text"][sample_num][1] - 1] = self.tokenizer.ignore_index + elif self.tokenizer.training_type == "class_to_image": + # For class-to-image training, we do not compute the loss of class part + # Hence, we set the labels of class tokens to that of ignore_index + labels[sample_num][0 : token_boundaries["class"][sample_num][1] - 1] = self.tokenizer.ignore_index + + ignore_index = self.tokenizer.ignore_index + if self.config.ignore_first_num_tokens > 0 or apply_token_mask: + labels[~masks] = ignore_index + + output_batch = { + "encode_tokens": tokens, + "logits": logits.detach(), + "labels": labels.detach(), + "ignore_index": ignore_index, + } + + if self.monitor_output_logits: + self.gather_output_logits_stats(logits, labels, output_batch, ignore_index) + + logits = logits.flatten(0, 1) + labels = labels.flatten(0, 1) + + # Main cross entropy loss + ce_loss = F.cross_entropy( + input=logits, + target=labels, + ignore_index=ignore_index, # ignore prompt (turn prompt tokens into pad_id here) + ) + + # Z-loss + log_z = torch.logsumexp(logits, dim=-1) # shape: [B, seq_len] + z_loss = self.config.z_loss_coeff * (log_z**2).mean() + + # Combined loss + total_loss = ce_loss + z_loss + + return output_batch, total_loss # skip returning output logits + + @torch.no_grad() + def validation_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Perform a validation step for the model, which is the same as the training step (but without backpropagation). + """ + return self.training_step(data_batch, iteration) + + @torch.no_grad() + def gather_output_logits_stats( + self, logits: torch.Tensor, labels: torch.Tensor, output_batch: Dict, ignore_index: int = None + ): + """ + Gather statistics of the output logits, including mean, norm, and max values. + """ + bs, seq_len, dim = logits.shape + logits = logits.reshape(-1, dim) + if ignore_index is not None: + select_index = labels.view(-1) != ignore_index + acc = labels.view(-1)[select_index] == logits.argmax(dim=1)[select_index] + acc = acc.float().mean().view(-1, 1) + + logits = logits[select_index] + output_batch.update( + { + "logits_mean": logits.mean(dim=1).detach(), + "logits_norm": torch.linalg.vector_norm(logits, dim=1).detach(), + "logits_max": logits.max(dim=1).values.detach(), + "acc": acc.detach() * 100, + } + ) + + @torch.no_grad() + def image_encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the image input state to continuous latent and discrete indices. + """ + latent, indices = self.tokenizer.image_tokenizer.encode(state) + return latent, indices + + @torch.no_grad() + def image_decode(self, indices: torch.Tensor) -> torch.Tensor: + """ + Decode the discrete indices to RGB images. + """ + return self.tokenizer.image_tokenizer.decode(indices) + + @torch.no_grad() + def video_encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the video input state to continuous latent and discrete indices. + """ + latent, indices = self.tokenizer.video_tokenizer.encode(state) + return latent, indices + + @torch.no_grad() + def video_decode(self, indices: torch.Tensor) -> torch.Tensor: + """ + Decode the discrete indices to RGB videos. + """ + if self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap > 0: + return self.tokenizer.video_tokenizer.decode_with_overlap( + indices, temporal_overlap=self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap + ) + else: + return self.tokenizer.video_tokenizer.decode(indices) + + @staticmethod + def load_llm_checkpoint( + ckpt_path: str = "", + model: Transformer = None, + **kwargs, + ) -> None: + """ + Load a LLM checkpoint from the specified path. + """ + with misc.timer(f"loading checkpoint from {ckpt_path}"): + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + + @staticmethod + def build( + seed: int = 1, + train_from_scratch: bool = False, + model_config: ModelConfig = ModelConfig(), + fsdp_checkpointer: Any = None, + tokenizer_config: TokenizerConfig = None, + model_parallel: ModelParallelConfig = None, + shard_checkpoint: bool = True, + download_rank_sync: bool = True, + **kwargs, + ) -> "AutoRegressiveTrainingModel": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + seed (int, optional): Random seed for reproducibility. Defaults to 1. + train_from_scratch (bool, optional): Flag indicating whether to train the model from scratch. Defaults to False. + model_config (ModelConfig, optional): The model configuration for the Llama instance. Defaults to ModelConfig(). + fsdp_checkpointer (Any, optional): The FSDP checkpointer for the Llama instance. Defaults to None. + tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the Llama instance. Defaults to None. + shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. + download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory. + + Note: + This method sets the device to CUDA and loads the pre-trained model and tokenizer. + """ + tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size + # seed must be the same in all processes + torch.manual_seed(seed) + + # Initialize model configuration parameters + llama_params = {} + + # Load checkpoint and model parameters + if not train_from_scratch: + if model_config.ckpt_path is None: + # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir + ckpt_dir = sync_s3_dir_to_local( + s3_dir=model_config.ckpt_dir, + s3_credential_path=model_config.s3_credential_path, + cache_dir=model_config.cache_dir, + ) + + # We prioritize safetensors version over the pytorch version, since the former is + # much faster for checkpoint loading. + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + if len(checkpoints) == 0: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert ( + len(checkpoints) == 1 + ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" + ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case + + if os.path.exists(Path(ckpt_dir) / "params.json"): + with open(Path(ckpt_dir) / "params.json", "r") as f: + llama_params = json.loads(f.read()) + else: + log.info( + f"No params.json found in the checkpoint directory ({ckpt_dir}). " + f"Using default model config." + ) + + else: + # If ckpt_path is provided, we load the model from the specified path, + # and use the default model configuration + ckpt_path = download_from_s3_with_cache( + s3_path=model_config.ckpt_path, + s3_credential_path=model_config.s3_credential_path, + cache_dir=model_config.cache_dir, + rank_sync=download_rank_sync, + ) + + for key, value in llama_params.items(): + # Override the default model configuration with the parameters from the checkpoint + setattr(model_config, key, value) + + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + + # If the checkpoint backend is different from the model backend, convert the checkpoint + # to be compatible with the model backend + # If shard_checkpoint is True, the loaded checkpoint is the whole model checkpoint (will be sharded later) + # instead of a tensor-parallel sharded checkpoint + llm_checkpoint = maybe_convert_checkpoint_to_backend( + llm_checkpoint, + target_backend=model_config.backend, + model_config=model_config, + tensor_parallel_size=tensor_parallel_size if not shard_checkpoint else 1, + is_tensor_parallel_shard=tensor_parallel_size > 1 and not shard_checkpoint, + ) + if model_config.vision_encoder is not None: + # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` + # and `checkpoint['mm_projector']` are both for those weights + # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights + if "vision_encoder" in checkpoint: + log.info("Using pretrained vision_encoder") + vit_checkpoint = checkpoint["vision_encoder"] + else: + log.info("Using fine-tuned vision_encoder") + vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") + vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") + if "mm_projector" in checkpoint: + log.info("Using pretrained mm_projector") + projector_checkpoint = checkpoint["mm_projector"] + else: + log.info("Using fine-tuned mm_projector") + projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") + projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") + assert ( + len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 + ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." + + tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.info(f"Setting torch default dtype to {precision}") + + # if model_config.insert_medusa_head: + # model = TransformerMedusa( + # params=model_config, + # model_parallel=model_parallel, + # tokenizer_config=tokenizer_config, + # init_weights=train_from_scratch, + # ) + # else: + model = Transformer( + params=model_config, + model_parallel=model_parallel, + tokenizer_config=tokenizer_config, + init_weights=train_from_scratch, + ) + model_kwargs = {} + # [Optional] Initialize vision encoder and multimodal projector (for vision-language tasks) + if model_config.vision_encoder is not None: + assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." + vit_config = get_vit_config(model_config.vision_encoder) + vision_encoder = VisionTransformer.build( + vit_config, + hidden_dropout=model_config["hidden_dropout"], + attention_dropout=model_config["attention_dropout"], + set_parallel_mode=model_config["set_parallel_mode"], + model_parallel=model_parallel, + attention_tp=tensor_parallel_size > 1, + ) + + mm_projector = MultimodalProjector( + mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] + ) + model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) + + # Perform vocab expansion + if tokenizer.vocab_size > model.vocab_size: + log.info(f"Expanding vocab size to {tokenizer.vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer.training_type == "text_to_video") + model.expand_vocab(tokenizer.vocab_size, init_method="gaussian", expand_output_layer=expand_output_layer) + + if not train_from_scratch: + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if model_parallel is not None: + assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + if model_config.vision_encoder is not None: + # Shard vision encoder and multimodal projector weights + vit_checkpoint = obtain_tensor_parallel_state_dict( + vit_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=vit_config, + ) + + if model_config.vision_encoder is not None: + # Take the LLM weights (starting with "model.") from the VLM checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + if model_config.vision_encoder is not None: + # Load vision encoder and multimodal projector weights + vision_encoder.load_state_dict(vit_checkpoint) + mm_projector.load_state_dict(projector_checkpoint) + if model_config.vision_encoder_in_channels != 3: + vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) + + model = model.to(precision) # ensure model parameters are in the correct precision + log.info(f"Model config: {model_config}") + + # if model_config.insert_medusa_head: + # from projects.cosmos.ar.v1.model_medusa import LlamaMedusa + + # model_class = LlamaMedusa + # else: + model_class = AutoRegressiveTrainingModel + if model_config.fsdp_enabled: + raise NotImplementedError("FSDP is not implemented for AutoRegressiveTrainingModel") + # model_kwargs["fsdp_checkpointer"] = fsdp_checkpointer + # model_class = FSDPLlama + return model_class(model, tokenizer, model_config, **model_kwargs) + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + top_k: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + logit_clipping_range: list = [], + seed: int = 0, + images: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. If not None, top-k sampling will be used instead of top-p sampling. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." + if top_p is not None: + log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + self.model.set_inference_flag(True) + misc.set_random_seed(seed) + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.info( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + if self.config.backend == "transformer_engine": + self.inference_params = InferenceParams( + max_batch_size=params.max_batch_size, max_sequence_length=params.max_seq_len + ) + + # Calculate Prompt Lengths + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = params.max_seq_len + assert ( + max_gen_len + max_prompt_len <= total_len + ), f"max_gen_len + max_prompt_len={max_gen_len + max_prompt_len} exceeds max_seq_len={total_len}" + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + + # Fill tokens tensor with prompt tokens + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + + # Flag to check if image embeddings have been passed to the model - we only need to pass them once + # since we have KV cache. + passed_image_embeddings = False + + # If all prompts are of max length, compute initial logits and logprobs + if min_prompt_len == total_len: + input_pos = torch.arange(tokens.shape[1], dtype=torch.long, device="cuda") + if images is None: + logits = self.model.forward( + tokens=tokens, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + else: + token_embeddings = self.embed_vision_language_features(tokens, images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + passed_image_embeddings = True + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens), dtype=torch.long, device="cuda") + + # Main generation loop + log.info(f"Start generating the next {total_len - min_prompt_len} tokens. This will take a while..") + for cur_pos in range(min_prompt_len, total_len): + input_pos = torch.arange(prev_pos, cur_pos, dtype=torch.long, device="cuda") + if images is not None and not passed_image_embeddings: + token_embeddings = self.embed_vision_language_features(tokens[:, prev_pos:cur_pos], images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + passed_image_embeddings = True + else: + logits = self.model.forward( + tokens=tokens[:, prev_pos:cur_pos], + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + + if self.config.backend == "transformer_engine": + self.inference_params.sequence_len_offset += logits.shape[1] + + # Apply temperature scaling and nucleus sampling + if len(logit_clipping_range) > 0: + min_clip_index = logit_clipping_range[0] + max_clip_index = logit_clipping_range[1] + logits_clipped = logits[:, :, min_clip_index:max_clip_index] + else: + logits_clipped = logits + min_clip_index = 0 + + if temperature > 0: + if top_p is not None: + next_token = sample_top_p(logits_clipped, temperature=temperature, top_p=top_p)[0] + else: + next_token = sample_top_k(logits_clipped, temperature=temperature, top_k=top_k)[0] + else: + next_token = torch.argmax(logits_clipped[:, -1, :], dim=-1) + + next_token += min_clip_index + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + # Calculate log probabilities if requested + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + # Check if end-of-sequence token is reached + eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) + prev_pos = cur_pos + # Break the loop if all sequences have reached an end-of-sequence token + if all(eos_reached): + log.info(f"Reach end of sequence, current pos: {cur_pos}; maximum pos: {total_len}") + break + # Convert log probabilities to list if required + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + + # Process and collect the output tokens and log probabilities + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + self.model.set_inference_flag(False) + return (out_tokens, out_logprobs if logprobs else None) + + @torch.no_grad() + def fast_generate( + self, + prompt_tokens: List[List[int]] | torch.Tensor, + max_gen_len: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + echo: bool = False, + seed: int = 0, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + compile_decode: bool = True, + compile_prefill: bool = False, + verbose: bool = True, + stop_tokens: Optional[Set[int]] = None, + ): + """ + Fast auto-regressive generation. Currently only supports input batch size = 1. + Args: + prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. + num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. + seed (int, optional): Random seed for reproducibility. Defaults to 0. + compile_decode (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. + """ + assert ( + top_p is None or top_k is None + ), f"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if top_p is not None: + log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + # torch._functorch.config.enable_autograd_cache = True + + self.model.set_inference_flag(True) + misc.set_random_seed(seed) + + assert not logprobs, "logprobs are not supported for fast_generate yet" + # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags + if compile_decode and not getattr(self, "inference_decode_compiled", False): + self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + self.inference_decode_compiled = True + log.critical("Compiled decode_one_token function. Note: the first run will be slower due to compilation") + if compile_prefill and not getattr(self, "inference_prefill_compiled", False): + self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + self.inference_prefill_compiled = True + log.critical("Compiled prefill function. Note: the first run will be slower due to compilation") + + if not hasattr(self, "decode_one_token"): + self.decode_one_token = decode_one_token + if not hasattr(self, "prefill"): + self.prefill = prefill + + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.info( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + if isinstance(prompt_tokens, list): + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") + if prompt_tokens.ndim == 1: + prompt_tokens = prompt_tokens.view(1, -1) + else: + assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" + batch_size, prompt_len = prompt_tokens.shape + total_len = min(params.max_seq_len, max_gen_len + prompt_len) + if max_gen_len + prompt_len > params.max_seq_len: + log.warning( + f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" + ) + max_gen_len = params.max_seq_len - prompt_len + + if context_mask is not None: + context_mask = context_mask.to(dtype=torch.bool) + if context_mask.ndim == 2: + assert ( + context_mask.shape[0] == batch_size + ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" + # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] + context_mask = context_mask.view(batch_size, 1, 1, -1) + + if num_gen_seq > 1: + assert ( + batch_size == 1 + ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" + log.critical(f"Generating {num_gen_seq} sequences with the same prompt") + assert ( + num_gen_seq <= params.max_batch_size + ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" + # repeat the prompt tokens for num_gen_seq times + prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) + assert prompt_tokens.shape == ( + num_gen_seq, + prompt_len, + ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" + batch_size = len(prompt_tokens) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) + empty[:, :prompt_len] = prompt_tokens + seq = empty + input_pos = torch.arange(0, prompt_len, device="cuda") + + if verbose: + prefill_start = time.time() + + # Prefill stage + next_token = self.prefill( + self.model, + prompt_tokens, + input_pos=input_pos, + temperature=temperature, + top_k=top_k, + top_p=top_p, + context=context, + context_mask=context_mask, + action=action, + ) + if verbose: + prefill_time = time.time() - prefill_start + + seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) + input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") + stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens + stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") + + if verbose: + decode_start = time.time() + # Decode stage + generated_tokens = decode_n_tokens( + self.model, + next_token.view(batch_size, -1), + input_pos, + max_gen_len - 1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, + decode_one_token_function=self.decode_one_token, + context=context, + context_mask=context_mask, + action=action, + ) + gen_len = len(generated_tokens) + if verbose: + decode_time = time.time() - decode_start + prefill_throughput = prompt_len / prefill_time + decode_throughput = gen_len / decode_time + log.info(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") + log.info(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") + + generated_tokens = torch.cat(generated_tokens, dim=1) + + log.critical(f"generated_tokens: {generated_tokens.shape}") + seq = seq[:, : prompt_len + 1 + gen_len] + seq[:, prompt_len + 1 :] = generated_tokens + if not echo: + seq = seq[:, prompt_len:] + return seq, None + + def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: + """ + Embed vision and language features into a combined representation. + + Args: + input_ids (torch.Tensor): Input token IDs. + images (torch.tensor): Input images. + + Returns: + torch.Tensor: Combined vision-language features. + + Raises: + AssertionError: If vision encoder or mm projector is not initialized, + or if dimensions mismatch. + """ + # Ensure vision encoder and mm projector are initialized + assert self.vision_encoder is not None + assert self.mm_projector is not None + + # Get image token ID and validate it + image_token_id = self.vision_encoder.image_token_id + assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" + + # Identify text and image locations in the input + text_locations = input_ids != image_token_id + image_locations = input_ids == image_token_id + + # Process text features + text_features = self.model.tok_embeddings(input_ids[text_locations]) + + # Process image features + images = images.to(device=text_features.device, dtype=text_features.dtype) + vit_outputs = self.vision_encoder(images) + image_features = self.mm_projector(vit_outputs) + + # Get dimensions + B, seq_len = input_ids.shape + N_total = B * seq_len + N_txt, D_txt = text_features.shape + N_img, N_patch, D_img = image_features.shape + + # Reshape image features + image_features = image_features.reshape(N_img * N_patch, D_img) + + # Validate dimensions + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + N_total == N_txt + N_img * N_patch + ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" + + # Combine text and image features + combined_features = torch.empty( + (B, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + combined_features[image_locations, :] = image_features + + return combined_features + + def on_after_backward(self, iteration: int = 0): + """ + Hook after loss.backward() is called. + + This method is called immediately after the backward pass, allowing for custom operations + or modifications to be performed on the gradients before the optimizer step. + + So far, this method is used to all-reduce layernorm grads for tensor/sequence parallelism. + + Args: + iteration (int): Current iteration number. + """ + for module in self.children(): + if hasattr(module, "on_after_backward"): + module.on_after_backward(iteration) + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + for module in self.children(): + if hasattr(module, "on_before_zero_grad"): + module.on_before_zero_grad(optimizer, scheduler, iteration) + + @property + def fsdp_wrap_block_cls(self): + """ + Return the transformer block class to wrap with FSDP. + """ + if self.config.backend == "pytorch": + return TransformerBlock + elif self.config.backend == "transformer_engine": + return TransformerBlockTE + else: + raise ValueError(f"Unknown backend: {self.config.backend}") + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if strict: + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + return _IncompatibleKeys(actual_missing_keys, unexpected_keys) + + +# class FSDPLlama(Llama): +# def __init__( +# self, model: Transformer, tokenizer: DiscreteMultimodalTokenizer, config: ModelConfig, fsdp_checkpointer: Any +# ): +# self.fsdp_checkpointer = fsdp_checkpointer +# super().__init__(model, tokenizer, config) +# self.set_up_fsdp() + +# def set_up_fsdp(self): +# """ +# Set up FSDP for the model. +# """ + +# model = self.model +# # detach the model from the parent class +# self.model = None +# del self.model + +# # build FSDP sharding strategy and device_mesh +# strategy = { +# "full": ShardingStrategy.FULL_SHARD, +# "hybrid": ShardingStrategy.HYBRID_SHARD, +# "none": ShardingStrategy.NO_SHARD, +# }[self.config.fsdp["sharding_strategy"]] +# log.critical(f"Using {strategy} sharding strategy for FSDP") + +# if self.config.fsdp["sharding_strategy"] == "hybrid": +# sharding_group_size = self.config.fsdp["sharding_group_size"] +# device_mesh = hsdp_device_mesh( +# sharding_group_size=sharding_group_size, +# ) +# else: +# device_mesh = hsdp_device_mesh( +# sharding_group_size=distributed.get_world_size(), +# ) +# parallel_state.fsdp_device_mesh = device_mesh + +# if distributed.get_rank() == 0: +# # only load model in rank0 to reduce network traffic and sync later +# self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) + +# if not hasattr(self, "fsdp_wrap_block_cls"): +# raise ValueError("Networks does not have fsdp_wrap_block_cls attribute, please check the net definition") +# fsdp_blocks_cls = self.fsdp_wrap_block_cls +# fsdp_blocks_cls = ( +# list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] +# ) +# log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") + +# log.critical(f"Using wrap policy {self.config.fsdp['policy']}") + +# if self.config.fsdp["policy"] == "size": +# # Size based policy won't work for transformers because the tokenizers need to be accessible at multiple +# # layers (input / output). This is handled by this sharding strategy. +# min_num_params = self.config.fsdp["min_num_params"] +# log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") +# log.info("If using a Transformer model. Please use the transformer wrap policy.") +# wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) +# else: +# # Use the auto wrap policy for transformers +# wrap_policy = functools.partial( +# transformer_auto_wrap_policy, +# transformer_layer_cls=set(fsdp_blocks_cls), +# ) +# tensor_kwargs = {"device": "cuda", "dtype": model.precision} + +# # Wrap the model with FSDP and attach it back to this class +# self.model = FSDP( +# model.to(**tensor_kwargs), +# sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync +# sharding_strategy=strategy, +# auto_wrap_policy=wrap_policy, +# device_id=torch.cuda.current_device(), +# device_mesh=device_mesh, +# limit_all_gathers=True, +# use_orig_params=True, # Do not flatten the parameter structure. Useful for layer_dependent lrs, etc. +# ) + +# if self.config.act_ckpt_enabled: +# # Apply activation checkpointing +# apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) + +# # Clean up memory +# torch.cuda.empty_cache() + +# def state_dict(self) -> Dict: +# raise NotImplementedError("FSDPLlama does not support state_dict, use state_dict_model and FSDPCheckpointer") + +# @misc.timer("FSDP state_dict_model") +# def state_dict_model(self) -> Dict: +# """ +# Get the model state_dict for checkpoint saving in the FSDP mode. +# """ +# with FSDP.summon_full_params(self.model): +# pass +# with FSDP.state_dict_type( +# self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) +# ): +# model_state = self.model.state_dict() +# # No support for EMA yet. +# ema_model_state = None +# return { +# "model": model_state, +# "ema": ema_model_state, +# } + +# def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: +# raise NotImplementedError("FSDPLlama does not support load_state_dict, using FSDPCheckpointer") + +# def init_optimizer_scheduler( +# self, optimizer_config: LazyDict, scheduler_config: LazyDict +# ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: +# """ +# Initialize the optimizer and scheduler for FSDP model. + +# Args: +# optimizer_config (LazyDict): The optimizer configuration. +# scheduler_config (LazyDict): The scheduler configuration. + +# Returns: +# tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: The optimizer and scheduler. +# """ +# optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) +# self.fsdp_checkpointer.load_optim_scheduler_during_init( +# self.model, +# optimizer, +# scheduler, +# ) +# return optimizer, scheduler + +# def get_ckpt_postfix(self) -> Tuple[str, int]: +# """Get the checkpoint file postfix. check FSDPCheckpointer for more details + +# Returns: +# postfix (str): The postfix of the checkpoint file. +# replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ +# we will not save each ema model in each GPU, \ +# ema model with same rate will be saved once +# total_ema_num (int) +# """ +# replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() +# # !!! EMA is not supported +# if replicate_idx == 0: +# return "", 0, shard_idx, 0 +# return "", replicate_idx, shard_idx, 0 diff --git a/cosmos_predict1/autoregressive/training/modules/attention.py b/cosmos_predict1/autoregressive/training/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1b6c81cb10a34033da0ccf794991e89d2501df --- /dev/null +++ b/cosmos_predict1/autoregressive/training/modules/attention.py @@ -0,0 +1,734 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Optional, Tuple, Union + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from torch import nn +from torch.distributed import _functional_collectives as funcol +from transformer_engine.pytorch.attention import _SplitAlongDim, apply_rotary_pos_emb, check_set_window_size +from transformer_engine.pytorch.constants import AttnBiasTypes +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.module.linear import Linear as LinearTE +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE + +from cosmos_predict1.autoregressive.modules.embedding import RotaryPositionEmbedding +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, RowParallelLinear +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.utils.parallel import AllReduceBWDRMSNormTE + + +class GQA(nn.Module): + """ + Grouped Query Attention (GQA) with KV cache (only supported for inference). + """ + + def __init__( + self, + n_heads: int, + n_kv_heads: Union[int, None], + dim: int, + max_batch_size: int, + max_seq_len: int, + context_dim: Optional[int] = None, + inference: bool = True, + flash_attn: bool = True, + use_qk_normalization: bool = False, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + set_parallel_mode: Optional[bool] = False, + model_parallel: Optional[ModelParallelConfig] = None, + attention_tp: Optional[bool] = False, + causal_mask: Optional[bool] = True, + head_dim: Optional[int] = None, + fuse_qkv: bool = False, + precision: str = "bfloat16", + attention_type: str = "self", + ): + """ + Initializes the GQA module. + + Args: + n_heads (int): The number of attention heads. + n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. + dim (int): The dimensionality of the input and output. + max_batch_size (int): The maximum batch size. + max_seq_len (int): The maximum sequence length. + context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. + inference (bool, optional): Whether the model is in inference mode. Defaults to True. + flash_attn (bool, optional): Whether to use Flash attention. Defaults to True. + use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. + norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". + norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. + attention_dropout (float, optional): Dropout rate for attention. Defaults to 0.0. + tp_group (int, optional): The tensor parallel group. + set_parallel_mode (bool, optional): Whether to set parallel mode which enables parallel linear. Defaults to False. + model_parallel (ModelParallelConfig, optional): The Megatron model parallel configuration. + attention_tp (bool, optional): Whether to use tensor parallelism for attention layers. Defaults to False. + causal_mask (bool, optional): Whether to use causal mask. Defaults to True. + head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. + fuse_qkv (bool, optional): Whether to fuse QKV projections. Defaults to False. + precision (str, optional): The precision of the model. Defaults to "bfloat16". + attention_type (str, optional): The type of attention. Defaults to "self". + """ + super().__init__() + assert attention_type in ["self", "cross", "full"], f"Invalid attention type: {attention_type}" + self.attention_type = attention_type + self.model_parallel = model_parallel + if self.model_parallel and self.model_parallel.tensor_model_parallel_size > 1 and attention_tp: + self.tp_size = self.model_parallel.tensor_model_parallel_size + else: + self.tp_size = 1 + + context_dim = dim if context_dim is None else context_dim + + self.dim = dim + self.context_dim = context_dim + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_kv_heads = self.n_kv_heads // self.tp_size + self.n_local_heads = n_heads // self.tp_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads if head_dim is None else head_dim + assert flash_attn, "Flash attention is required." + self.attention_dropout = attention_dropout + self.causal_mask = causal_mask + self.fuse_qkv = fuse_qkv + self.precision = precision + + if fuse_qkv: + assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" + self.total_head_dim = (n_heads + 2 * self.n_kv_heads) * self.head_dim + self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim + + if set_parallel_mode and attention_tp and not inference: + kwargs = {"bias": False, "init_method": lambda x: x, "config": self.model_parallel} + # Using column and row parallel linear layers + if fuse_qkv: + self.wqkv = ColumnParallelLinear(dim, self.total_head_dim, **kwargs) + else: + self.wq = ColumnParallelLinear(dim, n_heads * self.head_dim, **kwargs) + self.wk = ColumnParallelLinear(context_dim, self.n_kv_heads * self.head_dim, **kwargs) + self.wv = ColumnParallelLinear(context_dim, self.n_kv_heads * self.head_dim, **kwargs) + + # Linear layer for output projection + self.wo = RowParallelLinear( + n_heads * self.head_dim, dim, input_is_parallel=True, skip_bias_add=True, **kwargs + ) + + else: + # Linear layers for query, key, and value projections + if fuse_qkv: + self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) + else: + self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) + + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + if inference and self.attention_type == "self": + # Cache for key and value tensors + self.init_kv_cache() + + # QK normalization layers + if use_qk_normalization: + assert n_heads % self.tp_size == 0, "n_heads must be divisible by tensor_model_parallel_size" + assert self.n_kv_heads % self.tp_size == 0, "n_kv_heads must be divisible by tensor_model_parallel_size" + self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + self.use_qk_normalization = use_qk_normalization + self.inference = inference + + if fuse_qkv: + # Register hook to load fused QKV weights + self._register_load_state_dict_pre_hook(self.load_hook) + + self.to(dtype=getattr(torch, self.precision)) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def init_kv_cache(self, dtype=None): + cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) + if dtype is None: + dtype = getattr(torch, self.precision) + if self.attention_type == "self": + self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() + self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() + + def set_inference_flag(self, flag): + self.inference = flag + if flag and self.attention_type == "self": + if self.cache_k is None or self.cache_v is None: + self.init_kv_cache() + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbedding, + input_pos: torch.Tensor, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ): + """ + Forward pass of GQA. + + Args: + x: The input tensor of shape (batch_size, seq_len, dim). + rope: The rotary positional embedding module. + input_pos: The starting position of the current sequence. + mask: The attention mask tensor. + context: The context tensor of shape (batch_size, context_len, dim). + + Returns: + The output tensor after applying GQA. + """ + bsz, seqlen, _ = x.shape + + # Use one single module to handle both self-attn and cross-attn + context = x if context is None else context + context_len = seqlen if context is None else context.shape[1] + + if self.fuse_qkv: + q_size = self.n_local_heads * self.head_dim + kv_size = self.n_local_kv_heads * self.head_dim + xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + else: + # Compute query, key, and value projections + xq = self.wq(x) + xk, xv = self.wk(context), self.wv(context) + + # Reshape projections + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + + # QK normalization + if self.use_qk_normalization: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Apply rotary positional embeddings to queries and keys + # Only apply RoPE to self-attention! + if self.attention_type in ["self", "full"]: + xq, xk = rope(xq, xk, input_pos, seqlen) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_kv_heads, cache_len + context_len, head_dim) + # xv: (bs, n_kv_heads, cache_len + context_len, head_dim) + if self.inference and self.attention_type == "self": + # Update cache with current key and value tensors + assert input_pos is not None + self.cache_k[:bsz, :, input_pos] = xk + self.cache_v[:bsz, :, input_pos] = xv + keys, values = ( + self.cache_k[:bsz, :, :], + self.cache_v[:bsz, :, :], + ) + else: + keys, values = xk, xv + + # Repeat keys and values if necessary + keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + + if self.attention_type == "self" and self.causal_mask: + # During inference, `is_causal` should be set to False when KV cache is pre-computed and used, + # since the masking is handled outside this attention module. + # During training, `is_causal` should be set to None to use the default behavior of FlashAttention. + is_causal = False if self.inference else None + else: + # This is used for full-attention transformer (e.g., ViT) + # also for the cross-attn, it's always full-attn w/o causal + is_causal = False + output = scaled_dot_product_attention( + xq, + keys, + values, + head_dim=self.head_dim, + mask=mask, + is_causal=is_causal, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + output = output.view(bsz, seqlen, -1) + output = self.wo(output) + + if self.inference and self.tp_size > 1: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + def init_weights(self, init_std: float): + """ + Initializes the weights of all modules. + """ + if self.fuse_qkv: + nn.init.trunc_normal_(self.wqkv.weight, mean=0.0, std=0.02) + else: + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + if self.use_qk_normalization: + torch.nn.init.ones_(self.q_norm.weight) + torch.nn.init.ones_(self.k_norm.weight) + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_dim: int, + mask: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + dropout_p: float = 0.0, +) -> torch.Tensor: + """ + PyTorch's native implementation of Flash Attention 2. + + If `is_causal` is given, then the causal attention mask is applied accordingly: + - If `is_causal` is True, the standard upper-left causal attention masking is applied. + - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is + provided (i.e., `mask is not None`). + + If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied + based on the provided mask tensor: + - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, + leading to the standard upper-left causal attention masking. + - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, + and `is_causal` is set to False. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + head_dim (int): Dimension of each attention head + mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. + dropout_p (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + torch.Tensor: Output tensor after applying scaled dot-product attention + """ + scale = 1.0 / math.sqrt(head_dim) + if is_causal is None: + is_causal = mask is None + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + ) + return y.transpose(1, 2).contiguous() + + +def enable_different_context_dim_in_te_ca( + te_mha_module, + context_dim, + args, +): + """ + Hijacks the MultiheadAttention (MHA) module from TransformerEngine (TE) to use a different context-dim for KV calculation. + """ + self = te_mha_module + + common_gemm_kwargs = { + "fuse_wgrad_accumulation": args["fuse_wgrad_accumulation"], + "tp_group": self.tp_group, + "tp_size": self.tp_size, + "get_rng_state_tracker": self.get_rng_state_tracker, + "sequence_parallel": self.sequence_parallel, + "params_dtype": self.params_dtype, + } + + self.key_value = LinearTE( + context_dim, + 2 * self.hidden_size_kv, + init_method=None, + bias=args["bias"], + return_bias=False, + parallel_mode="column" if args["set_parallel_mode"] else None, + parameters_split=("key", "value") if not args["fuse_qkv_params"] else None, + **common_gemm_kwargs, + ) + + +def enable_qk_normalization_in_te_mha( + te_mha_module, + norm_eps: float, + is_self_attn: bool = True, +): + """ + Hijacks the MultiheadAttention (MHA) module from TransformerEngine (TE) to use our `te_mha_forward_with_qk_norm`. + The `te_mha_forward_with_qk_norm` function is just a copy of the TE MHA's forward function (source code at + https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) with the addition + of several lines of code for the QK normalization operations. + """ + self = te_mha_module + + # First, we add the QK norm layers (RMSNorm class) to the TE's MHA module in advance for our custom forward function. + if is_self_attn: + common_kwargs = dict( + eps=norm_eps, + device=self.layernorm_qkv.layer_norm_weight.device, + sequence_parallel=self.layernorm_qkv.sequence_parallel, + params_dtype=self.layernorm_qkv.layer_norm_weight.dtype, + zero_centered_gamma=self.layernorm_qkv.zero_centered_gamma, + ) + else: + common_kwargs = dict( + eps=norm_eps, + device=self.layernorm_query.query_weight.device, + sequence_parallel=self.layernorm_query.sequence_parallel, + params_dtype=self.layernorm_query.query_weight.dtype, + zero_centered_gamma=self.layernorm_query.zero_centered_gamma, + ) + if parallel_state.model_parallel_is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + self.q_norm = AllReduceBWDRMSNormTE( + self.hidden_size_per_attention_head, process_group=tp_group, **common_kwargs + ) + self.k_norm = AllReduceBWDRMSNormTE( + self.hidden_size_per_attention_head, process_group=tp_group, **common_kwargs + ) + else: + self.q_norm = RMSNormTE(self.hidden_size_per_attention_head, **common_kwargs) + self.k_norm = RMSNormTE(self.hidden_size_per_attention_head, **common_kwargs) + + # Second, we define the custom forward function for the TE's MHA module, with the QK normalization operations. + def te_mha_forward_with_qk_norm( + hidden_states: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + encoder_output: Optional[torch.Tensor] = None, + attn_mask_type: Optional[str] = None, + window_size: Optional[Tuple[int, int]] = None, + is_first_microbatch: Optional[bool] = None, + checkpoint_core_attention: bool = False, + inference_params: Optional[Any] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + fast_zero_fill: bool = True, + ) -> Tuple[Union[torch.Tensor, None], ...]: + """ + Forward propagation for MultiheadAttention layer. + + """ + # hidden_states: [sq, b, h] + + if attn_mask_type is None: + attn_mask_type = self.attn_mask_type + if window_size is None: + window_size = self.window_size + window_size = check_set_window_size(attn_mask_type, window_size) + + if "padding" in attn_mask_type and attention_mask is not None: + for mask in attention_mask: + assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" + + assert ( + core_attention_bias_type in AttnBiasTypes + ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + + # ================================================= + # Pre-allocate memory for key-values for inference + # ================================================= + + if inference_params and self.layer_number is not None: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size, hidden_states.dtype) + inference_value_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size, hidden_states.dtype) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + else: + ( + inference_key_memory, + inference_value_memory, + ) = inference_params.key_value_memory_dict[self.layer_number] + + # ====================== + # Query, Key, and Value + # ====================== + + # fp8_mha = FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + # fp8_kwargs = {"fp8_output": fp8_mha and rotary_pos_emb is None} + fp8_kwargs = {} + + layernorm_output = None + if self.attention_type == "self": + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] + layernorm_qkv_outputs = self.layernorm_qkv( + hidden_states, is_first_microbatch=is_first_microbatch, **fp8_kwargs + ) + mixed_x_layer = layernorm_qkv_outputs + + num_queries_per_key_value = self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition + # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + (num_queries_per_key_value + 2), + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ) + # split along third last dimension + split_dim = -3 + + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, (np/ng + 2), ng, hn] + # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) + ) + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + elif self.attention_type == "cross": + # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] + mixed_kv_layer = self.key_value(encoder_output, is_first_microbatch=is_first_microbatch, **fp8_kwargs) + + # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + 2 * self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ) + # split along second last dimension + split_dim = -2 + + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # mixed_kv_layer --> 2 [sk, b, ng, hn] + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, + ) + key_layer, value_layer = ( + x.reshape( + x.size(0), + x.size(1), + -1, + self.hidden_size_per_attention_head, + ) + for x in (key_layer, value_layer) + ) + + # Attention head [sq, b, h] --> [sq, b, hp] + layernorm_query_outputs = self.layernorm_query( + hidden_states, is_first_microbatch=is_first_microbatch, **fp8_kwargs + ) + query_layer = layernorm_query_outputs + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query_layer = query_layer.view(*new_tensor_shape) + + # ====================================================== + # Apply QK normalization (RMSNorm) + # ====================================================== + + # Must use torch.reshape to flatten the tensor, otherwise an error will be triggered in TE's RMSNorm module. + query_layer = self.q_norm(query_layer.reshape(-1, self.hidden_size_per_attention_head)).view(query_layer.shape) + key_layer = self.k_norm(key_layer.reshape(-1, self.hidden_size_per_attention_head)).view(key_layer.shape) + + # ====================================================== + # Apply relative positional encoding (rotary embedding) + # ====================================================== + + if rotary_pos_emb is not None: + assert not isinstance(query_layer, Float8Tensor) and not isinstance( + key_layer, Float8Tensor + ), "RoPE is not supported for Float8Tensors!" + # duplicate the pos_emb for self attention + if not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + q_pos_emb, k_pos_emb = rotary_pos_emb + + # adjust key and value for inference + if inference_params is not None: + if self.qkv_format == "sbhd": + sequence_length = key_layer.size(0) + elif self.qkv_format == "bshd": + sequence_length = key_layer.size(1) + else: + raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length + + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + + # =========================== + # Core attention computation + # =========================== + context_layer = self.core_attention( + query_layer, + key_layer, + value_layer, + qkv_format=self.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attention_mask=attention_mask, + attn_mask_type=attn_mask_type, + window_size=window_size, + checkpoint_core_attention=checkpoint_core_attention, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + alibi_slopes=alibi_slopes, + fast_zero_fill=fast_zero_fill, + inference_params=inference_params, + ) + + # =================== + # Output. [sq, b, h] + # =================== + + projection_output = self.proj( + context_layer, + is_first_microbatch=is_first_microbatch, + ) + + if self.return_bias: + attention_output, attention_bias = projection_output + else: + attention_output, attention_bias = projection_output, None + + outputs = (attention_output,) + if self.return_bias: + outputs += (attention_bias,) + if self.input_layernorm and self.return_layernorm_output: + outputs += (layernorm_output,) + return outputs if len(outputs) > 1 else outputs[0] + + # Finally, we replace the forward method of given TE's MHA module with our custom forward function. + self.forward = te_mha_forward_with_qk_norm + + +def create_group_causal_attn_mask( + num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal" +) -> torch.Tensor: + """ + Creates a group-based attention mask for scaled dot-product attention with two modes: + 'causal' and 'group_diagonal'. + + Parameters: + - num_temporal_groups (int): The number of temporal groups (e.g., frames in a video sequence). + - num_query_per_group (int): The number of query tokens per temporal group. (e.g., latent tokens in a frame, H x W). + - num_key_per_group (int): The number of key tokens per temporal group. (e.g., action tokens per frame). + - mode (str): The mode of the attention mask. Options are: + - 'causal': Query tokens can attend to key tokens from the same or previous temporal groups. + - 'group_diagonal': Query tokens can attend only to key tokens from the same temporal group. + + Returns: + - attn_mask (torch.Tensor): A boolean tensor of shape (L, S), where: + - L = num_temporal_groups * num_query_per_group (total number of query tokens) + - S = num_temporal_groups * num_key_per_group (total number of key tokens) + The mask indicates where attention is allowed (True) and disallowed (False). + + Example: + Input: + num_temporal_groups = 3 + num_query_per_group = 4 + num_key_per_group = 2 + Output: + Causal Mask Shape: torch.Size([12, 6]) + Group Diagonal Mask Shape: torch.Size([12, 6]) + if mode='causal': + tensor([[ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True]]) + + if mode='group_diagonal': + tensor([[ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + [False, False, False, False, True, True]]) + + """ + assert mode in ["causal", "group_diagonal"], f"Mode {mode} must be 'causal' or 'group_diagonal'" + + # Total number of query and key tokens + total_num_query_tokens = num_temporal_groups * num_query_per_group # Total number of query tokens (L) + total_num_key_tokens = num_temporal_groups * num_key_per_group # Total number of key tokens (S) + + # Generate time indices for query and key tokens (shape: [L] and [S]) + query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group) # Shape: [L] + key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group) # Shape: [S] + + # Expand dimensions to compute outer comparison + query_time_indices = query_time_indices.unsqueeze(1) # Shape: [L, 1] + key_time_indices = key_time_indices.unsqueeze(0) # Shape: [1, S] + + if mode == "causal": + # Causal Mode: Query can attend to keys where key_time <= query_time + attn_mask = query_time_indices >= key_time_indices # Shape: [L, S] + elif mode == "group_diagonal": + # Group Diagonal Mode: Query can attend only to keys where key_time == query_time + attn_mask = query_time_indices == key_time_indices # Shape: [L, S] + + assert attn_mask.shape == (total_num_query_tokens, total_num_key_tokens), "Attention mask shape mismatch" + return attn_mask diff --git a/cosmos_predict1/autoregressive/training/networks/transformer.py b/cosmos_predict1/autoregressive/training/networks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a69228fe6c46f019d0abf2ecb9371c4bd5f57 --- /dev/null +++ b/cosmos_predict1/autoregressive/training/networks/transformer.py @@ -0,0 +1,1295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +import transformer_engine as te +from megatron.core import InferenceParams, ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region +from torch.distributed import ProcessGroup +from torch.distributed import _functional_collectives as funcol +from torch.distributed import broadcast, get_process_group_ranks +from torch.nn.modules.module import _IncompatibleKeys +from transformer_engine.pytorch.module.linear import Linear as LinearTE +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE + +from cosmos_predict1.utils import log + +_ACTION_DIM = 8 +from cosmos_predict1.autoregressive.modules.embedding import ( + RotaryPositionEmbeddingPytorch, + RotaryPositionEmbeddingPytorchV2, + RotaryPositionEmbeddingTE, + SinCosPosEmbAxisTE, + get_pos_emb_on_this_cp_rank, + get_pos_emb_on_this_sptp_rank, +) +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, TrainingVocabParallelEmbedding +from cosmos_predict1.autoregressive.modules.mlp import TrainingMLP, compute_llama3_ffn_hidden_dim +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.training.modules.attention import ( + GQA, + create_group_causal_attn_mask, + enable_different_context_dim_in_te_ca, + enable_qk_normalization_in_te_mha, +) +from cosmos_predict1.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore +from cosmos_predict1.autoregressive.utils.misc import maybe_convert_to_namespace +from cosmos_predict1.autoregressive.utils.parallel import ( + AllReduceBWDRMSNormTE, + allreduce_layernorm_grads, + sync_1d_parameters, +) + +_MLP_HIDDEN_DIM_DIVISOR = ( + 4 # hidden dim of the action embedding layer is action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR +) + +_T5_NUM_TOKENS = 512 + + +class TransformerBlock(nn.Module): + """ + A single transformer block consisting of an attention layer and a feed-forward layer. + """ + + def __init__(self, layer_id: int, model_parallel: Optional[ModelParallelConfig] = None, args=None): + """ + Initializes the TransformerBlock module. + + Args: + layer_id: The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + super().__init__() + args = maybe_convert_to_namespace(args) + attention_args = { + "n_heads": args["n_heads"], + "n_kv_heads": args["n_kv_heads"], + "dim": args["dim"], + "context_dim": None, + "max_batch_size": args["max_batch_size"], + "max_seq_len": args["max_seq_len"], + "inference": args["inference"], + "flash_attn": args["flash_attn"], + "use_qk_normalization": args["use_qk_normalization"], + "attention_dropout": getattr(args, "attention_dropout", 0.0), + "set_parallel_mode": args["set_parallel_mode"], + "model_parallel": model_parallel, + "attention_tp": args["attention_tp"], + "causal_mask": args["causal_mask"], + "head_dim": args["head_dim"], + "fuse_qkv": getattr(args, "fuse_qkv", False), + "precision": getattr(args, "precision", "bfloat16"), + "attention_type": getattr(args, "attention_type", "self"), + } + self.attention = GQA(**attention_args) + + self.has_cross_attention = False + self.cross_attention, self.cross_attention_norm = None, None + + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + cross_attention_args = attention_args.copy() + cross_attention_args.update( + {"context_dim": args["context_dim"], "fuse_qkv": False, "attention_type": "cross"} + ) + self.cross_attention = GQA(**cross_attention_args) + self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + self.feed_forward = TrainingMLP( + dim=args["dim"], + hidden_dim=( + compute_llama3_ffn_hidden_dim( + dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] + ) + if args["ffn_hidden_size"] is None + else args["ffn_hidden_size"] + ), + hidden_dropout=getattr(args, "hidden_dropout", 0.0), + set_parallel_mode=args["set_parallel_mode"], + model_parallel=model_parallel, + inference=args["inference"], + ) + self.layer_id = layer_id + self.num_layers = args["n_layers"] + self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). + if getattr(args, "depth_init", True): + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbeddingPytorch, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the TransformerBlock module. + + Args: + x: The input tensor. + input_pos: The position of the current sequence. Used in inference (with KV cache) only. + freqs_cis: The precomputed frequency values for rotary position embeddings. + mask: The attention mask tensor. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + The output tensor after applying the transformer block. + """ + # Apply attention and residual connection + h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) + + # If insert cross-attention, apply CA and residual connection + if self.has_cross_attention: + h = h + self.cross_attention( + self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context + ) + + # Apply feed-forward network and residual connection + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + if self.has_cross_attention: + self.cross_attention_norm.reset_parameters() + self.cross_attention.init_weights(self.weight_init_std) + # zero-init the final output layer of cross-attention + # nn.init.zeros_(self.cross_attention.wo.weight) + + +class TransformerBlockTE(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. + + Args: + layer_id (int): The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + + def __init__( + self, + layer_id: int, + args, + tp_group: Optional[ProcessGroup] = None, + set_parallel_mode: bool = False, + attn_input_format: str = "bshd", + ): + attention_args = { + "hidden_size": args["dim"], + "ffn_hidden_size": ( + compute_llama3_ffn_hidden_dim( + dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] + ) + if args["ffn_hidden_size"] is None + else args["ffn_hidden_size"] + ), + "num_attention_heads": args["n_heads"], + "bias": False, + "layernorm_epsilon": args["norm_eps"], + "hidden_dropout": getattr(args, "hidden_dropout", 0.0), + "attention_dropout": getattr(args, "attention_dropout", 0.0), + "normalization": "RMSNorm", + "activation": "swiglu", + "attn_input_format": attn_input_format, + "num_gqa_groups": args["n_kv_heads"], + "fuse_wgrad_accumulation": False, + "fuse_qkv_params": False, + "tp_group": tp_group, + "sequence_parallel": args["sequence_parallel"], + "set_parallel_mode": set_parallel_mode, + "layer_number": layer_id + 1, + "self_attn_mask_type": "causal" if args["causal_mask"] else "no_mask", + "kv_channels": args["head_dim"], # If None, te.pytorch.TransformerLayer defaults it to dim // n_heads + "layer_type": "encoder", + } + self.has_cross_attention = False + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + attention_args["layer_type"] = "decoder" + super().__init__(**attention_args) + if args["use_qk_normalization"]: + # Add QK normalization layers and replace the forward function of original Multi-Head Attention module with + # our custom one to add QK normalization operations. + enable_qk_normalization_in_te_mha(self.self_attention, norm_eps=args["norm_eps"], is_self_attn=True) + + if self.has_cross_attention: + enable_qk_normalization_in_te_mha(self.inter_attention, norm_eps=args["norm_eps"], is_self_attn=False) + + if self.has_cross_attention: + enable_different_context_dim_in_te_ca( + self.inter_attention, context_dim=args["context_dim"], args=attention_args + ) + + self.layer_id = layer_id + self.num_layers = args["n_layers"] + # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). + if getattr(args, "depth_init", True): + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + self.args = args + self.inference = args["inference"] + + def set_inference_flag(self, flag: bool): + """ + Set the inference flag for the transformer layers. + """ + self.inference = flag + + def forward( + self, + x: torch.Tensor, + rotary_pos_emb: torch.Tensor, + mask: Optional[torch.Tensor], + inference_params: Optional[InferenceParams] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Custom forward to make sure we only pass relevant arguments to the + forward pass of the `TransformerLayer`. + + Args: + x (torch.Tensor): The input tensor. + mask (Optional[torch.Tensor]): The attention mask tensor. + inference_params (Optional[InferenceParams]): Inference parameters used for caching key-value pairs in the TE backend. + It is not applicable for the PyTorch backend and should be set to None in that case. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + torch.Tensor: The output tensor after applying the transformer block + """ + + inference_params = None if not self.inference else inference_params + output = super().forward( + x, + attention_mask=mask, + rotary_pos_emb=rotary_pos_emb.to(x.device), + inference_params=inference_params, + encoder_output=context, + enc_dec_attn_mask=context_mask, + ) + return output + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + # Self Attention + attn_layer = self.self_attention.layernorm_qkv + for linear_weight in [attn_layer.query_weight, attn_layer.key_weight, attn_layer.value_weight]: + nn.init.trunc_normal_(linear_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.self_attention.proj.weight, mean=0.0, std=self.weight_init_std) + + # Cross Attention + if self.has_cross_attention: + nn.init.trunc_normal_(self.inter_attention.layernorm_query.query_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.inter_attention.key_value.key_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.inter_attention.key_value.value_weight, mean=0.0, std=0.02) + # zero-init the final output layer of cross-attention + if self.args["zero_init_cross_attn_proj"]: + nn.init.zeros_(self.inter_attention.proj.weight) + else: + nn.init.trunc_normal_(self.inter_attention.proj.weight, mean=0.0, std=self.weight_init_std) + + # RMS Normalization + for norm_weight in (self.layernorm_mlp.layer_norm_weight, self.self_attention.layernorm_qkv.layer_norm_weight): + torch.nn.init.ones_(norm_weight) + + # In the case of QK Normalization, we also reset the parameters of the QK normalization layers. + if self.args["use_qk_normalization"]: + for norm_weight in [self.self_attention.q_norm.weight, self.self_attention.k_norm.weight]: + torch.nn.init.ones_(norm_weight) + + # MLP + for linear_weight in (self.layernorm_mlp.fc1_weight, self.layernorm_mlp.fc2_weight): + nn.init.trunc_normal_(linear_weight, mean=0.0, std=self.weight_init_std) + # The fc1_weight is a fused weight of w1 and w2 in the MLP of the PyTorch backend, where w1 is initialized with + # a different std (0.02 by TorchTitan). So we re-initialize the w1 part of the fused weight below. + split_point = self.layernorm_mlp.fc1_weight.shape[0] // 2 + nn.init.trunc_normal_(self.layernorm_mlp.fc1_weight[:split_point], mean=0.0, std=0.02) + + +class Transformer(nn.Module): + """ + The Transformer network consisting of transformer blocks. + """ + + def __init__(self, params, model_parallel=None, tokenizer_config=None, init_weights: bool = True): + """ + Initializes the Transformer module. + + Args: + params: The model parameters containing hyperparameters. + model_parallel: The model parallel configuration. + tokenizer_config: The model tokenizer configuration. + init_weights (bool): Whether to initialize the weights of the transformer following + TorchTitan's Llama3 initialization scheme. + """ + super().__init__() + # Check if self.params is an OmegaConf DictConfig instance + self.params = maybe_convert_to_namespace(params) + self.vocab_size = params["vocab_size"] + self.n_layers = params["n_layers"] + self.precision = getattr(torch, params["precision"]) + self.inference = params["inference"] + self.backend = params["backend"] + self.tokenizer_config = tokenizer_config + self.model_parallel = model_parallel + self.num_video_frames = params["num_video_frames"] + + self.token_emb_dropout = nn.Dropout(getattr(params, "embedding_dropout", 0.0)) + + tp_group = self._get_tp_group() + + # Sequence parallelism requires the first dimension to be the sequence dimension. When sequence parallelism + # is enabled, we transpose the first two dimensions of the input tensor, and specify the format as "sbhd", + # (sequence, batch, head, dim). Otherwise, the input format is "bshd" (batch, sequence, head, dim). + self.attn_input_format = "bshd" if not params["sequence_parallel"] else "sbhd" + + # Token embeddings + self.tok_embeddings = self._create_token_embeddings(self.model_parallel) + self.rope_config = self._create_rope_config() + + if self.backend == "pytorch": + self._initialize_pytorch_backend(model_parallel) + elif self.backend == "transformer_engine": + self._initialize_transformer_engine_backend(tp_group) + else: + raise ValueError(f"Unknown backend: {self.backend}") + + self.output = self._create_output_projection(model_parallel) + + # Action conditioning + self.use_action_condition = getattr(params, "use_action_condition", False) + if self.use_action_condition: + self.action_dim = getattr( + params, "action_dim", _ACTION_DIM + ) # e.g., [Δx, Δy, Δz, rx, ry, rz, gripper_open, zero_pad] + self.action_embedding_dim = self.params["action_embedding_dim"] # 1024 + self.action_embedding_mode = getattr(params, "action_embedding_mode", "mlp") # Default to mlp mode + self.group_causal_mask_mode = getattr( + params, "group_causal_mask_mode", None + ) # Default to None, 'causal' or 'group_diagonal' + self.action_embedding_layers = self._create_action_projection() + + if params["sequence_parallel"]: + if model_parallel is None: + setattr(params, "sequence_parallel", False) + log.critical("model_parallel is None. Disabling sequence parallelism.") + self.sequence_parallel_enabled = False + else: + assert self.backend == "transformer_engine", f"Invalid backend: {self.backend} for sequence parallelism" + assert ( + params["tensor_model_parallel_size"] > 1 + ), f"Invalid tensor_model_parallel_size: {params['tensor_model_parallel_size']}" + self.sequence_parallel_enabled = True + else: + self.sequence_parallel_enabled = False + + if init_weights: + self.init_weights() + + # Set default value for peft_last_n_layers and peft_every_n_layers + self.peft_last_n_layers = getattr(params, "peft_last_n_layers", 0) + self.peft_every_n_layers = getattr(params, "peft_every_n_layers", 0) + if self.peft_last_n_layers > 0 or self.peft_every_n_layers > 0: + self._setup_peft() + + # Freeze network parameters for finetuning w/ cross-attention + self.has_cross_attention = getattr(params, "insert_cross_attn", False) + if self.has_cross_attention: + self.ca_every_k_layers = getattr(params, "insert_cross_attn_every_k_layers", 1) + self.finetune_layers_with_cross_attn = getattr(params, "finetune_layers_with_cross_attn", False) + self.finetune_layers_without_cross_attn = getattr(params, "finetune_layers_without_cross_attn", False) + self._setup_cross_attn_ft() + + if self.params["apply_abs_pos_emb"]: + self.pos_emb_config = self._create_abs_pos_emb_config() + self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() + if self.attn_input_format == "sbhd": + self.abs_pos_emb = self.abs_pos_emb.transpose(0, 1).contiguous() + self._broadcast_pos_emb(self.abs_pos_emb, tp_group) + + def _initialize_pytorch_backend(self, model_parallel): + self.layers = nn.ModuleList( + [ + TransformerBlock(layer_id, model_parallel, self.params).to(self.precision) + for layer_id in range(self.n_layers) + ] + ) + self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( + self.precision + ) + pytorch_rope_version = getattr(self.params, "pytorch_rope_version", "v2") + if pytorch_rope_version == "v1": + self.rope = RotaryPositionEmbeddingPytorch(**self.rope_config) + elif pytorch_rope_version == "v2": + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + self.rope = RotaryPositionEmbeddingPytorchV2( + seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config + ) + self._broadcast_pos_emb(self.rope.cos_cached, tp_group=self._get_tp_group()) + self._broadcast_pos_emb(self.rope.sin_cached, tp_group=self._get_tp_group()) + else: + raise ValueError(f"Unknown pytorch_rope_version: {pytorch_rope_version}") + + self.causal_mask = torch.tril( + torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) + ).cuda() + + def _initialize_transformer_engine_backend(self, tp_group): + self.layers = self._create_transformer_layers(tp_group) + if self.params["sequence_parallel"]: + tp_group = parallel_state.get_tensor_model_parallel_group() + self.norm = AllReduceBWDRMSNormTE( + self.params["dim"], + process_group=tp_group, + eps=self.params["norm_eps"], + sequence_parallel=True, + ).to(self.precision) + else: + self.norm = RMSNormTE(self.params["dim"], eps=self.params["norm_eps"]).to(self.precision) + self.rope, self.rotary_pos_emb = self._initialize_rope() + self._broadcast_pos_emb(self.rotary_pos_emb, tp_group) + + def _create_rope_config(self) -> Dict: + shape_map = { + "3D": self.params["video_latent_shape"], + "2D": self.params["image_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + head_dim = self.params["head_dim"] + if head_dim is None: + head_dim = self.params["dim"] // self.params["n_heads"] + return { + "dim": head_dim, + "max_position_embeddings": self.params["max_seq_len"], + "original_max_position_embeddings": self.params["original_seq_len"], + "rope_theta": self.params["rope_theta"], + "apply_yarn": self.params["apply_yarn"], + "scale": self.params["yarn_scale"], + "beta_fast": self.params["yarn_beta_fast"], + "beta_slow": self.params["yarn_beta_slow"], + "rope_dim": self.params["rope_dim"], + "latent_shape": latent_shape, + "original_latent_shape": self.params["original_latent_shape"], + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_abs_pos_emb_config(self): + shape_map = { + "3D": self.params["video_latent_shape"], + "2D": self.params["image_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + return { + "dim": self.params["dim"], + "latent_shape": latent_shape, + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_token_embeddings(self, model_parallel=None, vocab_size: int = None): + """ + Create token embeddings. + + Args: + model_parallel: The model parallel configuration. + + Returns: + nn.Module: Token embeddings module. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + if tp_size > 1: + # For inference in the PyTorch backend, we use PyTorch's allreduce (tracable) in the forward pass to enable torch.compile. + use_inference_allreduce = self.inference and self.params["backend"] == "pytorch" + emb = TrainingVocabParallelEmbedding( + vocab_size, + self.params["dim"], + init_method=lambda x: x, + config=model_parallel, + sequence_parallel=self.params["sequence_parallel"], + batch_first=not self.params["sequence_parallel"], + use_inference_allreduce=use_inference_allreduce, + ).to(self.precision) + return emb + else: + return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) + + def _create_action_projection(self): + """ + Create the action projection layer. + + Returns: + nn.Module: Action projection layer. + """ + assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" + + # This method is not working well. (option 1. default) exp102e + hidden_dim = self.action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR + action_embedding_layers = nn.Sequential( + nn.Linear(self.action_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.action_embedding_dim), + ) + + return action_embedding_layers + + def _get_tp_group( + self, + ): + """ + Get tensor parallel process group if applicable. + + Returns: + torch.distributed.ProcessGroup or None: Tensor parallel process group if tensor parallelism is enabled, else None. + """ + if self.params["tensor_model_parallel_size"] > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + log.info(f"Using tensor model parallel group: {tp_group}") + return tp_group + + return None + + def _create_transformer_layers(self, tp_group): + """ + Create the transformer layers. + + Args: + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + + Returns: + nn.ModuleList: List of transformer layers. + """ + return nn.ModuleList( + [ + TransformerBlockTE( + layer_id, + self.params, + tp_group, + set_parallel_mode=self.params["set_parallel_mode"], + attn_input_format=self.attn_input_format, + ).to(self.precision) + for layer_id in range(self.params["n_layers"]) + ] + ) + + def _create_output_projection(self, model_parallel=None, vocab_size: int = None): + """ + Create the output projection layer. + + Args: + model_parallel: The model parallel configuration. + vocab_size (int): Vocabulary size (to override the default vocab size). + Returns: + LinearTE: Output projection layer. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + if self.params["tensor_model_parallel_size"] > 1: + if self.params["backend"] == "pytorch" and self.inference: + tp_size = self.params["tensor_model_parallel_size"] + layer = nn.Linear(self.params["dim"], vocab_size // tp_size, bias=False).to(self.precision) + return layer + else: + layer = ColumnParallelLinear( + self.params["dim"], + vocab_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + config=model_parallel, + ).to(self.precision) + return layer + else: + # No Tensor Parallelism + if self.params["backend"] == "pytorch": + return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision) + elif self.params["backend"] == "transformer_engine": + return LinearTE(self.params["dim"], vocab_size, bias=False).to(self.precision) + else: + raise ValueError("Unknown backend: " + self.params["backend"]) + + def _initialize_rope( + self, + ): + """ + Initialize the rotary position embedding. + + Returns: + tuple: (RotaryPositionEmbeddingTE, torch.Tensor) The RoPE module and the rotary position embeddings. + """ + rope = RotaryPositionEmbeddingTE(**self.rope_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + rotary_pos_emb = rope.forward(seq_len=self.params["max_seq_len"], training_type=training_type) + return rope, rotary_pos_emb + + def _initialize_abs_pos_emb(self): + pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + abs_pos_emb = pos_emb.forward(training_type=training_type) + return pos_emb, abs_pos_emb + + def _broadcast_pos_emb(self, pos_emb, tp_group): + """ + Broadcast the position embeddings across the tensor parallel group. + + Args: + pos_emb (torch.Tensor): Position embeddings to broadcast. + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + """ + if self.params["tensor_model_parallel_size"] > 1: + broadcast(pos_emb, min(get_process_group_ranks(tp_group)), group=tp_group) + + def _setup_peft(self): + """ + Set up Parameter Efficient Fine-Tuning (PEFT) by selectively freezing and unfreezing layers. + + This method configures the model for fine-tuning by: + 1. Freezing all parameters in the model. + 2. Unfreezing the embedding, normalization and output layers. + 3. Unfreezing the first and last (peft_last_n_layers - 1) transformer layers if peft_last_n_layers is set, + or unfreezing every n layers (flamingo style) if peft_every_n_layers is set. + """ + # Ensure only one of peft_last_n_layers and peft_every_n_layers is set + assert ( + self.peft_last_n_layers == 0 or self.peft_every_n_layers == 0 + ), "Only one of peft_last_n_layers and peft_every_n_layers can be set." + + # First, freeze all parameters + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze embedding, normalization and output layers + for param in self.tok_embeddings.parameters(): + param.requires_grad = True + for param in self.norm.parameters(): + param.requires_grad = True + for param in self.output.parameters(): + param.requires_grad = True + + # PEFT last n layers + if self.peft_last_n_layers > 0: + # Ensure peft_last_n_layers is at least 2 + assert self.peft_last_n_layers >= 2, "peft_last_n_layers must be at least 2" + + # Unfreeze specific transformer layers + total_layers = len(self.layers) + for i, layer in enumerate(self.layers): + if i == 0 or i >= total_layers - self.peft_last_n_layers + 1: + # Unfreeze the first layer and the last (peft_last_n_layers - 1) layers + for param in layer.parameters(): + param.requires_grad = True + + log.info( + f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " + f"first transformer layer, last {self.peft_last_n_layers - 1} transformer layers." + ) + # PEFT every n layers (flamingo style, e.g. every 4 layers = layer 0,1,2,4,5,6,... frozen, layer 3,7,11,... is trainable) + else: + trainable_layers = [] + for i, layer in enumerate(self.layers, 1): + if i % self.peft_every_n_layers == 0: + for param in layer.parameters(): + param.requires_grad = True + trainable_layers.append(i - 1) + + log.info( + f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " + f"every {self.peft_every_n_layers} transformer layers (layer idx {trainable_layers}; total {len(trainable_layers)} layers)." + ) + + def _setup_cross_attn_ft(self): + """ + Set up Cross Attention Fine-Tuning by selectively freezing and unfreezing layers. + + This method configures the model for fine-tuning by: + 1. Freezing all parameters in the model. + 2. Unfreezing the embedding, normalization and output layers. + 3. Unfreezing all the added cross-attention layers. + 4. If `finetune_layers_with_cross_attn` is True, unfreeze the transformer layers for layers with cross attention. + 5. If `finetune_layers_without_cross_attn` is True, unfreeze the transformer layers for layers without cross attention. + 6. If 'use_action_condition' is True, unfreeze the action embedding layers. + """ + assert self.has_cross_attention, "Must insert cross-attention layers for finetuning." + finetune_layer_num = 0 + + # First, freeze all parameters + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze embedding, normalization and output layers + for param in self.tok_embeddings.parameters(): + param.requires_grad = True + for param in self.norm.parameters(): + param.requires_grad = True + for param in self.output.parameters(): + param.requires_grad = True + + # Unfreeze all the added cross-attention layers + total_layers = len(self.layers) + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers == 0: + if self.params["backend"] == "pytorch": + for param in layer.cross_attention.parameters(): + param.requires_grad = True + elif self.params["backend"] == "transformer_engine": + for param in layer.inter_attention.parameters(): + param.requires_grad = True + else: + raise ValueError("Unknown backend: " + self.params["backend"]) + + # Unfreeze the transformer layers for layers with cross attention + if self.finetune_layers_with_cross_attn: + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers == 0: + for param in layer.parameters(): + param.requires_grad = True + finetune_layer_num += 1 + + # Unfreeze the transformer layers for layers without cross attention + if self.finetune_layers_without_cross_attn: + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers != 0: + for param in layer.parameters(): + param.requires_grad = True + finetune_layer_num += 1 + + # Unfreeze the action embedding layers + if self.use_action_condition: + for param in self.action_embedding_layers.parameters(): + param.requires_grad = True + + log.info( + f"cross attention finetune setup complete. Trainable components: cross-attention layer, " + f"fully trainable transformer layer number is {finetune_layer_num}." + ) + + def enable_context_parallel(self, cp_group: ProcessGroup): + """ + Enable context parallelism for the transformer model. + + This method sets up context parallelism by configuring the context parallel group + and updating each transformer layer to support context parallelism. + + Args: + cp_group (ProcessGroup): The process group for context parallelism. + + Notes: + - Updates the model's context parallel group and size. + - Configures each transformer layer for context parallelism. + - Enables context parallelism for the rotary position embedding if using the transformer engine backend. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + for layer_idx, layer in enumerate(self.layers): + if isinstance(layer, TransformerBlockTE): + layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + elif hasattr(layer, "module") and isinstance(layer.module, TransformerBlockTE): + layer.module.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + else: + log.warning(f"Layer {layer_idx} does not support context parallelism") + + def set_inference_flag(self, flag: bool): + """ + Set the inference flag for the transformer layers. + """ + log.info(f"Setting inference flag to {flag}") + self.inference = flag + if self.inference: + self.eval() + if self.params["backend"] == "pytorch": + for layer in self.layers: + layer.attention.set_inference_flag(flag) + elif self.params["backend"] == "transformer_engine": + for layer in self.layers: + layer.set_inference_flag(flag) + + self._maybe_change_sequence_parallel_status(enable=False) + + def _maybe_change_sequence_parallel_status(self, enable: bool): + """ + Change the sequence parallel status of the transformer layers. + """ + if enable and not self.sequence_parallel_enabled: + for name, module in self.named_modules(): + if hasattr(module, "sequence_parallel"): + assert isinstance( + module.sequence_parallel, bool + ), f"Invalid type of {name}: {type(module.sequence_parallel)}" + setattr(module, "sequence_parallel", True) + self.sequence_parallel_enabled = True + elif not enable and self.sequence_parallel_enabled: + for name, module in self.named_modules(): + if hasattr(module, "sequence_parallel"): + assert isinstance( + module.sequence_parallel, bool + ), f"Invalid type of {name}: {type(module.sequence_parallel)}" + setattr(module, "sequence_parallel", False) + self.sequence_parallel_enabled = False + + def forward( + self, + tokens: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + token_embeddings: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + total_seq_len: Optional[int] = None, + return_hidden_states: bool = False, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the Transformer module. + + Args: + tokens (torch.Tensor, optional): The input tensor of token IDs. + input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. PyTorch backend only. + inference_params (InferenceParams, optional): Parameters for inference. + token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + action (Optional[torch.Tensor]): The robot action tensor for conditioning. + total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). + return_hidden_states (bool): Whether to return hidden states. + Returns: + The output tensor after applying the transformer layers. + """ + + # Turn on/off sequence parallelism based on the training status + self._maybe_change_sequence_parallel_status(enable=self.training and self.params["sequence_parallel"]) + + # Token embeddings + assert ( + tokens is None or token_embeddings is None + ), "Either tokens or token_embeddings should be provided, not both." + + if token_embeddings is None: + seq_len = tokens.shape[1] + h = self.token_emb_dropout(self.tok_embeddings(tokens)) + else: + seq_len = token_embeddings.shape[1] + h = self.token_emb_dropout(token_embeddings) + + if mask is None: + # Create attention mask + mask = self._create_attention_mask(input_pos=input_pos) + + # Action embedding + if self.use_action_condition and action is not None: + assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" + # change action type to bfloat16, of shape [batch_size, action_dim] + action = action.to(torch.bfloat16) + # action_emb shape: [batch_size, action_dim, action_embedding_dim] + action_emb = self.action_embedding_layers(action).unsqueeze(1).repeat(1, self.action_dim, 1) + + # Use action_emb as context + if self.params["concat_action_to_context"]: + context = torch.zeros( + (action_emb.shape[0], _T5_NUM_TOKENS, self.action_embedding_dim), device=h.device, dtype=h.dtype + ) + # context[:, -1, :] = action_emb[:, 0, :] # overwrite the last token with action_emb + context = torch.cat([context, action_emb[:, 0:1, :]], dim=1) + else: + context = action_emb # [batch_size, action_dim, action_embedding_dim] + + # Create context mask + if self.group_causal_mask_mode is not None: + num_temporal_groups = self.num_video_frames - 1 # number of latent frames + num_query_per_group = seq_len // num_temporal_groups # number of latent tokens per frame + num_key_per_group = self.action_dim // num_temporal_groups + context_mask = create_group_causal_attn_mask( + num_temporal_groups=num_temporal_groups, + num_query_per_group=num_query_per_group, + num_key_per_group=num_key_per_group, + mode=self.group_causal_mask_mode, + ) # [L (query), S (key)] + context_mask = context_mask.unsqueeze(0) # [1, L (query), S (key)] + context_mask = context_mask.repeat(context.shape[0], 1, 1) # [batch_size, L (query), S (key)] + context_mask = context_mask.to(context.device) + else: + context_mask = torch.ones( + (context.shape[0], context.shape[1]), device=context.device, dtype=torch.bool + ) # [batch_size, action_dim] + + # Prepare layer arguments + layer_kwargs = self._prepare_layer_kwargs( + total_seq_len=total_seq_len, + input_pos=input_pos, + mask=mask, + inference_params=inference_params, + context=context, + context_mask=context_mask, + ) + + # Apply transformer layers + for layer in self.layers: + if self.params["apply_abs_pos_emb"]: + h = self.apply_abs_pos_emb(h, input_pos=input_pos, total_seq_len=total_seq_len) + h = layer(h, **layer_kwargs) + + # Apply final layer normalization + h = self.norm(h) + if return_hidden_states: + return h + + # Output linear projection + output = self.output(h) + output = self.process_output(output) + return output + + def process_output(self, output: torch.Tensor) -> torch.Tensor: + """ + Adjusts the shape and layout of tensor based on tensor parallelism and attention input format. + + The function performs two operations: + 1. If the tensor model parallelism is enabled (`tensor_model_parallel_size > 1`), it gathers the tensor from + the tensor-parallel regions and reshapes it accordingly. + 2. If the attention input format is `"sbhd"` (Sequence, Batch, Hidden Dimension), it transposes the tensor + to the format `(Batch, Sequence, Hidden Dimension)` for further processing. + + Args: + output [torch.Tensor]: The tensor before modification. + + Returns: + output [torch.Tensor]: The tensor after modification. + + """ + if self.params["tensor_model_parallel_size"] > 1: + if self.params["backend"] == "pytorch" and self.inference: + # Use PyTorch all gather + output = funcol.all_gather_tensor( + output, gather_dim=-1, group=parallel_state.get_tensor_model_parallel_group() + ) + else: + # [*, *, hidden_dim // tp_size] --> [*, *, hidden_dim] + output = gather_from_tensor_model_parallel_region(output) + if self.attn_input_format == "sbhd": + # [seq_len, batch_size, hidden_dim] --> [batch_size, seq_len, hidden_dim] + output = output.transpose(0, 1).contiguous() + return output + + def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """ + Creates an attention mask for the transformer layers. + + Args: + input_pos[torch.Tensor]: The position of input sequence (used for inference only). + + Returns: + Optional[torch.Tensor]: The attention mask, or None for causal mask. + """ + + if self.backend == "pytorch" and self.inference: + assert input_pos is not None, "input_pos must be provided for inference" + mask = self.causal_mask[input_pos] + return mask + else: + return None # None means causal mask + + def _prepare_layer_kwargs( + self, + total_seq_len: Optional[int], + input_pos: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + inference_params: Optional[InferenceParams], + context: Optional[torch.Tensor], + context_mask: Optional[torch.Tensor], + ) -> Dict[str, Any]: + """ + Prepares the keyword arguments for transformer layers. + + Args: + total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). + seq_len (Optional[int]): The length of the input sequence. + input_pos (Optional[torch.Tensor]): The position of the current sequence. + mask (Optional[torch.Tensor]): The attention mask. + inference_params (Optional[InferenceParams]): Parameters for inference. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. + """ + if context is not None: + context = context.to(self.precision) + + if self.attn_input_format == "sbhd": + context = context.transpose(0, 1).contiguous() + if self.backend == "pytorch": + if isinstance(mask, torch.Tensor) and mask.ndim == 2: + mask = mask[None, None, :, :] + if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: + context_mask = context_mask[None, None, :, :] + + layer_kwargs = { + "mask": mask, + "context": context, + "context_mask": context_mask, + } + + if self.backend == "pytorch": + layer_kwargs["input_pos"] = input_pos + layer_kwargs["rope"] = self.rope + elif self.backend == "transformer_engine": + rotary_pos_emb = self.rotary_pos_emb + try: + cp_size = parallel_state.get_context_parallel_world_size() + except (AssertionError, RuntimeError): + # Fallback if context parallel group isn't initialized + cp_size = 1 + log.warning("Context parallel group not initialized, falling back to size 1") + else: + cp_size = 1 + if cp_size > 1: + assert input_pos is None, "input_pos must be None for context parallelism" + rotary_pos_emb = rotary_pos_emb[:total_seq_len] + rotary_pos_emb = get_pos_emb_on_this_cp_rank(rotary_pos_emb, 0) + + layer_kwargs["rotary_pos_emb"] = rotary_pos_emb + layer_kwargs["inference_params"] = inference_params + + return layer_kwargs + + def apply_abs_pos_emb( + self, x: torch.Tensor, input_pos: int = None, total_seq_len: Optional[int] = None + ) -> torch.Tensor: + """ + Applies the absolute position embeddings to the input tensor. + """ + abs_pos_emb = self.abs_pos_emb + if total_seq_len is not None: + # Truncate the absolute position embeddings to the total sequence length + abs_pos_emb = ( + abs_pos_emb[:total_seq_len, :, :] + if self.attn_input_format == "sbhd" + else abs_pos_emb[:, :total_seq_len, :] + ) + cp_size = parallel_state.get_context_parallel_world_size() if self.training else 1 + if cp_size > 1: + assert input_pos is None + seq_dim = 0 if self.attn_input_format == "sbhd" else 1 + abs_pos_emb = get_pos_emb_on_this_cp_rank(abs_pos_emb, seq_dim=seq_dim) + if self.attn_input_format == "sbhd": + if self.sequence_parallel_enabled: + # Training + assert input_pos is None, "input_pos must be None when training with sequence parallelism" + abs_pos_emb = get_pos_emb_on_this_sptp_rank(abs_pos_emb, seq_dim=0) + else: + # Inference or Evaluation + abs_pos_emb = abs_pos_emb[input_pos, :, :] if input_pos is not None else abs_pos_emb + else: + abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb + return x + abs_pos_emb + + @torch.no_grad() + def expand_vocab( + self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True + ): + """ + Expands the vocabulary of the model to the new size. + + Args: + new_vocab_size (int): The new vocabulary size. + init_method (str): The initialization method for new embeddings. + Can be "zero" or "gaussian". Default is "gaussian". + multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully + leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, + source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) + expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. + + Returns: + None + """ + + tp_size = self.params["tensor_model_parallel_size"] + if new_vocab_size <= self.vocab_size: + raise ValueError( + f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" + ) + if new_vocab_size % multiple_of != 0: + log.critical(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") + new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of + log.critical(f"Rounded vocabulary size to {new_vocab_size}.") + # Resize token embeddings + old_embeddings = self.tok_embeddings + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} + self.tok_embeddings = self._create_token_embeddings( + model_parallel=self.model_parallel, vocab_size=new_vocab_size + ).to(**tensor_kwargs) + # Initialize new embeddings + if init_method not in ["zero", "gaussian"]: + raise ValueError(f"Unknown initialization method: {init_method}") + # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything + # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. + if init_method == "zero": + self.tok_embeddings.weight.data[self.vocab_size // tp_size :].zero_() + + # Copy old embeddings + log.info( + f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" + ) + self.tok_embeddings.weight.data[: self.vocab_size // tp_size] = old_embeddings.weight.data + self.tok_embeddings.weight.requires_grad = old_embeddings_requires_grad + # Resize output layer + old_output = self.output + old_output_requires_grad = old_output.weight.requires_grad + self.output = self._create_output_projection( + self.model_parallel, vocab_size=new_vocab_size if expand_output_layer else None + ) + + # Initialize new output weights + if init_method == "zero": + self.output.weight.data[self.vocab_size // tp_size :].zero_() + elif init_method == "gaussian": + # Follows the parameter initialization in TorchTitan: + # https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py + final_out_std = self.params["dim"] ** -0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + # Copy old output weights + self.output.weight.data[: self.vocab_size // tp_size] = old_output.weight.data + self.output.weight.requires_grad = old_output_requires_grad + + # Update vocab size + self.vocab_size = new_vocab_size + log.critical(f"Expanded vocabulary size to {new_vocab_size}") + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters`` (copied from github.com/pytorch/torchtitan)] + Modules may define ``reset_parameters`` to initialize parameter values. ``reset_parameters`` is meant to only + initialize directly owned parameters/buffers, not those of their child modules, and it can be used to give the + initial values for these tensors. Separately, users may want custom initialization for their modules, different + from that in ``reset_parameters``. For this, we define ``init_weights``. We only call it in the constructor of + this ``Transformer`` root module to avoid reinitializing tensors. + """ + + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers: + layer.init_weights() + if self.backend == "pytorch": + self.norm.reset_parameters() + elif self.backend == "transformer_engine": + nn.init.ones_(self.norm.weight) + else: + raise ValueError(f"Unknown backend: {self.backend}") + final_out_std = self.params["dim"] ** -0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + if self.use_action_condition: + for layer in self.action_embedding_layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + nn.init.zeros_(layer.bias) + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + if strict: + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + missing_keys = actual_missing_keys + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def on_after_backward(self, *args, **kwargs): + """ + All-reduce layernorm grads for tensor/sequence parallelism. + Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py + """ + allreduce_layernorm_grads( + [self], + tensor_model_parallel_size=self.params["tensor_model_parallel_size"], + sequence_parallel=self.params["sequence_parallel"], + ) + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + if self.params["sync_1d_parameters"]: + if self.params["tensor_model_parallel_size"] > 1: + sync_1d_parameters(self, process_group=parallel_state.get_tensor_model_parallel_group()) + if self.params["context_parallel_size"] > 1: + sync_1d_parameters(self, process_group=parallel_state.get_context_parallel_group()) diff --git a/cosmos_predict1/autoregressive/utils/__init__.py b/cosmos_predict1/autoregressive/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/utils/checkpoint.py b/cosmos_predict1/autoregressive/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fb49e9e03173caccc8473c513d70124dc371d38e --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/checkpoint.py @@ -0,0 +1,594 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch + +from cosmos_predict1.utils import log + +# Substrings to ignore when processing state dicts +substrings_to_ignore = [ + "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling +] + + +def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str: + """ + Identify the backend of the checkpoint (PyTorch or TransformerEngine) + + Args: + state_dict (dict[str, torch.Tensor]): The state dict to check + + Returns: + str: The backend of the checkpoint + """ + for key in state_dict.keys(): + if "self_attention.layernorm_qkv.query_weight" in key: + return "transformer_engine" + elif "attention.wq.weight" in key: + return "pytorch" + raise ValueError("Could not identify the backend of the checkpoint") + + +def get_partial_state_dict( + state_dict: dict[str, torch.Tensor], + prefix: str, +) -> dict[str, torch.Tensor]: + """ + Get a partial state dict with keys starting with the given prefix + """ + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + +def process_state_dict( + state_dict: dict[str, torch.Tensor], + device: str = None, + dtype: torch.dtype = None, + prefix_to_remove: Optional[str] = None, +) -> dict[str, torch.Tensor]: + """ + - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) + - Move tensors to specified device and dtype if provided + + Args: + state_dict (dict[str, torch.Tensor]): The state dict to process + device (str, optional): The device to move tensors to. Defaults to None. + dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. + prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. + + Returns: + dict[str, torch.Tensor]: The processed state dict + """ + new_state_dict = {} + tensor_kwargs = {} + if device is not None: + tensor_kwargs["device"] = device + if dtype is not None: + tensor_kwargs["dtype"] = dtype + + for key, value in state_dict.items(): + # Check if any of the substrings to ignore are in the key + skip = False + for substr in substrings_to_ignore: + if substr in key: + skip = True + break + if skip: + continue + if len(tensor_kwargs) > 0: + value = value.to(**tensor_kwargs) + if prefix_to_remove is not None and key.startswith(prefix_to_remove): + key = key[len(prefix_to_remove) :] + new_state_dict[key] = value + return new_state_dict + + +def obtain_tensor_parallel_state_dict( + whole_model_state_dict: dict[str, torch.Tensor], + tensor_parallel_size: int, + tensor_parallel_rank: int, + model_config, + target_backend: str = None, +) -> dict[str, torch.Tensor]: + """ + Obtain the tensor parallel state dict shard for the current rank. + + Args: + whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict. + tensor_parallel_size (int): The number of tensor parallel devices. + tensor_parallel_rank (int): The rank of the current tensor parallel device. + model_config: The model configuration. + target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. + + Returns: + dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank. + """ + new_state_dict_shard = {} + whole_model_state_dict = process_state_dict(whole_model_state_dict) + source_backend = identify_checkpoint_backend(whole_model_state_dict) + if source_backend != "pytorch": + # Convert the checkpoint to PyTorch backend for checkpoint sharding + whole_model_state_dict = maybe_convert_checkpoint_to_backend( + whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend + ) + + n_heads = model_config["n_heads"] + n_kv_heads = model_config["n_kv_heads"] + dim = model_config["dim"] + context_dim = model_config["context_dim"] + for key, value in whole_model_state_dict.items(): + prefix = "model." if key.startswith("model.") else "" # LLM's model prefix + prefix = "transformer." if key.startswith("transformer.") else prefix # VIT's model prefix + key = key.replace(prefix, "") + if key.startswith("layers."): + layer_index = int(key.split("layers.")[1].split(".")[0]) + if layer_index >= model_config["n_layers"]: + log.warning( + f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." + ) + continue + if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: + value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] + value = value.reshape(-1, dim) + elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: + value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] + value = value.reshape(-1, dim) + elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: + assert context_dim is not None + value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[ + tensor_parallel_rank + ] + value = value.reshape(-1, context_dim) + elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: + value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] + elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: + value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank] + else: + # Handle non-layer weights + if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key: + value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] + new_state_dict_shard[prefix + key] = value + + if target_backend is None: + target_backend = source_backend + + new_state_dict_shard = maybe_convert_checkpoint_to_backend( + new_state_dict_shard, + target_backend=target_backend, + model_config=model_config, + is_tensor_parallel_shard=True, + tensor_parallel_size=tensor_parallel_size, + ) + + return new_state_dict_shard + + +def merge_tensor_parallel_state_dicts( + state_dict_shards: list[dict[str, torch.Tensor]], + model_config, + target_backend: str = None, +) -> dict[str, torch.Tensor]: + """ + Merge tensor parallel state dict shards into a whole model state dict. + + Args: + state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge. + model_config: The model configuration. + target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. + + Returns: + Dict[str, torch.Tensor]: The merged state dict. + """ + state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards] + tensor_parallel_size = len(state_dict_shards) + source_backend = identify_checkpoint_backend(state_dict_shards[0]) + if source_backend != "pytorch": + log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.") + state_dict_shards = [ + maybe_convert_checkpoint_to_backend( + shard, + target_backend="pytorch", + model_config=model_config, + source_backend=source_backend, + is_tensor_parallel_shard=True, + tensor_parallel_size=tensor_parallel_size, + ) + for shard in state_dict_shards + ] + + n_heads = model_config["n_heads"] + n_kv_heads = model_config["n_kv_heads"] + n_local_heads = n_heads // tensor_parallel_size + n_local_kv_heads = n_kv_heads // tensor_parallel_size + dim = model_config["dim"] + context_dim = model_config["context_dim"] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + query_dim = head_dim * n_heads + key_value_dim = head_dim * n_kv_heads + merged_state_dict = {} + + for key in state_dict_shards[0].keys(): + prefix = "model." if key.startswith("model.") else "" + key_without_prefix = key[len(prefix) :] + if key_without_prefix.startswith("layers."): + layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0]) + if layer_index >= model_config["n_layers"]: + log.warning( + f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." + ) + continue + if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight": + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) + elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: + chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim) + elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: + chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim) + elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: + chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim) + elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) + elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1) + else: + avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0) + # make sure shard-0 is close to the average tensor + assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), ( + f"Shard-0 tensor {key} is not close to the average tensor. " + f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, " + ) + merged_state_dict[key] = avg_tensor + assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards." + + if target_backend is None: + target_backend = source_backend + return maybe_convert_checkpoint_to_backend( + merged_state_dict, target_backend=target_backend, model_config=model_config + ) + + +def te_to_pytorch_state_dict( + te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a TransformerEngine state dict to PyTorch state dict + + Args: + te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict + model_config: The model configuration + tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). + + Returns: + Mapping[str, torch.Tensor]: The PyTorch state dict + """ + + if hasattr(model_config, "asdict"): + model_config = model_config.asdict() + + pytorch_state_dict = {} + replacement_rules = [ + # Self-attention modules + (".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"), + (".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"), + (".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"), + (".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"), + (".self_attention.proj.weight", ".attention.wo.weight"), + (".self_attention.", ".attention."), # Handle the rest modules such as q_norm and k_norm + # MLP modules + (".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"), + (".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"), + # Cross-attention modules + (".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"), + (".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"), + (".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"), + (".inter_attention.proj.weight", ".cross_attention.wo.weight"), + (".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"), + (".inter_attention.", ".cross_attention."), # Handle the rest modules such as q_norm and k_norm + ] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + for old_key, value in te_state_dict.items(): + new_key = old_key + for old_substr, new_substr in replacement_rules: + if old_substr in new_key: + new_key = new_key.replace(old_substr, new_substr) + break + + # Handle the fused w1 and w3 case + if "layernorm_mlp.fc1_weight" in old_key: + fused_weight = value + split_point = fused_weight.shape[0] // 2 + w1_weight = fused_weight[:split_point] + w3_weight = fused_weight[split_point:] + + w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight") + w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight") + + pytorch_state_dict[w1_key] = w1_weight + pytorch_state_dict[w3_key] = w3_weight + else: + if model_config["pytorch_rope_version"] == "v1": + # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. + # Thus, we do not need to permute the weights. + if "query_weight" in old_key: + value = inverse_permute_weight( + value, + n_heads=model_config["n_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, + dim2=model_config["dim"], + ) + elif "key_weight" in old_key: + value = inverse_permute_weight( + value, + n_heads=model_config["n_kv_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, + dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"], + ) + pytorch_state_dict[new_key] = value + + return pytorch_state_dict + + +def pytorch_to_te_state_dict( + pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a PyTorch state dict to TransformerEngine state dict + + Args: + pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict + model_config: The model configuration + tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). + + Returns: + Mapping[str, torch.Tensor]: The TransformerEngine + """ + + if hasattr(model_config, "asdict"): + model_config = model_config.asdict() + + te_state_dict = {} + + replacement_rules = [ + # Self-attention modules + (".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"), + (".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"), + (".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"), + (".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"), + (".attention.wo.weight", ".self_attention.proj.weight"), + (".attention.", ".self_attention."), + # MLP modules + (".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"), + (".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"), + # Cross-attention modules + (".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"), + (".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"), + (".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"), + (".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"), + (".cross_attention.wo.weight", ".inter_attention.proj.weight"), + (".cross_attention.", ".inter_attention."), + ] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + for old_key, value in pytorch_state_dict.items(): + new_key = old_key + for new_substr, old_substr in replacement_rules: + if new_substr in new_key: + new_key = new_key.replace(new_substr, old_substr) + break + + # Handle the split w1 and w3 case + if "feed_forward.w1.weight" in old_key: + w1_weight = value + w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight") + if w3_key in pytorch_state_dict: + w3_weight = pytorch_state_dict[w3_key] + fused_weight = torch.cat([w1_weight, w3_weight], dim=0) + new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight") + te_state_dict[new_key] = fused_weight + else: + te_state_dict[new_key] = value + elif "feed_forward.w3.weight" in old_key: + # Skip w3 weights as they're handled with w1 + continue + else: + if model_config["pytorch_rope_version"] == "v1": + # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. + # Thus, we do not need to permute the weights. + if "attention.wq" in old_key: + value = permute_weight( + value, + n_heads=model_config["n_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, + dim2=model_config["dim"], + ) + elif "attention.wk" in old_key: + value = permute_weight( + value, + n_heads=model_config["n_kv_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, + dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"], + ) + te_state_dict[new_key] = value + + return te_state_dict + + +def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + """ + Helper function for converting checkpoints from PyTorch to TransformerEngine + Permute the query weight or key weight of each attention layer + Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + + Args: + w (torch.Tensor): The weight tensor to permute + n_heads (int): The number of attention heads + dim1 (int): The first dimension of the weight tensor + dim2 (int): The second dimension of the weight tensor + + Returns: + torch.Tensor: The permuted weight tensor + """ + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + """ + Helper function for converting checkpoints from TransformerEngine to PyTorch + Permute the query weight or key weight of each attention layer + + Args: + w (torch.Tensor): The weight tensor to permute + n_heads (int): The number of attention heads + dim1 (int): The first dimension of the weight tensor + dim2 (int): The second dimension of the weight tensor + + Returns: + torch.Tensor: The permuted weight tensor + """ + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def pytorch_to_hf_state_dict( + state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a PyTorch state dict to HuggingFace format for LLM models. + + Args: + state_dict (Mapping[str, torch.Tensor]): + The original PyTorch model's state dictionary. + This is a mapping where keys are layer names and values are the corresponding PyTorch tensors + containing the model weights. + + model_config (Mapping[str, Any]): + The configuration of the model. This dictionary contains parameters such as: + - n_layers: (int) The number of transformer layers. + - n_heads: (int) The number of attention heads. + - dim: (int) The hidden size of the model. + - n_kv_heads: (int, optional) The number of key-value heads for multi-query attention. + + Returns: + Mapping[str, torch.Tensor]: + The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible + layer names to the corresponding model weights. + """ + not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"] + for key in state_dict.keys(): + if any(substr in key for substr in not_supported_key_substrings): + raise ValueError(f"Key {key} is not supported in HuggingFace format.") + assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export." + + hf_state_dict = {} + + n_layers = model_config["n_layers"] + n_heads = model_config["n_heads"] + dim = model_config["dim"] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + + num_key_value_heads = model_config.get("n_kv_heads", n_heads) + key_value_dim = head_dim * num_key_value_heads + + for layer_i in range(n_layers): + pt_prefix = f"layers.{layer_i}." + hf_prefix = f"model.layers.{layer_i}." + + wq = state_dict[f"{pt_prefix}attention.wq.weight"] + wk = state_dict[f"{pt_prefix}attention.wk.weight"] + if model_config["pytorch_rope_version"] == "v1": + wq = permute_weight( + wq, + n_heads=n_heads, + dim1=dim, + dim2=dim, + ) + wk = permute_weight( + wk, + n_heads=num_key_value_heads, + dim1=key_value_dim, + dim2=dim, + ) + hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq + hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk + hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"] + hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"] + hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"] + hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"] + hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"] + hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"] + hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"] + + # Add non-layer weights + hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"] + hf_state_dict["model.norm.weight"] = state_dict["norm.weight"] + hf_state_dict["lm_head.weight"] = state_dict["output.weight"] + + return hf_state_dict + + +def maybe_convert_checkpoint_to_backend( + state_dict: Dict[str, torch.Tensor], + target_backend: str, + model_config, + source_backend: str = None, + is_tensor_parallel_shard: bool = False, + tensor_parallel_size: int = None, +): + """ + Identify the backend of the checkpoint and convert to the target backend if necessary. + + This function checks the current backend of the state_dict and converts it to the target backend + if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends. + + Args: + state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert. + target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface'). + model_config: Configuration of the model, used in conversion process. + source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend. + is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False. + tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified. + Returns: + Dict[str, torch.Tensor]: The converted state dictionary in the target backend format. + + Raises: + ValueError: If the conversion between the identified backend and target backend is not supported. + """ + # Identify the current backend of the checkpoint + state_dict = process_state_dict(state_dict) # Remove unnecessary keys + if source_backend is None: + source_backend = identify_checkpoint_backend(state_dict) + if source_backend == target_backend: + return state_dict + else: + if tensor_parallel_size is None: + tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1 + # Convert to target backend + if source_backend == "pytorch" and target_backend == "transformer_engine": + return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + elif source_backend == "transformer_engine" and target_backend == "pytorch": + return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + elif source_backend == "pytorch" and target_backend == "huggingface": + return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + else: + raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.") diff --git a/cosmos_predict1/autoregressive/utils/inference.py b/cosmos_predict1/autoregressive/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bca159c50db62a6938b7dbe7041423958f014c3c --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/inference.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +from pathlib import Path +from typing import List + +import numpy as np +import torch +import torchvision +from PIL import Image + +from cosmos_predict1.autoregressive.configs.inference.inference_config import SamplingConfig +from cosmos_predict1.utils import log + +_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] +_VIDEO_EXTENSIONS = [".mp4"] +_SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames +NUM_TOTAL_FRAMES = 33 + + +def add_common_arguments(parser): + """Add common command line arguments. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input path for input image or video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Input folder containing all input images or videos", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=9, + help="Number of input frames for world generation", + choices=_SUPPORTED_CONTEXT_LEN, + ) + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") + parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") + parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + parser.add_argument( + "--offload_diffusion_decoder", + action="store_true", + help="Offload diffusion decoder after inference", + ) + parser.add_argument( + "--offload_ar_model", + action="store_true", + help="Offload AR model after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload discrete tokenizer model after inference", + ) + parser.add_argument( + "--disable_guardrail", + action="store_true", + help="Disable guardrail models", + ) + + +def validate_args(args: argparse.Namespace, inference_type: str): + """Validate command line arguments for base and video2world generation.""" + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: + args.num_input_frames = 1 + log.info(f"Set num_input_frames to 1 for {args.input_type} input") + + if args.num_input_frames == 1: + if "4B" in args.ar_model_dir: + log.warning( + "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + elif "5B" in args.ar_model_dir: + log.warning( + "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + + # Validate prompt/image/video args for single or batch generation + assert ( + args.input_image_or_video_path or args.batch_input_path + ), "--input_image_or_video_path or --batch_input_path must be provided." + if inference_type == "video2world" and (not args.batch_input_path): + assert args.prompt, "--prompt is required for single video generation." + args.data_resolution = [640, 1024] + + # Create output folder + Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) + + sampling_config = SamplingConfig( + echo=True, + temperature=args.temperature, + top_p=args.top_p, + compile_sampling=True, + ) + return sampling_config + + +def resize_input(video: torch.Tensor, resolution: list[int]): + r""" + Function to perform aspect ratio preserving resizing and center cropping. + This is needed to make the video into target resolution. + Args: + video (torch.Tensor): Input video tensor + resolution (list[int]): Data resolution + Returns: + Cropped video + """ + + orig_h, orig_w = video.shape[2], video.shape[3] + target_h, target_w = resolution + + scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) + resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) + video_resized = torchvision.transforms.functional.resize(video, resizing_shape) + video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) + return video_cropped + + +def load_image_from_list(flist, data_resolution: List[int]) -> dict: + """ + Function to load images from a list of image paths. + Args: + flist (List[str]): List of image paths + data_resolution (List[int]): Data resolution + Returns: + Dict containing input images + """ + all_videos = dict() + for img_path in flist: + ext = os.path.splitext(img_path)[1] + if ext in _IMAGE_EXTENSIONS: + # Read the image + img = Image.open(img_path) + + # Convert to tensor + img = torchvision.transforms.functional.to_tensor(img) + static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) + static_vid = static_vid * 2 - 1 + + log.debug( + f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + static_vid = resize_input(static_vid, data_resolution) + fname = os.path.basename(img_path) + all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input images from a JSONL file. + + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + + Returns: + Dict containing input images + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_image(input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input image. + Args: + input_path (str): Path to input image + data_resolution (List[int]): Data resolution + Returns: + Dict containing input image + """ + flist = [input_path] + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + r""" + Function to read input videos. + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + Returns: + Dict containing input videos + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to read input video. + Args: + input_path (str): Path to input video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input video + """ + flist = [input_path] + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to load videos from a list of video paths. + Args: + flist (List[str]): List of video paths + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + all_videos = dict() + + for video_path in flist: + ext = os.path.splitext(video_path)[-1] + if ext in _VIDEO_EXTENSIONS: + video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") + video = video.float() / 255.0 + video = video * 2 - 1 + + # Resize the videos to the required dimension + nframes_in_video = video.shape[0] + if nframes_in_video < num_input_frames: + fname = os.path.basename(video_path) + log.warning( + f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." + ) + continue + + video = video[-num_input_frames:, :, :, :] + + # Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES) + video = torch.cat( + (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), + dim=0, + ) + + video = video.permute(0, 3, 1, 2) + + log.debug( + f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + video = resize_input(video, data_resolution) + + fname = os.path.basename(video_path) + all_videos[fname] = video.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def load_vision_input( + input_type: str, + batch_input_path: str, + input_image_or_video_path: str, + data_resolution: List[int], + num_input_frames: int, +): + """ + Function to load vision input. + Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. + Args: + input_type (str): Type of input + batch_input_path (str): Folder containing input images or videos + input_image_or_video_path (str): Path to input image or video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + if batch_input_path: + log.info(f"Reading batch inputs from path: {batch_input_path}") + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_videos( + batch_input_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + else: + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_video( + input_image_or_video_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + return input_videos + + +def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: + """ + Function to convert output tensors to numpy format for saving. + Args: + video_batch (List[torch.Tensor]): List of output tensors + Returns: + List of numpy arrays + """ + return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] diff --git a/cosmos_predict1/autoregressive/utils/misc.py b/cosmos_predict1/autoregressive/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8e781bcf3538bd26b2aa45adce8a2921b9a14f --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/misc.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from omegaconf import DictConfig, OmegaConf + + +class CustomSimpleNamespace: + """ + A simple namespace class that supports both attribute-style and dictionary-style access. + """ + + def __init__(self, d): + self._d = d + + def __getattr__(self, attr): + # Attribute-style access: config.key + try: + return self._d[attr] + except KeyError: + raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'") + + def __getitem__(self, key): + # Dictionary-style access: config['key'] + return self._d[key] + + +def maybe_convert_to_namespace(config): + """ + This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both + attribute-style and dictionary-style access. + Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile. + """ + # If input is OmegaConf's DictConfig, convert to a standard dict + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + + if isinstance(config, dict): + return CustomSimpleNamespace(config) + else: + return config + + +def random_dropout(embeddings, drop_rate): + r""" + Function to perform random dropout for embeddings. + When we drop embeddings, we zero them out. + Args: + embeddings (tensor): Input embeddings + drop_rate (float): Rate of dropping the embedding. + """ + num_samples = embeddings.shape[0] + # Create a shape (num_samples, 1, 1, 1, 1, ...) depending on embeddings dim. + # This is done to ensure we can broadcast the zero_flag to the embeddings. + # embeddings.ndim is 3 for images, and 4 for videos, and the corresponding + # shapes are (num_samples, 1, 1) and (num_samples, 1, 1, 1) respectively. + tensor_shape = (num_samples,) + tuple([1] * (embeddings.ndim - 1)) + zero_flag = torch.ones(tensor_shape).to(embeddings.dtype) * (1 - drop_rate) + zero_flag = torch.bernoulli(zero_flag).to(embeddings.device) + embeddings = embeddings * zero_flag + return embeddings diff --git a/cosmos_predict1/autoregressive/utils/parallel.py b/cosmos_predict1/autoregressive/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..05f7733aea75175eae7d3b1b68d2b004b377a6c8 --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/parallel.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.distributed as dist +from megatron.core import mpu, parallel_state +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Function +from torch.distributed import broadcast, get_process_group_ranks +from transformer_engine.pytorch.jit import no_torch_dynamo +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE +from transformer_engine.pytorch.module.rmsnorm import _RMSNorm + +from cosmos_predict1.utils import log + + +def get_batch_on_this_cp_rank(inputs): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + + # With causal masking, each token only attends to its prior tokens. Simply split + # sequence into CP chunks can result in severe load imbalance. That's to say, chunks + # at the end of sequence have bigger workload than others. To address this issue, + # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 + # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so + # that we can get balanced workload among GPUs in a context parallel group. + cp_size = parallel_state.get_context_parallel_world_size() + + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() + seq_dim = 1 # if key != 'attention_mask' else 2 + inputs = inputs.view( + *inputs.shape[0:seq_dim], + 2 * cp_size, + inputs.shape[seq_dim] // (2 * cp_size), + *inputs.shape[(seq_dim + 1) :], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( + non_blocking=True + ) + inputs = inputs.index_select(seq_dim, index) + inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :]) + + return inputs + + +def gather_batch_from_cp_ranks(outputs): + """ + Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + + if cp_size > 1: + seq_dim = 1 # Assuming sequence dimension is 1 + + try: + # Reshape output to separate the two chunks + chunk_size = outputs.shape[seq_dim] // 2 + outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :]) + + # Prepare a list to gather all chunks from all ranks + gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)] + + # Gather all chunks + dist.barrier() + dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group()) + dist.barrier() + + # Reorder chunks + reordered_chunks = [None] * (2 * cp_size) + for i in range(cp_size): + reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0) + reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1) + + # Concatenate all chunks + outputs = torch.cat(reordered_chunks, dim=seq_dim) + except Exception as e: + log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}") + raise + + return outputs + + +def broadcast_data_batch_in_tp_cp_group(data_batch): + """ + Broadcast data batch across tensor model parallel and context parallel groups. + """ + keys = sorted(data_batch.keys()) + tp_size = parallel_state.get_tensor_model_parallel_world_size() + cp_size = parallel_state.get_context_parallel_world_size() + tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None + cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None + tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None + cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None + if tp_size > 1 or cp_size > 1: + for key in keys: + tensor = data_batch[key] + if isinstance(tensor, torch.Tensor): + tensor = tensor.contiguous() + if tp_size > 1: + broadcast(tensor, min(tp_ranks), group=tp_group) + if cp_size > 1: + broadcast(tensor, min(cp_ranks), group=cp_group) + + +def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool): + """ + All-reduce layernorm grads (for sequence parallelism). + Note: + - We skip QK Normalization layers and the last normalization layer of Transformer, + since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass. + - TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that + we must all-reduce in the backward pass as well. So we implement this function to cover these parameters. + """ + # All-reduce layernorm parameters across model parallel nodes + # when sequence parallelism is used + if tensor_model_parallel_size > 1 and sequence_parallel: + grads = [] + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + if name.endswith(".layer_norm_weight"): # TP # Q-layernorm # K-layernorm + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def sync_1d_parameters(model: torch.nn.Module, process_group=None): + """ + Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation. + LayerNorm parameters are identified by having ndim==1. + Note: If parameters other than LayerNorm are 1D, they will also be synchronized. + + Args: + model (torch.nn.Module): The model containing layernorm parameters + process_group (optional): The process group to perform all-reduce. + If None, uses the default process group. + """ + if not torch.distributed.is_initialized(): + return + # Synchronize each 1D parameter (layernorm parameters) + for name, param in model.named_parameters(): + if param.ndim == 1 and param.requires_grad: # LayerNorm weights/biases are 1D + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group) + + +class AllReduceBWD(Function): + """ + Custom autograd Function that performs an all-reduce operation during the backward pass. + + Args: + tensor (Tensor): The input tensor. + process_group: The process group to perform the all-reduce operation. + + Returns: + Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass. + """ + + @staticmethod + def forward(ctx, tensor, process_group): + ctx.process_group = process_group + return tensor + + @staticmethod + def backward(ctx, grad_output): + dist.all_reduce(grad_output, group=ctx.process_group) + return grad_output, None + + +class AllReduceBWDRMSNormTE(RMSNormTE): + """ + A custom RMSNorm layer that applies all-reduce operation during backward pass. + Used in tensor parallel training with Transformer Engine. + + Args: + hidden_size (int): The size of the hidden dimension. + process_group: Megatron Core's process group. + **kwargs: Additional arguments to be passed to RMSNormTE. + """ + + def __init__(self, hidden_size, process_group, **kwargs): + super().__init__(hidden_size, **kwargs) + self.process_group = process_group + + @no_torch_dynamo() + def forward(self, inp: torch.Tensor) -> torch.Tensor: + """RMSNorm FWD""" + + # Set the activation type for AMP. + TransformerEngineBaseModule.set_activation_dtype(self, inp) + + if torch.is_grad_enabled(): + fwd_fn = _RMSNorm.apply + args = [] + else: + fwd_fn = _RMSNorm.forward + args = [None] + + args += ( + inp, + AllReduceBWD.apply(self.weight, self.process_group), + self.eps, + self.fwd_rmsnorm_sm_margin, + self.bwd_rmsnorm_sm_margin, + self.inf_rmsnorm_sm_margin, + self.zero_centered_gamma, + torch.is_grad_enabled(), + self.activation_dtype, + ) + + return fwd_fn(*args) diff --git a/cosmos_predict1/autoregressive/utils/sampling.py b/cosmos_predict1/autoregressive/utils/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..46d364c3e4b3f84f2f29baa4551f8e17a2f601ad --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/sampling.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from torch.nn.attention import SDPBackend, sdpa_kernel + +from cosmos_predict1.autoregressive.networks.transformer import Transformer + + +def sample_top_p(logits, temperature, top_p, return_probs: bool = False): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + logits (torch.Tensor): Logits of the probability distribution. + temperature (float): Temperature for sampling. + top_p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1) + # Sort the probabilities in descending order and get their indices. + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + # Compute the cumulative sum of the sorted probabilities. + probs_sum = torch.cumsum(probs_sort, dim=-1) + # Create a mask where the cumulative probability exceeds the threshold p. + mask = probs_sum - probs_sort > top_p + # Set the probabilities that exceed the threshold to 0. + probs_sort[mask] = 0.0 + # Renormalize the remaining probabilities so they sum to 1. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + # Sample from the renormalized probability distribution. + # next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64) + # Gather the indices of the sampled tokens. + next_token = torch.gather(probs_idx, -1, next_token) + if return_probs: + # Initialize a tensor for unsorted probabilities + probs_unsorted = torch.zeros_like(probs_sort) + # Scatter the sorted probabilities back to their original order + probs_unsorted.scatter_(-1, probs_idx, probs_sort) + else: + probs_unsorted = None + return next_token, probs_unsorted + + +def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int): + """ + Multinomial sampling without a cuda synchronization. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype) + + +def logits_to_probs( + logits, + temperature: float = 1.0, + top_k: Optional[int] = None, +): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None): + """ + Sample from the logits using top-k sampling. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + # logits: [batch_size, seq_len, vocab_size] + if temperature == 0.0: + idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) + probs = None + else: + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, + input_pos: torch.Tensor, + tokens: torch.Tensor = None, + token_embeddings: torch.Tensor = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> torch.Tensor: + logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs) + # Only top-p or top-k can be provided + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p)[0] + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k)[0] + + +def decode_one_token( + model: Transformer, + tokens: torch.Tensor, + input_pos: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decode a single token from the autoregressive model. + """ + logits = model(tokens=tokens, input_pos=input_pos, **kwargs) + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p) + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + stop_tokens: torch.Tensor = None, + temperature: float = 1.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + return_probs: bool = False, + decode_one_token_function=decode_one_token, + **kwargs, +): + """ + Decode n tokens from the autoregressive model. + Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + new_tokens, new_probs = [], [] + batch_size = cur_token.shape[0] + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if stop_tokens is not None: + # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch + eos_reached = torch.tensor([False] * batch_size, device="cuda") + for t in range(num_new_tokens): + with sdpa_kernel([SDPBackend.MATH]): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token_function( + model, + tokens=cur_token, + input_pos=input_pos, + temperature=temperature, + top_k=top_k, + top_p=top_p, + **kwargs, + ) + input_pos += 1 + if stop_tokens is not None and len(stop_tokens) > 0: + eos_reached = eos_reached | (torch.isin(next_token, stop_tokens)) + if eos_reached.all(): + break + new_tokens.append(next_token.clone()) + if return_probs: + new_probs.append(next_prob.clone()) + cur_token = next_token.clone() + + if return_probs: + return new_tokens, new_probs + else: + return new_tokens diff --git a/cosmos_predict1/auxiliary/guardrail/__init__.py b/cosmos_predict1/auxiliary/guardrail/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py b/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py b/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py new file mode 100644 index 0000000000000000000000000000000000000000..ade12a7cf8588768511aaa6f282850c0c2252d1d --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +from cosmos_predict1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.utils import log, misc + +SAFE = misc.Color.green("SAFE") +UNSAFE = misc.Color.red("UNSAFE") + + +class Aegis(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = torch.bfloat16 + + base_model_id = "meta-llama/LlamaGuard-7b" + aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" + base_model_dir = os.path.join(self.checkpoint_dir, base_model_id) + aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter) + + base_model = AutoModelForCausalLM.from_pretrained(base_model_dir) + self.tokenizer = AutoTokenizer.from_pretrained(base_model_dir) + self.model = PeftModel.from_pretrained(base_model, aegis_adapter_dir) + + self.model.to(self.device, dtype=self.dtype).eval() + + def get_moderation_prompt(self, user_prompt: str) -> str: + """Create the moderation prompt for the Aegis model.""" + unsafe_categories = "".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES]) + full_prompt = ( + "[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n" + "\n" + f"{unsafe_categories}\n" + "\n" + "\n" + ) + full_prompt += f"User: {user_prompt}\n\n" + full_prompt += """ + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" + + return full_prompt + + def get_aegis_block_message(self, moderation_output: str) -> str: + """Extract the blocked category and reason from the Aegis model output.""" + block_msg = "Prompt blocked by Aegis." + try: + # Extract blocked category and reason + if len(moderation_output.split("\n")) >= 2: + moderation_output_list = moderation_output.split("\n") + blocked_category = int(moderation_output_list[1].lstrip("O")) - 1 + if 0 <= blocked_category < len(UNSAFE_CATEGORIES): + blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"] + blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"] + block_msg = f"{blocked_category_name}: {blocked_reason}" + except Exception as e: + log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}") + return block_msg + + def filter_aegis_output(self, prompt: str) -> tuple[bool, str]: + """Filter the Aegis model output and return the safety status and message.""" + full_prompt = self.get_moderation_prompt(prompt) + inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device) + output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id) + prompt_len = inputs["input_ids"].shape[-1] + moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) + + if "unsafe" in moderation_output.lower(): + block_msg = self.get_aegis_block_message(moderation_output) + return False, block_msg + else: + return True, "" + + def is_safe(self, prompt: str) -> tuple[bool, str]: + """Check if the input prompt is safe according to the Aegis model.""" + try: + return self.filter_aegis_output(prompt) + except Exception as e: + log.error(f"Unexpected error occurred when running Aegis guardrail: {e}") + return True, "Unexpected error occurred when running Aegis guardrail." + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Aegis checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + aegis = Aegis(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[aegis]) + with misc.timer("aegis safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/categories.py b/cosmos_predict1/auxiliary/guardrail/aegis/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..f9236da33c8d2669b701cc21709925bacbbe6907 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/categories.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +VIOLENCE = { + "category": "violence", + "blocked_reason": """This prompt is blocked due to the presence of violent content.""", + "instruction": """O1: Violence. +Should not +- Help people plan or engage in violence. +- Encourage people to engage in violence. +- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader. +Can +- Provide information on violence. +- Discuss the topics of violence, and can discuss historical events involving violence. +""", +} + +SEXUAL = { + "category": "sexual", + "blocked_reason": """This prompt is blocked due to the presence of sexual content.""", + "instruction": """O2: Sexual. +Should not +- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +- Engage or encourage sexual activity that is considered illegal. +Can +- Discuss topics related to sexuality and sexual education. +- Provide general information regarding sexual activities. +- Provide guidance on topics related to sexual health. +""", +} + +CRIMINAL_PLANNING = { + "category": "criminal_planning", + "blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""", + "instruction": """O3: Criminal Planning/Confessions. +Should not +- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in criminal activities or condone criminal behavior. +- Have a confession of crime that was committed or will be committed in future. +Can +- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. +- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. +""", +} + +WEAPONS = { + "category": "weapons", + "blocked_reason": """This prompt is blocked due to the presence of weapons content.""", + "instruction": """O4: Guns and Illegal Weapons. +Should not +- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them. +Can +- Discuss firearms and the arguments for and against firearm ownership. +""", +} + +SUBSTANCE_ABUSE = { + "category": "substance_abuse", + "blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""", + "instruction": """O5: Controlled/Regulated Substances. +Should not +- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. +- Assist or encourage people to create such substances when it is illegal to do so. +Can +- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). +- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. +""", +} + +SUICIDE = { + "category": "suicide", + "blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.", + "instruction": """O6: Suicide and Self Harm. +Should not +- Encourage people to harm themselves. +- Romanticize or condone self-harm or suicide. +- Provide information regarding the methods of suicide or self-harm. +- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +- Contain confessions or intent of self harm +Can +- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should +- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. +""", +} + +CSAM = { + "category": "csam", + "blocked_reason": "This prompt contains material related to Child Sexual Abuse", + "instruction": """O7: Sexual (minor). +Should not +- Assist, engage or encourage intention to engage in sexual activity with an underage/minor. +- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor. +- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media. +- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children. +Can +- Provide information about sex education for use in schools. +""", +} + +HATRED = { + "category": "hatred", + "blocked_reason": "This prompt is blocked due to the presence of hate speech content.", + "instruction": """O8: Hate/Identity Hate. +Should not +- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their +race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste. +- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country. +- Should not contain slurs. +- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear. +Can +- Provide information on discrimination. +- Discuss the topics of discrimination, and can discuss historical events involving discrimination. +- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate). +""", +} + +HARASSMENT = { + "category": "harassment", + "blocked_reason": "This prompt is blocked due to the presence of harassment content.", + "instruction": """O10: Harassment. +Should not +- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur. +- Engage in harassment, misconduct and unlawful discrimination +- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex, +Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled), +Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation, +- Taking or requesting statutorily protected leave Body characteristics. +- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct. +- Bullying. +- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law. +- Making rude or insulting comments, demeaning, and objectifying terms toward an individual. +Can +- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting. +- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are. +- Provide information or refer or cite material that informs what harassment is. +""", +} + +THREAT = { + "category": "threat", + "blocked_reason": "This prompt is blocked due to the presence of threat content.", + "instruction": """O11: Threat. +Should not +- Discuss a communication of intent to inflict harm or loss on another person. +- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control. +- Should not contain intention of knowingly putting another person in fear of bodily injury. +- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person. +- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will. +Can +- Provide information about threat is and the meaning of threat. +""", +} + +PROFANITY = { + "category": "profanity", + "blocked_reason": "This prompt is blocked due to the presence of profanity content.", + "instruction": """O12: Profanity. +Should not +- Contain swear words, curse words, or other obscene or profane language +- This includes offensive words used without any intention to act on them or without any explicit targeted recipient +""", +} + +UNSAFE_CATEGORIES = [ + VIOLENCE, + SEXUAL, + CRIMINAL_PLANNING, + WEAPONS, + SUBSTANCE_ABUSE, + SUICIDE, + CSAM, + HATRED, + HARASSMENT, + THREAT, + PROFANITY, +] diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py b/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py b/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py new file mode 100644 index 0000000000000000000000000000000000000000..d0fa7aafc609e69677f2ffb319fa11eb9a689fe5 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +import string +from difflib import SequenceMatcher + +import nltk +from better_profanity import profanity + +from cosmos_predict1.auxiliary.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.utils import log, misc + +CENSOR = misc.Color.red("*") + + +class Blocklist(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + guardrail_partial_match_min_chars: int = 6, + guardrail_partial_match_letter_count: float = 0.4, + ) -> None: + self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/blocklist") + nltk.data.path.append(os.path.join(self.checkpoint_dir, "nltk_data")) + self.lemmatizer = nltk.WordNetLemmatizer() + self.profanity = profanity + self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars + self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count + + # Load blocklist and whitelist keywords + self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom")) + self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist")) + self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match")) + + self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words) + log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist") + log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist") + log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist") + + def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str: + """Explicitly uncensor words that are in the whitelist.""" + input_words = input_prompt.split() + censored_words = censored_prompt.split() + whitelist_words = set(self.whitelist_words) + for i, token in enumerate(input_words): + if token.strip(string.punctuation).lower() in whitelist_words: + censored_words[i] = token + censored_prompt = " ".join(censored_words) + return censored_prompt + + def censor_prompt(self, input_prompt: str) -> tuple[bool, str]: + """Censor the prompt using the blocklist with better-profanity fuzzy matching. + + Args: + input_prompt: input prompt to censor + + Returns: + bool: True if the prompt is blocked, False otherwise + str: A message indicating why the prompt was blocked + """ + censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR) + # Uncensor whitelisted words that were censored from blocklist fuzzy matching + censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt) + if CENSOR in censored_prompt: + return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}" + return False, "" + + @staticmethod + def check_partial_match( + normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float + ) -> tuple[bool, str]: + """ + Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters. + + Args: + normalized_prompt: a string with many words + normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt + guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters) + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + prompt_words = normalized_prompt.split() + word_length = len(normalized_word.split()) + max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float( + len(normalized_word) + ) + + for i in range(len(prompt_words) - word_length + 1): + # Extract a substring from the prompt with the same number of words as the normalized_word + substring = " ".join(prompt_words[i : i + word_length]) + similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio() + if similarity_ratio >= max_similarity_ratio: + return ( + True, + f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}", + ) + + return False, "" + + @staticmethod + def check_against_whole_word_blocklist( + prompt: str, + blocklist: list[str], + guardrail_partial_match_min_chars: int = 6, + guardrail_partial_match_letter_count: float = 0.4, + ) -> bool: + """ + Check if the prompt contains any whole words from the blocklist. + The match is case insensitive and robust to multiple spaces between words. + + Args: + prompt: input prompt to check + blocklist: list of words to check against + guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match + guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + # Normalize spaces and convert to lowercase + normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower() + + for word in blocklist: + # Normalize spaces and convert to lowercase for each blocklist word + normalized_word = re.sub(r"\s+", " ", word).strip().lower() + + # Use word boundaries to ensure whole word match + if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt): + return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}" + + # Check for partial match if the word is long enough + if len(normalized_word) >= guardrail_partial_match_min_chars: + match, message = Blocklist.check_partial_match( + normalized_prompt, normalized_word, guardrail_partial_match_letter_count + ) + if match: + return True, message + + return False, "" + + def is_safe(self, input_prompt: str = "") -> tuple[bool, str]: + """Check if the input prompt is safe using the blocklist.""" + # Check if the input is empty + if not input_prompt: + return False, "Input is empty" + input_prompt = to_ascii(input_prompt) + + # Check full sentence for censored words + censored, message = self.censor_prompt(input_prompt) + if censored: + return False, message + + # Check lemmatized words for censored words + tokens = nltk.word_tokenize(input_prompt) + lemmas = [self.lemmatizer.lemmatize(token) for token in tokens] + lemmatized_prompt = " ".join(lemmas) + censored, message = self.censor_prompt(lemmatized_prompt) + if censored: + return False, message + + # Check for exact match blocklist words + censored, message = self.check_against_whole_word_blocklist( + input_prompt, + self.exact_match_words, + self.guardrail_partial_match_min_chars, + self.guardrail_partial_match_letter_count, + ) + if censored: + return False, message + + # If all these checks pass, the input is safe + return True, "Input is safe" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Blocklist checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[blocklist]) + with misc.timer("blocklist safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py b/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8555af872b03dd8a9dad0dd2699550bdcdd5b1 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from cosmos_predict1.utils import log + + +def read_keyword_list_from_dir(folder_path: str) -> list[str]: + """Read keyword list from all files in a folder.""" + output_list = [] + file_list = [] + # Get list of files in the folder + for file in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, file)): + file_list.append(file) + + # Process each file + for file in file_list: + file_path = os.path.join(folder_path, file) + try: + with open(file_path, "r") as f: + output_list.extend([line.strip() for line in f.readlines()]) + except Exception as e: + log.error(f"Error reading file {file}: {str(e)}") + + return output_list + + +def to_ascii(prompt: str) -> str: + """Convert prompt to ASCII.""" + return re.sub(r"[^\x00-\x7F]+", " ", prompt) diff --git a/cosmos_predict1/auxiliary/guardrail/common/__init__.py b/cosmos_predict1/auxiliary/guardrail/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/auxiliary/guardrail/common/core.py b/cosmos_predict1/auxiliary/guardrail/common/core.py new file mode 100644 index 0000000000000000000000000000000000000000..f4deeaa2ca0eb99d8b778665963221365ad3927d --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/core.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +import numpy as np + +from cosmos_predict1.utils import log + + +class ContentSafetyGuardrail: + def is_safe(self, **kwargs) -> Tuple[bool, str]: + raise NotImplementedError("Child classes must implement the is_safe method") + + +class PostprocessingGuardrail: + def postprocess(self, frames: np.ndarray) -> np.ndarray: + raise NotImplementedError("Child classes must implement the postprocess method") + + +class GuardrailRunner: + def __init__( + self, + safety_models: list[ContentSafetyGuardrail] | None = None, + generic_block_msg: str = "", + generic_safe_msg: str = "", + postprocessors: list[PostprocessingGuardrail] | None = None, + ): + self.safety_models = safety_models + self.generic_block_msg = generic_block_msg + self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe" + self.postprocessors = postprocessors + + def run_safety_check(self, input: Any) -> Tuple[bool, str]: + """Run the safety check on the input.""" + if not self.safety_models: + log.warning("No safety models found, returning safe") + return True, self.generic_safe_msg + + for guardrail in self.safety_models: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + safe, message = guardrail.is_safe(input) + if not safe: + reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}" + return False, reasoning + return True, self.generic_safe_msg + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Run the postprocessing on the video frames.""" + if not self.postprocessors: + log.warning("No postprocessors found, returning original frames") + return frames + + for guardrail in self.postprocessors: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + frames = guardrail.postprocess(frames) + return frames diff --git a/cosmos_predict1/auxiliary/guardrail/common/io_utils.py b/cosmos_predict1/auxiliary/guardrail/common/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..129f049233191368a6dee4ef202088fdd851e3e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/io_utils.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +from dataclasses import dataclass + +import imageio +import numpy as np + +from cosmos_predict1.utils import log + + +@dataclass +class VideoData: + frames: np.ndarray # Shape: [B, H, W, C] + fps: int + duration: int # in seconds + + +def get_video_filepaths(input_dir: str) -> list[str]: + """Get a list of filepaths for all videos in the input directory.""" + paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True) + paths = sorted(paths) + log.debug(f"Found {len(paths)} videos") + return paths + + +def read_video(filepath: str) -> VideoData: + """Read a video file and extract its frames and metadata.""" + try: + reader = imageio.get_reader(filepath, "ffmpeg") + except Exception as e: + raise ValueError(f"Failed to read video file: {filepath}") from e + + # Extract metadata from the video file + try: + metadata = reader.get_meta_data() + fps = metadata.get("fps") + duration = metadata.get("duration") + except Exception as e: + reader.close() + raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e + + # Extract frames from the video file + try: + frames = np.array([frame for frame in reader]) + except Exception as e: + raise ValueError(f"Failed to extract frames from video file: {filepath}") from e + finally: + reader.close() + + return VideoData(frames=frames, fps=fps, duration=duration) + + +def save_video(filepath: str, frames: np.ndarray, fps: int) -> None: + """Save a video file from a sequence of frames.""" + try: + writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1) + for frame in frames: + writer.append_data(frame) + except Exception as e: + raise ValueError(f"Failed to save video file to {filepath}") from e + finally: + writer.close() diff --git a/cosmos_predict1/auxiliary/guardrail/common/presets.py b/cosmos_predict1/auxiliary/guardrail/common/presets.py new file mode 100644 index 0000000000000000000000000000000000000000..245fe445496e9c265742023f3c61ece7d82ee49e --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/presets.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +import numpy as np + +from cosmos_predict1.auxiliary.guardrail.blocklist.blocklist import Blocklist +from cosmos_predict1.auxiliary.guardrail.common.core import GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter +from cosmos_predict1.auxiliary.guardrail.llamaGuard3.llamaGuard3 import LlamaGuard3 +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import ( + VideoContentSafetyFilter, +) +from cosmos_predict1.utils import log + + +def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the text guardrail runner.""" + return GuardrailRunner(safety_models=[Blocklist(checkpoint_dir), LlamaGuard3(checkpoint_dir)]) + + +def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the video guardrail runner.""" + return GuardrailRunner( + safety_models=[VideoContentSafetyFilter(checkpoint_dir)], + postprocessors=[RetinaFaceFilter(checkpoint_dir)], + ) + + +def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool: + """Run the text guardrail on the prompt, checking for content safety. + + Args: + prompt: The text prompt. + guardrail_runner: The text guardrail runner. + + Returns: + bool: Whether the prompt is safe. + """ + is_safe, message = guardrail_runner.run_safety_check(prompt) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return is_safe + + +def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None: + """Run the video guardrail on the frames, checking for content safety and applying face blur. + + Args: + frames: The frames of the generated video. + guardrail_runner: The video guardrail runner. + + Returns: + The processed frames if safe, otherwise None. + """ + is_safe, message = guardrail_runner.run_safety_check(frames) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return None + + frames = guardrail_runner.postprocess(frames) + return frames diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d52f69d220444a53027b3b4acc3bd192fc6eb76f --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: + """ + Pixelate a face region by reducing resolution and then upscaling. + + Args: + face_img: Face region to pixelate + blocks: Number of blocks to divide the face into (in each dimension) + + Returns: + Pixelated face region + """ + h, w = face_img.shape[:2] + # Shrink the image and scale back up to create pixelation effect + temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) + pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) + return pixelated diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..eb26b9c4d0b0dca930336487c11f26373b6ab293 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import warnings + +import numpy as np +import torch +from retinaface.data import cfg_re50 +from retinaface.layers.functions.prior_box import PriorBox +from retinaface.models.retinaface import RetinaFace +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from cosmos_predict1.auxiliary.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail +from cosmos_predict1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video, save_video +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.blur_utils import pixelate_face +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.retinaface_utils import ( + decode_batch, + filter_detected_boxes, + load_model, +) +from cosmos_predict1.utils import log, misc + +# RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +TOP_K = 5_000 +KEEP_TOP_K = 750 +NMS_THRESHOLD = 0.4 + + +class RetinaFaceFilter(PostprocessingGuardrail): + def __init__( + self, + checkpoint_dir: str, + batch_size: int = 1, + confidence_threshold: float = 0.7, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + """ + Initialize the RetinaFace model for face detection and blurring. + + Args: + checkpoint: Path to the RetinaFace checkpoint file + batch_size: Batch size for RetinaFace inference and processing + confidence_threshold: Minimum confidence score to consider a face detection + """ + self.checkpoint = f"{checkpoint_dir}/nvidia/Cosmos-Guardrail1/face_blur_filter/Resnet50_Final.pth" + self.cfg = cfg_re50 + self.batch_size = batch_size + self.confidence_threshold = confidence_threshold + self.device = device + self.dtype = torch.float32 + + # Disable loading ResNet pretrained weights + self.cfg["pretrain"] = False + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.net = RetinaFace(cfg=self.cfg, phase="test") + cpu = self.device == "cpu" + + # Load from RetinaFace pretrained checkpoint + self.net = load_model(self.net, self.checkpoint, cpu) + self.net.to(self.device, dtype=self.dtype).eval() + + def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor: + """Preprocess a sequence of frames for face detection. + + Args: + frames: Input frames + + Returns: + Preprocessed frames tensor + """ + with torch.no_grad(): + frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C] + frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W] + frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input + means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1) + frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel + return frames_tensor + + def blur_detected_faces( + self, + frames: np.ndarray, + batch_loc: torch.Tensor, + batch_conf: torch.Tensor, + prior_data: torch.Tensor, + scale: torch.Tensor, + min_size: tuple[int] = (20, 20), + ) -> list[np.ndarray]: + """Blur detected faces in a batch of frames using RetinaFace predictions. + + Args: + frames: Input frames + batch_loc: Batched location predictions + batch_conf: Batched confidence scores + prior_data: Prior boxes for the video + scale: Scale factor for resizing detections + min_size: Minimum size of a detected face region in pixels + + Returns: + Processed frames with pixelated faces + """ + with torch.no_grad(): + batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"]) + batch_boxes = batch_boxes * scale + + blurred_frames = [] + for i, boxes in enumerate(batch_boxes): + boxes = boxes.detach().cpu().numpy() + scores = batch_conf[i, :, 1].detach().cpu().numpy() + + filtered_boxes = filter_detected_boxes( + boxes, + scores, + confidence_threshold=self.confidence_threshold, + nms_threshold=NMS_THRESHOLD, + top_k=TOP_K, + keep_top_k=KEEP_TOP_K, + ) + + frame = frames[i] + for box in filtered_boxes: + x1, y1, x2, y2 = map(int, box) + # Ignore bounding boxes smaller than the minimum size + if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]: + continue + max_h, max_w = frame.shape[:2] + face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] + blurred_face = pixelate_face(face_roi) + frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face + blurred_frames.append(frame) + + return blurred_frames + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Blur faces in a sequence of frames. + + Args: + frames: Input frames + + Returns: + Processed frames with pixelated faces + """ + # Create dataset and dataloader + frames_tensor = self.preprocess_frames(frames) + dataset = TensorDataset(frames_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + processed_frames, processed_batches = [], [] + + prior_data, scale = None, None + for i, batch in enumerate(dataloader): + batch = batch[0] + h, w = batch.shape[-2:] # Batch shape: [C, H, W] + + with torch.no_grad(): + # Generate priors for the video + if prior_data is None: + priorbox = PriorBox(self.cfg, image_size=(h, w)) + priors = priorbox.forward() + priors = priors.to(self.device, dtype=self.dtype) + prior_data = priors.data + + # Get scale for resizing detections + if scale is None: + scale = torch.Tensor([w, h, w, h]) + scale = scale.to(self.device, dtype=self.dtype) + + batch_loc, batch_conf, _ = self.net(batch) + + # Blur detected faces in each batch of frames + start_idx = i * self.batch_size + end_idx = min(start_idx + self.batch_size, len(frames)) + processed_batches.append( + self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale) + ) + + processed_frames = [frame for batch in processed_batches for frame in batch] + return np.array(processed_frames) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos") + parser.add_argument( + "--checkpoint-dir", + type=str, + help="Path to the RetinaFace checkpoint file", + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + face_blur = RetinaFaceFilter(checkpoint_dir=args.checkpoint) + postprocessing_runner = GuardrailRunner(postprocessors=[face_blur]) + os.makedirs(args.output_dir, exist_ok=True) + + for filepath in tqdm(filepaths): + video_data = read_video(filepath) + with misc.timer("face blur filter"): + frames = postprocessing_runner.postprocess(video_data.frames) + + output_path = os.path.join(args.output_dir, os.path.basename(filepath)) + save_video(output_path, frames, video_data.fps) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea87e5818d643fb13bf59950c40c24d0cf36acf --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from retinaface.utils.nms.py_cpu_nms import py_cpu_nms + +from cosmos_predict1.utils import log + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k): + """Filter boxes based on confidence score and remove overlapping boxes using NMS.""" + # Keep detections with confidence above threshold + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # Sort by confidence and keep top K detections + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # Run non-maximum-suppression (NMS) to remove overlapping boxes + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + dets = dets[:keep_top_k, :] + boxes = dets[:, :-1] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs +def decode_batch(loc, priors, variances): + """Decode batched locations from predictions using priors and variances. + + Args: + loc (tensor): Batched location predictions for loc layers. + Shape: [batch_size, num_priors, 4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors, 4] + variances: (list[float]): Variances of prior boxes. + + Return: + Decoded batched bounding box predictions + Shape: [batch_size, num_priors, 4] + """ + batch_size = loc.size(0) + priors = priors.unsqueeze(0).expand(batch_size, -1, -1) + + boxes = torch.cat( + ( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]), + ), + dim=2, + ) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + log.debug("Missing keys:{}".format(len(missing_keys))) + log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys))) + log.debug("Used keys:{}".format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" + return True + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _remove_prefix(state_dict, prefix): + """Old version of the model is stored with all names of parameters sharing common prefix 'module.'""" + log.debug("Removing prefix '{}'".format(prefix)) + + def f(x): + return x.split(prefix, 1)[-1] if x.startswith(prefix) else x + + return {f(key): value for key, value in state_dict.items()} + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def load_model(model, pretrained_path, load_to_cpu): + log.debug("Loading pretrained model from {}".format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load( + pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True + ) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.") + else: + pretrained_dict = _remove_prefix(pretrained_dict, "module.") + _check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d5a95d4dce1202e3acec0e10177c97c1e5924e --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +UNSAFE_CATEGORIES = { + "S1": "Violent Crimes.", + "S2": "Non-Violent Crimes.", + "S3": "Sex Crimes.", + "S4": "Child Exploitation.", + "S5": "Defamation.", + "S6": "Specialized Advice.", + "S7": "Privacy.", + "S8": "Intellectual Property.", + "S9": "Indiscriminate Weapons.", + "S10": "Hate.", + "S11": "Self-Harm.", + "S12": "Sexual Content.", + "S13": "Elections.", + "s14": "Code Interpreter Abuse.", +} diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py new file mode 100644 index 0000000000000000000000000000000000000000..cae384738ed228bf4455d95c3c6fe011d790c01f --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.llamaGuard3.categories import UNSAFE_CATEGORIES +from cosmos_predict1.utils import log, misc + +SAFE = misc.Color.green("SAFE") +UNSAFE = misc.Color.red("UNSAFE") + + +class LlamaGuard3(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = torch.bfloat16 + + model_id = "meta-llama/Llama-Guard-3-8B" + model_dir = os.path.join(self.checkpoint_dir, model_id) + + self.model = AutoModelForCausalLM.from_pretrained(model_dir) + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + + self.model.to(self.device, dtype=self.dtype).eval() + + def get_llamaGuard3_block_message(self, moderation_output: str) -> str: + """Extract the blocked category from the Llama Guard 3 model output.""" + block_msg = "Prompt blocked by Llama Guard 3." + try: + lines = moderation_output.splitlines() + categories_detected = [] + for line in lines[1:]: + line_stripped = line.split("<|eot_id|>")[0].strip() + for catagory in line_stripped.split(","): + catagory = catagory.strip() + if catagory not in UNSAFE_CATEGORIES: + log.warning(f"Unrecognized category from moderation output: {catagory}") + else: + categories_detected.append(catagory) + if len(categories_detected) > 0: + blocked_catagories = ", ".join([UNSAFE_CATEGORIES[catagory][:-1] for catagory in categories_detected]) + block_msg = f"{block_msg} Violations: {blocked_catagories}." + except Exception as e: + log.warning(f"Unable to extract blocked category from Llama Guard 3 output: {e}") + return block_msg + + def filter_llamaGuard3_output(self, prompt: str) -> tuple[bool, str]: + """Filter the Llama Guard 3 model output and return the safety status and message.""" + conversation = [{"role": "user", "content": prompt}] + input_ids = self.tokenizer.apply_chat_template( + conversation, categories=UNSAFE_CATEGORIES, return_tensors="pt" + ).to("cuda") + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=100, + return_dict_in_generate=True, + pad_token_id=0, + ) + generated_tokens = output.sequences[:, prompt_len:] + moderation_output = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=False).strip() + + if "unsafe" in moderation_output.lower(): + block_msg = self.get_llamaGuard3_block_message(moderation_output) + return False, block_msg + else: + return True, "" + + def is_safe(self, prompt: str) -> tuple[bool, str]: + """Check if the input prompt is safe according to the Llama Guard 3 model.""" + try: + return self.filter_llamaGuard3_output(prompt) + except Exception as e: + log.error(f"Unexpected error occurred when running Llama Guard 3 guardrail: {e}") + return True, "Unexpected error occurred when running Llama Guard 3 guardrail." + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Llama Guard 3 checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + llamaGuard3 = LlamaGuard3(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[llamaGuard3]) + with misc.timer("Llama Guard 3 safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6dabaa352257260cb6f6462e86f4d966d1b67118 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attrs +import torch +import torch.nn as nn + +from cosmos_predict1.utils.config import make_freezable + + +@make_freezable +@attrs.define(slots=False) +class ModelConfig: + input_size: int = 1152 + num_classes: int = 7 + + +class SafetyClassifier(nn.Module): + def __init__(self, input_size: int = 1024, num_classes: int = 2): + super().__init__() + self.input_size = input_size + self.num_classes = num_classes + self.layers = nn.Sequential( + nn.Linear(self.input_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, self.num_classes), + # Note: No activation function here; CrossEntropyLoss expects raw logits + ) + + def forward(self, x): + return self.layers(x) + + +class VideoSafetyModel(nn.Module): + def __init__(self, config: ModelConfig) -> None: + super().__init__() + self.config = config + self.num_classes = config.num_classes + self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes) + + @torch.inference_mode() + def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + logits = self.network(data_batch["data"].cuda()) + return {"logits": logits} diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..6bccb8dad67e1baf2649d6c6d83f29e5a09f8445 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from typing import Iterable, Tuple, Union + +import torch +from PIL import Image + +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder +from cosmos_predict1.utils import log, misc + +# Define the class index to class name mapping for multi-class classification +CLASS_IDX_TO_NAME = { + 0: "Safe", + 1: "Sexual_Content", + 2: "Violence", + 3: "Drugs", + 4: "Child_Abuse", + 5: "Hate_and_Harassment", + 6: "Self-Harm", +} + + +class VideoContentSafetyFilter(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.device = device + self.dtype = torch.float32 + + # Initialize the SigLIP encoder + self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/video_content_safety_filter") + self.encoder = SigLIPEncoder(checkpoint_dir=self.checkpoint_dir, device=device, dtype=self.dtype) + + # Use ModelConfig directly for inference configuration + model_config = ModelConfig(input_size=1152, num_classes=7) + + # Load the multi-class classifier + self.model = VideoSafetyModel(model_config) + safety_filter_local_path = os.path.join(self.checkpoint_dir, "safety_filter.pt") + checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True) + self.model.load_state_dict(checkpoint["model"]) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def __infer(self, pil_image: Image.Image) -> int: + """Infer the class of the image.""" + image_embs = self.encoder.encode_image(pil_image) + logits = self.model.network(image_embs) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + predicted_class = torch.argmax(probabilities, dim=-1).item() + return predicted_class + + def is_safe_file(self, filepath: str) -> bool: + """Check if the video file is safe.""" + video_data = read_video(filepath) + + # Sample frames at 2 FPS + sample_rate = 2 # frames per second + frame_interval = int(video_data.fps / sample_rate) + frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval)) + + is_safe = True + frame_scores = [] + + for frame_number in frame_numbers: + try: + frame = video_data.frames[frame_number] + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark the video as unsafe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + # Prepare data for JSON + video_data = { + "filepath": filepath, + "is_safe": is_safe, + "video_length": video_data.duration, + "fps": video_data.fps, + "frame_scores": frame_scores, + } + + log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.") + log.debug(f"Video data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe_frames(self, frames: Iterable) -> bool: + """Check if the video frames are safe.""" + is_safe = True + frame_scores = [] + + for frame_number, frame in enumerate(frames): + try: + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark as not safe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + video_data = { + "is_safe": is_safe, + "frame_scores": frame_scores, + } + + log.debug(f"Frames data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]: + if isinstance(input, str): + is_safe = self.is_safe_file(input) + return is_safe, "safe video detected" if is_safe else "unsafe video detected" + elif isinstance(input, Iterable): + is_safe = self.is_safe_frames(input) + return is_safe, "safe frames detected" if is_safe else "unsafe frames detected" + else: + raise ValueError(f"Input type {type(input)} not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Video Content Safety Filter checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe") + + for filepath in filepaths: + with misc.timer("video content safety filter"): + _ = runner.run_safety_check(filepath) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6c7b45422501fec96bd1e711509ead5efa019a --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from PIL import Image +from transformers import SiglipModel, SiglipProcessor + + +class SigLIPEncoder(torch.nn.Module): + def __init__( + self, + checkpoint_dir: str, + model_name: str = "google/siglip-so400m-patch14-384", + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, + ) -> None: + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = dtype + self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def encode_image(self, input_img: Image.Image) -> torch.Tensor: + """Encode an image into a feature vector.""" + with torch.no_grad(): + inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype) + image_features = self.model.get_image_features(**inputs) + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features diff --git a/cosmos_predict1/auxiliary/t5_text_encoder.py b/cosmos_predict1/auxiliary/t5_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..61a56073c57cca6b2f1e6a52f4fb853c5eee585c --- /dev/null +++ b/cosmos_predict1/auxiliary/t5_text_encoder.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple, Union + +import torch +import transformers +from transformers import T5EncoderModel, T5TokenizerFast + +from cosmos_predict1.utils import log + +transformers.logging.set_verbosity_error() + + +class CosmosT5TextEncoder(torch.nn.Module): + """Handles T5 text encoding operations.""" + + def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"): + """Initializes the T5 tokenizer and encoder. + + Args: + model_name: The name of the T5 model to use. + device: The device to use for computations. + """ + super().__init__() + try: + self.tokenizer = T5TokenizerFast.from_pretrained(cache_dir, cache_dir=cache_dir) + self.text_encoder = T5EncoderModel.from_pretrained(cache_dir, cache_dir=cache_dir).to(device) + except Exception as e: + log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}") + self.tokenizer = T5TokenizerFast.from_pretrained(model_name) + self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device) + self.text_encoder.eval() + self.device = device + + @torch.inference_mode() + def encode_prompts( + self, prompts: Union[str, List[str]], max_length: int = 512 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encodes text prompts into hidden state representations using a T5 encoder. + + This function tokenizes the input prompts, processes them through a T5 text encoder, + and returns the last hidden states. The encoded outputs beyond the actual sequence + length are zero-padded. All prompts in a batch are padded to max_length. + + Args: + prompts: Input text to encode. Can be a single string or a list of strings. + max_length: Maximum sequence length for tokenization and padding. Longer + sequences will be truncated. Defaults to 512. + return_mask: If True, returns the attention mask along with encoded text. + Defaults to False. + + Returns: + If return_mask is False: + torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size). + If return_mask is True: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Encoded text embeddings of shape (batch_size, max_length, hidden_size) + - Attention mask of shape (batch_size, max_length) as boolean tensor + + Raises: + ValueError: If the input prompts list is empty. + + Example: + >>> encoder = CosmosT5TextEncoder() + >>> prompts = ["Hello world", "Another example"] + >>> embeddings = encoder.encode_prompts(prompts, max_length=128) + """ + if isinstance(prompts, str): + prompts = [prompts] + + if not prompts: + raise ValueError("The input prompt list is empty.") + + batch_encoding = self.tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + input_ids = batch_encoding.input_ids.to(self.device) + attn_mask = batch_encoding.attention_mask.to(self.device) + + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask) + + encoded_text = outputs.last_hidden_state + lengths = attn_mask.sum(dim=1).cpu() + + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + return encoded_text, attn_mask + + +class DummyT5TextEncoder(torch.nn.Module): + def __init__(self, device: str = "cuda"): + super().__init__() + self.device = device + + @torch.inference_mode() + def encode_prompts( + self, prompts: Union[str, List[str]], max_length: int = 512 + ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(prompts, str): + prompts = [prompts] + + if not prompts: + raise ValueError("The input prompt list is empty.") + + batch_size = len(prompts) + + dummy_text_embedding = torch.zeros(batch_size, max_length, 1024, device=self.device) + dummy_text_mask = torch.zeros(batch_size, max_length, device=self.device, dtype=torch.bool) + dummy_text_mask[0] = True + + return dummy_text_embedding, dummy_text_mask diff --git a/cosmos_predict1/callbacks/every_n.py b/cosmos_predict1/callbacks/every_n.py new file mode 100644 index 0000000000000000000000000000000000000000..25cab309a58336867ed5fc58849e71db7611d0f3 --- /dev/null +++ b/cosmos_predict1/callbacks/every_n.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.callback import Callback +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class EveryN(Callback): + def __init__( + self, + every_n: Optional[int] = None, + step_size: int = 1, + barrier_after_run: bool = True, + run_at_start: bool = False, + ) -> None: + """Constructor for `EveryN`. + + Args: + every_n (int): Frequency with which callback is run during training. + step_size (int): Size of iteration step count. Default 1. + barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. + run_at_start (bool): Whether to run at the beginning of training. Default False. + """ + self.every_n = every_n + if self.every_n == 0: + log.warning( + f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." + ) + + self.step_size = step_size + self.barrier_after_run = barrier_after_run + self.run_at_start = run_at_start + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training + if self.every_n != 0: + trainer = self.trainer + global_step = iteration // self.step_size + should_run = (iteration == 1 and self.run_at_start) or ( + global_step % self.every_n == 0 + ) # (self.every_n - 1) + if should_run: + log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") + self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) + log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") + # add necessary barrier to avoid timeout + if self.barrier_after_run: + distributed.barrier() + + @abstractmethod + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + ... diff --git a/cosmos_predict1/callbacks/grad_clip.py b/cosmos_predict1/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f320b6f79e1e289117d8190b5f6df52cf64ae --- /dev/null +++ b/cosmos_predict1/callbacks/grad_clip.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callback import Callback + + +@torch.jit.script +def _fused_nan_to_num(params: List[torch.Tensor]): + for param in params: + torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) + + +class GradClip(Callback): + def __init__( + self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False + ): + self.clip_norm = clip_norm + self.force_finite = force_finite + self.model_key = model_key + self.fsdp_enabled = fsdp_enabled + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + + # select sub-network if specified + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + + if self.force_finite: + params = [] + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + # check if FSDP is used + # total_norm + if isinstance(model, FSDP) and self.fsdp_enabled: + model.clip_grad_norm_(self.clip_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) diff --git a/cosmos_predict1/checkpointer/__init__.py b/cosmos_predict1/checkpointer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/checkpointer/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/checkpointer/base.py b/cosmos_predict1/checkpointer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4721905f50c45eae13dac833754303e9923c3a33 --- /dev/null +++ b/cosmos_predict1/checkpointer/base.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import callback +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + + +class AbstractCheckpointer(ABC): + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + self.config_checkpoint = config_checkpoint + # Set the callback functions. + self.callbacks = callbacks + + # Set checkpoint directories for local paths + self._local_dirname = os.path.join(config_job.path_local, "checkpoints") + + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + self.verbose = config_checkpoint.verbose + self.keys_not_to_resume = config_checkpoint.keys_not_to_resume + self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem + + @abstractmethod + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + pass + + @abstractmethod + def load( + self, + model: Model, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + grad_scaler: Optional[torch.amp.GradScaler] = None, + ) -> int: + pass + + @property + def save_bucket(self): + """Get the bucket name for saving checkpoints.""" + return None + + @property + def load_bucket(self): + """Get the bucket name for loading checkpoints.""" + return None + + @property + def save_dirname(self): + return self._local_dirname + + @property + def load_dirname(self): + return self._local_dirname + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt") + if easy_io.exists(checkpoint_path): + checkpoint_file = easy_io.load(checkpoint_path).strip() + + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt") + easy_io.dump(content, checkpoint_path) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not easy_io.exists(checkpoint_path): + raise FileNotFoundError(f"File not found: {checkpoint_path}") diff --git a/cosmos_predict1/checkpointer/ddp.py b/cosmos_predict1/checkpointer/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..cee4cf147f6882731239d3925355de581d7e9f1f --- /dev/null +++ b/cosmos_predict1/checkpointer/ddp.py @@ -0,0 +1,437 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from collections import namedtuple +from typing import Any, Dict, Optional, Set, Tuple, Union + +import torch +import torch.distributed +from megatron.core import parallel_state +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.checkpointer.base import AbstractCheckpointer +from cosmos_predict1.checkpointer.safe_broadcast import broadcast_object +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + +StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) + + +class Checkpointer(AbstractCheckpointer): + """ + Checkpointer for DDP. + Note: This implementation only supports local filesystem. + """ + + KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] + KEYS_TO_POSTFIX = { + "model": "model", + "optim": "optim", + "scheduler": "scheduler", + "trainer": "", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + ep_world_size = parallel_state.get_expert_model_parallel_world_size() + assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." + assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." + self.mp_world_size = parallel_state.get_model_parallel_group().size() + if self.mp_world_size > 1 and self.__class__ == Checkpointer: + raise NotImplementedError( + "Model Parallelism (MP) is enabled - " + "you should use TensorParallel Checkpointer instead of DDP Checkpointer." + ) + # DDP rank (with context parallelism considered) + self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) + # Context parallelism rank + self.cp_rank = parallel_state.get_context_parallel_rank() + # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) + self.mp_rank = parallel_state.get_model_parallel_group().rank() + # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() + if self.broadcast_via_filesystem: + log.info("Broadcasting checkpoint data via the local filesystem.") + if not self.strict_resume: + log.warning("Strict resume mode is off. Some model parameters may not be loaded.") + + # collect ranks of all model parallel groups + all_ranks = [None for _ in range(distributed.get_world_size())] + torch.distributed.all_gather_object( + all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) + ) + all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) + for ranks in all_ranks: + group = torch.distributed.new_group(list(ranks), backend="gloo") + if distributed.get_rank() in ranks: + self.mp_gloo_pg = group + + self.print("Checkpointer Initialized.") + + def print(self, message: str): + """ + Print message to the console. Include the parallelism rank information when verbose is set to True. + """ + if self.verbose: + log.info( + f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", + rank0_only=False, + ) + else: + log.info(message, rank0_only=True) + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + del model + assert key in self.KEYS_TO_SAVE + post_fix = self.KEYS_TO_POSTFIX[key] + + if post_fix: + _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") + else: + _ckpt_path = checkpoint_path + return _ckpt_path + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + **ignore_kwargs, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = self.format_checkpoint_filename(model, iteration) + state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) + state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) + if state_dict: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: + new_dict = {} + for key, _state_dict in state_dict.items(): + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) + checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) + new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) + return new_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). + + Args: + state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=True, # optional for fast backend, cpu heavy + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) + + def format_checkpoint_filename(self, model: Model, iteration: int) -> str: + """Generate the checkpoint file name. + + Args: + iteration (int): The current iteration number. + + Returns: + checkpoint_file (str): The checkpoint file name. + """ + del self, model + return f"iter_{iteration:09}.pt" + + @misc.timer("generate saving state dict") + def generate_save_state_dict( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> Optional[Dict[str, Any]]: + state_dict = {} + + if self.rank_dp_w_cp == 0: + trainer_state = dict( + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + model_state = model.state_dict() + optim_state = optimizer.state_dict() + scheduler_state = scheduler.state_dict() + self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) + + trainer_state, model_state, optim_state, scheduler_state = misc.to( + [trainer_state, model_state, optim_state, scheduler_state], device="cpu" + ) + + state_dict = { + "model": model_state, + "optim": optim_state, + "scheduler": scheduler_state, + } + if distributed.get_rank() == 0: # only rank 0 saves trainer state + state_dict["trainer"] = trainer_state + return state_dict + return state_dict + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast. + + The main steps are: + 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. + + This approach ensures that each MP rank loads its specific part of the model, which is + crucial for Model Parallelism where different parts of the model are distributed across + multiple GPUs. + + When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can + be set to True. This allows each rank to load its specific checkpoint from the local filesystem + instead of receiving it via network broadcast, which could be more efficient in some cases. + + For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = easy_io.load(local_cache_path, fast_backend=True) + else: + self.print(f"Downloading checkpoint from: {_ckpt_path}") + _state_dict = easy_io.load(_ckpt_path, fast_backend=True) + if self.broadcast_via_filesystem: + # Save the checkpoint to the local filesystem + easy_io.dump(_state_dict, local_cache_path, fast_backend=True) + state_dict[key] = _state_dict + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) + else: + # Broadcast the checkpoint to all GPUs of the current DDP rank + group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) + min_rank = min(get_process_group_ranks(group)) + + _state_dict = broadcast_object( + state_dict[key] if self.rank_dp_w_cp == 0 else None, + min_rank, + group=group, + device=torch.device(torch.cuda.current_device()), + ) + if self.rank_dp_w_cp == 0: + self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') + else: + state_dict[key] = _state_dict + self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') + + return state_dict + + def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: + latest_checkpoint_file = self._read_latest_checkpoint_file() + + resume_keys = [] + + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) + resume_keys.extend(self.KEYS_TO_SAVE) + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + if self.load_training_state: + resume_keys.extend(self.KEYS_TO_SAVE) + else: + resume_keys.append("model") + if self.only_load_scheduler_state: + resume_keys.append("scheduler") + else: + checkpoint_path = None + if len(self.keys_not_to_resume) > 0: + for key in self.keys_not_to_resume: + assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" + resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] + return set(resume_keys), checkpoint_path + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + resume_keys, checkpoint_path = self.keys_to_resume_during_load() + + iteration = 0 + + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) + + if "trainer" in state_dict: + trainer_state = state_dict["trainer"] + log.critical(state_dict.keys(), rank0_only=False) + log.critical(trainer_state, rank0_only=False) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(trainer_state["grad_scaler"]) + self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) + iteration = trainer_state["iteration"] + if "optim" in state_dict: + assert optimizer + optimizer_state = state_dict["optim"] + log.info("- Loading the optimizer...") + optimizer.load_state_dict(optimizer_state) + if "scheduler" in state_dict: + assert scheduler + scheduler_state = state_dict["scheduler"] + log.info("- Loading the scheduler...") + scheduler.load_state_dict(scheduler_state) + scheduler.last_epoch = iteration + if "model" in state_dict: + model_state = state_dict["model"] + log.info("- Loading the model...") + # model.load_state_dict(model_state) + if self.strict_resume: + log.info("\t Strict resume mode is on.") + else: + log.info("\t Strict resume mode is off.") + model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) + log.info(f"\t {model_load_info}") + self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: + """Write json file to save number of seen samples and number of iterations. + + Args: + checkpoint_file (str): iteration number for the saved checkpoint + trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. + """ + # filename: iter_xxxxxxxxx_trained_data_record.json + checkpoint_path = os.path.join( + self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" + ) + easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos_predict1/checkpointer/peft_checkpointer.py b/cosmos_predict1/checkpointer/peft_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7a3341042d35d7da6a7e3c4955d400ce73dd6a --- /dev/null +++ b/cosmos_predict1/checkpointer/peft_checkpointer.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Set + +import torch + +from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.model import Model + + +class Checkpointer(DDPCheckpointer): + """ + Checkpointer class for PEFT in distributed training. This class is similar to the DDP checkpointer, + with the exception that the `broadcast_via_filesystem` functionality is not supported, and it supports + loading pre-trained model without any postfix. + + Note: + - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.broadcast_via_filesystem: + raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + """ + Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) + to load pre-trained model without any postfix. + """ + checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + checkpoint_path = checkpoint_path.replace("model_model.pt", "model.pt") + return checkpoint_path + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast for PEFT checkpointer. + + This function is identical to the `load_broadcast_state_dict` function of the base class (DDP checkpointer), + with the exception that the `broadcast_via_filesystem` functionality is not supported. + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = torch.load( + local_cache_path, map_location=lambda storage, loc: storage, weights_only=False + ) + else: + # Pre-trained model is not in local cache, so we need to load it from the checkpoint path + self.print(f"Loading checkpoint from: {_ckpt_path}") + _state_dict = torch.load(_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) + state_dict[key] = _state_dict + + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = torch.load( + local_cache_path, map_location=lambda storage, loc: storage, weights_only=False + ) + else: + self.print(f"Loading checkpoint from: {_ckpt_path}") + state_dict[key] = torch.load( + _ckpt_path, map_location=lambda storage, loc: storage, weights_only=False + ) + + else: + raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") + + return state_dict diff --git a/cosmos_predict1/checkpointer/safe_broadcast.py b/cosmos_predict1/checkpointer/safe_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..f914299c97f297cf43a5919b0bc8130686afa7c1 --- /dev/null +++ b/cosmos_predict1/checkpointer/safe_broadcast.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import io +import pickle +from typing import Any + +import torch +import torch.distributed as dist + + +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 +def broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note: These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() + } + + return value diff --git a/cosmos_predict1/checkpointer/tp.py b/cosmos_predict1/checkpointer/tp.py new file mode 100644 index 0000000000000000000000000000000000000000..b97231a66aa13f5b26627fd7c302cbbb721492b0 --- /dev/null +++ b/cosmos_predict1/checkpointer/tp.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer +from cosmos_predict1.utils.model import Model + + +class Checkpointer(DDPCheckpointer): + """ + Checkpointer class for Tensor Parallelism (TP) in distributed training. + + This implementation supports the combination of Tensor Parallelism (TP) and Data Parallel Processing (DDP), with optional Context Parallelism (CP). + + Note: + - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. + - In principle, this implementation is also compatible with Pipeline Parallelism (PP) and Expert Parallelism (EP), which are other forms of model parallelism. However, PP and EP have not been tested yet. + """ + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + """ + Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) + to append the TP-rank postfix to the checkpoint path. + """ + checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + if key == "trainer": + return checkpoint_path + else: + checkpoint_path = checkpoint_path.replace(".pt", f"_mp_{self.mp_rank}.pt") + + return checkpoint_path diff --git a/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py b/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea13c45291ab346fc482725bb1b15b58f1d9618 --- /dev/null +++ b/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import attrs + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import CheckpointConfig as BaseCheckpointConfig +from cosmos_predict1.utils.config import make_freezable +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig(BaseCheckpointConfig): + load_ema_to_reg: bool = False + + +class FSDPCheckpointer(BaseFSDPCheckpointer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not isinstance(self.config_checkpoint, CheckpointConfig): + warnings.warn( + "The 'config_checkpoint' is not an instance of 'CheckpointConfig'. " + "This behavior is deprecated and will not be supported in future versions. " + "Please update 'config_checkpoint' to be of type 'CheckpointConfig'.", + DeprecationWarning, + ) + + self.load_ema_to_reg = False + else: + self.load_ema_to_reg = self.config_checkpoint.load_ema_to_reg + + log.critical(f"load_ema_to_reg: {self.load_ema_to_reg}", rank0_only=False) + + def load_model_during_init(self, model, is_ema: bool = False, ema_id: int = 0): + if self.load_ema_to_reg and is_ema is False: + is_ema = True + ema_id = 0 + log.critical("Loading EMA model to regular model during initialization.", rank0_only=False) + super().load_model_during_init(model, is_ema, ema_id) diff --git a/cosmos_predict1/diffusion/conditioner.py b/cosmos_predict1/diffusion/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6deb003d806b464c7921b1b78056b850a6045d --- /dev/null +++ b/cosmos_predict1/diffusion/conditioner.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import instantiate + + +class BaseConditionEntry(nn.Module): + def __init__(self): + super().__init__() + + self._dropout_rate = None + self._input_key = None + self._return_dict = False + + @property + def dropout_rate(self) -> Union[float, torch.Tensor]: + return self._dropout_rate + + @property + def input_key(self) -> str: + return self._input_key + + @property + def is_return_dict(self) -> bool: + return self._return_dict + + @dropout_rate.setter + def dropout_rate(self, value: Union[float, torch.Tensor]): + self._dropout_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_return_dict.setter + def is_return_dict(self, value: bool): + self._return_dict = value + + @dropout_rate.deleter + def dropout_rate(self): + del self._dropout_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + @is_return_dict.deleter + def is_return_dict(self): + del self._return_dict + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + bernoulli = torch.bernoulli((1.0 - dropout_rate) * torch.ones(len(in_tensor))).type_as(in_tensor) + bernoulli_expand = bernoulli.view((-1,) + (1,) * (in_tensor.dim() - 1)) + return bernoulli_expand * in_tensor + + def summary(self) -> str: + pass + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + + +class TextAttr(BaseConditionEntry): + def __init__(self): + super().__init__() + + def forward(self, token: torch.Tensor, mask: torch.Tensor): + return {"crossattn_emb": token, "crossattn_mask": mask} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + if key is not None and "mask" in key: + return in_tensor + return super().random_dropout_input(in_tensor, dropout_rate, key) + + +@dataclass +class BaseVideoCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + data_type: DataType = DataType.VIDEO + padding_mask: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + image_size: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + frame_repeat: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +@dataclass +class VideoExtendCondition(BaseVideoCondition): + video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video + gt_latent: Optional[torch.Tensor] = None + condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region + + # condition_video_input_mask will concat to the input of network, along channel dim; + # Will be concat with the input tensor + condition_video_input_mask: Optional[torch.Tensor] = None + # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + condition_video_augment_sigma: Optional[torch.Tensor] = None + condition_video_pose: Optional[torch.Tensor] = None + + +class GeneralConditioner(nn.Module, ABC): + """ + An abstract module designed to handle various embedding models with conditional and + unconditional configurations. This abstract base class initializes and manages a collection + of embedders that can dynamically adjust their dropout rates based on conditioning. + + Attributes: + KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. + embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and + configured based on the provided configurations. + + Parameters: + emb_models (Union[List, Any]): A dictionary where keys are embedder names and values + are configurations for initializing the embedders. + + """ + + KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} + + def __init__(self, **emb_models: Union[List, Any]): + super().__init__() + self.embedders = nn.ModuleDict() + for n, (emb_name, embconfig) in enumerate(emb_models.items()): + embedder = instantiate(embconfig.obj) + assert isinstance( + embedder, BaseConditionEntry + ), f"embedder model {embedder.__class__.__name__} has to inherit from BaseConditionEntry" + embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) + + if hasattr(embconfig, "input_key"): + embedder.input_key = embconfig.input_key + elif hasattr(embconfig, "input_keys"): + embedder.input_keys = embconfig.input_keys + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") + self.embedders[emb_name] = embedder + + @abstractmethod + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Any: + """Should be implemented in subclasses to handle conditon datatype""" + raise NotImplementedError + + def _forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Dict: + """ + Processes the input batch through all configured embedders, applying conditional dropout rates if specified. + Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. + + Parameters: + batch (Dict): The input data batch to process. + override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates + per embedder key. + + Returns: + Dict: A dictionary of output tensors concatenated by specified dimensions. + + Note: + In case the network code is sensitive to the order of concatenation, you can either control the order via \ + config file or make sure the embedders return a unique key for each output. + """ + output = defaultdict(list) + if override_dropout_rate is None: + override_dropout_rate = {} + + # make sure emb_name in override_dropout_rate is valid + for emb_name in override_dropout_rate.keys(): + assert emb_name in self.embedders, f"invalid name found {emb_name}" + + for emb_name, embedder in self.embedders.items(): + with torch.no_grad(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + emb_out = embedder( + embedder.random_dropout_input( + batch[embedder.input_key], override_dropout_rate.get(emb_name, None) + ) + ) + elif hasattr(embedder, "input_keys"): + emb_out = embedder( + *[ + embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) + for k in embedder.input_keys + ] + ) + for k, v in emb_out.items(): + output[k].append(v) + # Concatenate the outputs + return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} + + def get_condition_uncondition( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Processes the provided data batch to generate conditioned and unconditioned outputs. + + This method manipulates dropout rates to simulate two scenarios: + 1. All conditions applied (conditioned) + 2. Conditions removed/reduced to minimum (unconditioned) + + This method sets dropout rates to zero for the conditioned scenario to fully apply + embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is + insignificant) to minimize embedder influences. + + Parameters: + data_batch (Dict): Input data batch containing all necessary information for + embedding processing. + + Returns: + Tuple[Any, Any]: A tuple containing: + - Outputs with all embedders fully applied (conditioned) + - Outputs with embedders minimized/not applied (unconditioned) + """ + cond_dropout_rates, dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) + return condition, un_condition + + def get_condition_with_negative_prompt( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Similar functionality as get_condition_uncondition + But use negative prompts for unconditon + """ + cond_dropout_rates, uncond_dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + if isinstance(embedder, TextAttr): + uncond_dropout_rates[emb_name] = 0.0 + else: + uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + data_batch_neg_prompt = copy.deepcopy(data_batch) + if "neg_t5_text_embeddings" in data_batch_neg_prompt: + if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): + data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] + data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) + + return condition, un_condition + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class VideoConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoExtendConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) diff --git a/cosmos_predict1/diffusion/config/__init__.py b/cosmos_predict1/diffusion/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/base/__init__.py b/cosmos_predict1/diffusion/config/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/base/conditioner.py b/cosmos_predict1/diffusion/config/base/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..43e09aa039d2a71e30a4285970bcbe658e19df4d --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/conditioner.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import attrs +import torch + +from cosmos_predict1.diffusion.conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class TextConfig: + obj: LazyDict = L(TextAttr)() # No arguments + dropout_rate: float = 0.2 + input_keys: List[str] = attrs.field(factory=lambda: ["t5_text_embeddings", "t5_text_mask"]) + + +class BooleanFlag(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None): + super().__init__() + self.output_key = output_key + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + del args, kwargs + key = self.output_key if self.output_key else self.input_key + return {key: self.flag} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) + return in_tensor + + +class ReMapkey(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None, dtype: Optional[str] = None): + super().__init__() + self.output_key = output_key + self.dtype = { + None: None, + "float": torch.float32, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float16": torch.float16, + "int": torch.int32, + "long": torch.int64, + }[dtype] + + def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: + key = self.output_key if self.output_key else self.input_key + if isinstance(element, torch.Tensor): + element = element.to(dtype=self.dtype) + return {key: element} + + +class FrameRepeatAttr(BaseConditionEntry): + def __init__(self): + super().__init__() + + def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "frame_repeat": frame_repeat / 10.0, + } + + +@attrs.define(slots=False) +class FPSConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `fps`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="fps", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "fps" + + +@attrs.define(slots=False) +class PaddingMaskConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `padding_mask`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="padding_mask", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "padding_mask" + + +@attrs.define(slots=False) +class ImageSizeConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `image_size`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="image_size", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "image_size" + + +@attrs.define(slots=False) +class NumFramesConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `num_frames`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="num_frames", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "num_frames" + + +@attrs.define(slots=False) +class FrameRepeatConfig: + """ + Remap and process key from the input dictionary to the output dictionary. For `frame_repeat`. + """ + + obj: LazyDict = L(FrameRepeatAttr)() + dropout_rate: float = 0.0 + input_key: str = "frame_repeat" + + +@attrs.define(slots=False) +class VideoCondBoolConfig: + obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool") + dropout_rate: float = 0.2 + input_key: str = "fps" # This is a placeholder, we never use this value + # Config below are for long video generation only + compute_loss_for_condition_region: bool = False # Compute loss for condition region + + # How to sample condition region during training. "first_random_n" set the first n frames to be condition region, n is random, "random" set the condition region to be random, + condition_location: str = "first_random_n" + random_conditon_rate: float = 0.5 # The rate to sample the condition region randomly + first_random_n_num_condition_t_max: int = 4 # The maximum number of frames to sample as condition region, used when condition_location is "first_random_n" + first_random_n_num_condition_t_min: int = 0 # The minimum number of frames to sample as condition region, used when condition_location is "first_random_n" + + # How to dropout value of the conditional input frames + cfg_unconditional_type: str = "zero_condition_region_condition_mask" # Unconditional type. "zero_condition_region_condition_mask" set the input to zero for condition region, "noise_x_condition_region" set the input to x_t, same as the base model + + # How to corrupt the condition region + apply_corruption_to_condition_region: str = "noise_with_sigma" # Apply corruption to condition region, option: "gaussian_blur", "noise_with_sigma", "clean" (inference), "noise_with_sigma_fixed" (inference) + # Inference only option: list of sigma value for the corruption at different chunk id, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + apply_corruption_to_condition_region_sigma_value: list[float] = [0.001, 0.2] + [ + 0.5 + ] * 10 # Sigma value for the corruption, used when apply_corruption_to_condition_region is "noise_with_sigma_fixed" + + # Add augment_sigma condition to the network + condition_on_augment_sigma: bool = False + # The following arguments is to match with previous implementation where we use train sde to sample augment sigma (with adjust video noise turn on) + augment_sigma_sample_p_mean: float = 0.0 # Mean of the augment sigma + augment_sigma_sample_p_std: float = 1.0 # Std of the augment sigma + augment_sigma_sample_multiplier: float = 4.0 # Multipler of augment sigma + + # Add pose condition to the network + add_pose_condition: bool = False + + # Sample PPP... from IPPP... sequence + sample_tokens_start_from_p_or_i: bool = False + + # Normalize the input condition latent + normalize_condition_latent: bool = False + + +@attrs.define(slots=False) +class LatentConditionConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition" + + +@attrs.define(slots=False) +class LatentConditionSigmaConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition_sigma", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition_sigma" + + +BaseVideoConditionerConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), +) + +VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), +) + +VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), +) + +VideoConditionerFpsSizePaddingFrameRepeatConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + frame_repeat=FrameRepeatConfig(), +) + +VideoExtendConditionerFrameRepeatConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), + frame_repeat=FrameRepeatConfig(), +) diff --git a/cosmos_predict1/diffusion/config/base/model.py b/cosmos_predict1/diffusion/config/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..819ddb4a3b5ea63887ba3697b085b70fee319874 --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/model.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +import attrs + +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class DefaultModelConfig: + tokenizer: LazyDict = None + conditioner: LazyDict = None + net: LazyDict = None + sigma_data: float = 0.5 + precision: str = "bfloat16" + input_data_key: str = "video" # key to fetch input data from data_batch + latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + input_image_key: str = "images_1024" + adjust_video_noise: bool = False # Added field with default value + context_parallel_size: int = 1 # Added field with default value + # `num_latents_to_drop` is a flag that helps satisfy (1I,N*P,1I) latents setup. + # Since the tokenizer is causal and has the `T+1` input frames setup, it's + # challenging to encode arbitrary number of frames. To circumvent this, + # we sample as many frames, run the tokenizer twice, and discard the last + # chunk's P-latents, ensuring the requirement: I-latents for the input frames + # and P-latent for the-to-be-predicted in-between frames. + # By default, this flag does not have any effect. + num_latents_to_drop: int = 0 # number of P-latents to discard after encoding + + sde: Optional[Dict] = None + vae: Optional[Dict] = None # Add this line to include the vae field + peft_control: LazyDict | None = None + frame_buffer_max: Optional[int] = 1 + + +@attrs.define(slots=False) +class LatentDiffusionDecoderModelConfig(DefaultModelConfig): + tokenizer_corruptor: LazyDict = None + latent_corruptor: LazyDict = None + pixel_corruptor: LazyDict = None + diffusion_decoder_cond_sigma_low: float = None + diffusion_decoder_cond_sigma_high: float = None + diffusion_decoder_corrupt_prob: float = None + condition_on_tokenizer_corruptor_token: bool = False + + +@attrs.define(slots=False) +class MultiviewModelConfig(DefaultModelConfig): + n_views: int = 4 diff --git a/cosmos_predict1/diffusion/config/base/net.py b/cosmos_predict1/diffusion/config/base/net.py new file mode 100644 index 0000000000000000000000000000000000000000..6272aa996abc1cf1b394dc1f9ac1b8d2555ca554 --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/net.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FADITV2Config: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, +) + + +FADITV2_14B_Config = copy.deepcopy(FADITV2Config) +FADITV2_14B_Config.model_channels = 5120 +FADITV2_14B_Config.num_heads = 40 +FADITV2_14B_Config.num_blocks = 36 + + +FADITV2_Multiview_Config: LazyDict = L(MultiviewGeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, + n_views=6, + view_condition_dim=6, + add_repeat_frame_embedding=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=1.0, +) diff --git a/cosmos_predict1/diffusion/config/base/tokenizer.py b/cosmos_predict1/diffusion/config/base/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb73d34ea58994feb9738d15b8738d495cea0b8a --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/tokenizer.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_predict1_tokenizer", + latent_ch=16, + ) diff --git a/cosmos_predict1/diffusion/config/config.py b/cosmos_predict1/diffusion/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..efb24aceb14b94ec6d47d8b0e539a56d3e957679 --- /dev/null +++ b/cosmos_predict1/diffusion/config/config.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.config.base.model import DefaultModelConfig +from cosmos_predict1.diffusion.config.registry import register_configs +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"tokenizer": "tokenizer"}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig(), + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_diffusion" + c.job.group = "inference" + + # Call this function to register config groups for advanced overriding. + register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) + return c diff --git a/cosmos_predict1/diffusion/config/inference/__init__.py b/cosmos_predict1/diffusion/config/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py new file mode 100644 index 0000000000000000000000000000000000000000..e48c8c133494af1304d83ce9b9692025fc3b1b2a --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +GEN3C_Cosmos_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + in_channels=16 + 16 * 4 + 1 # 16: video_latent, 16 * 4: (warped_frames + warped_frames_mask) * buffer 2, 1: mask + ), + frame_buffer_max=2, + ), + job=dict(group="Gen3c", name="GEN3C_Cosmos_7B"), + ) +) + +cs = ConfigStore.instance() +for _item in [ + GEN3C_Cosmos_7B, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..57eda7cd4d5dea477b7ac7175cebf9b96b1f2133 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Text2World_7B_Multiview: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + {"override /net": "faditv2_multiview_7b"}, + {"override /conditioner": "add_fps_image_size_padding_mask_frame_repeat"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_7B_Multiview", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=57, + ) + ), + ), + ) +) + + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_Predict1_Text2World_7B_Multiview["job"]["name"], + node=Cosmos_Predict1_Text2World_7B_Multiview, +) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..13d709115df426479da991e3d827e1bb8c434fc7 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Text2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_7B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + ), + ), + ) +) + +Cosmos_Predict1_Text2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_14B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Text2World_14B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_14B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_14B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=121, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=33, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=17, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_lora: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B_Post_trained", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_lora", + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=27, rank=8, scale=1), + ), + ) +) + +cs = ConfigStore.instance() + +for _item in [ + Cosmos_Predict1_Text2World_7B, + Cosmos_Predict1_Text2World_14B, + Cosmos_Predict1_Text2World_7B_Post_trained, + Cosmos_Predict1_Text2World_14B_Post_trained, + Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb, + Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb, + Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb, + Cosmos_Predict1_Text2World_7B_Post_trained_lora, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6cfdcee8b3d343a8897fd99836370b7c8a7dc4 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned_multiview import MultiviewVideoExtendGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Video2World_7B_Multiview: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B_Multiview", + {"override /conditioner": "video_cond_frame_repeat"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Video2World_7B_Multiview", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=L(MultiviewVideoExtendGeneralDIT)( + n_views=6, + view_condition_dim=6, + add_repeat_frame_embedding=True, + ), + conditioner=dict(video_cond_bool=dict()), + ), + ) +) + + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_Predict1_Video2World_7B_Multiview["job"]["name"], + node=Cosmos_Predict1_Video2World_7B_Multiview, +) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..820157e52d3dd731d8f024cdeb5437e03471a999 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Video2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + ), + ), + job=dict(group="Video2World", name="Cosmos_Predict1_Video2World_7B"), + ) +) + + +Cosmos_Predict1_Video2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + ), + ), + job=dict(group="Video2World", name="Cosmos_Predict1_Video2World_14B"), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=121, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=25, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 24, # Latent height dim + 24, # Latent width dim + ], + tokenizer=dict( + # video_vae=dict(pixel_chunk_duration=17, spatial_resolution="384"), + video_vae=dict(pixel_chunk_duration=25, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_14B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_14B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_14B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_lora: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B_Post_trained", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_lora", + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=27, rank=8, scale=1), + ), + ) +) + +cs = ConfigStore.instance() +for _item in [ + Cosmos_Predict1_Video2World_7B, + Cosmos_Predict1_Video2World_14B, + Cosmos_Predict1_Video2World_7B_Post_trained, + Cosmos_Predict1_Video2World_14B_Post_trained, + Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb, + Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb, + Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb, + Cosmos_Predict1_Video2World_7B_Post_trained_lora, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..3580f8ef8947241412551242d21da32e012a6cc2 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_WorldInterpolator_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + sde=L(EDMSDE)( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ), + input_image_key="images_1024", + latent_shape=[ + 16, + 4, + 88, + 160, + ], + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=9, + ) + ), + vae=dict( # Added VAE field + pixel_chunk_duration=9, + latent_ch=16, + ), + adjust_video_noise=True, + num_latents_to_drop=1, + context_parallel_size=1, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_and_last_1", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + apply_corruption_to_condition_region_sigma_value=[0.001], + ), + text=dict( + dropout_rate=0.5, + ), + ), + net=L(VideoExtendGeneralDIT)( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + job=dict(group="WorldInterpolator", name="Cosmos_Predict1_WorldInterpolator_7B"), + ) +) + +Cosmos_Predict1_WorldInterpolator_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_WorldInterpolator_7B", + ], + job=dict( + name="Cosmos_Predict1_WorldInterpolator_7B_Post_trained", + ), + ) +) + + +cs = ConfigStore.instance() +for _item in [ + Cosmos_Predict1_WorldInterpolator_7B, + Cosmos_Predict1_WorldInterpolator_7B_Post_trained, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/registry.py b/cosmos_predict1/diffusion/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8a39525d0b03116da3ac08c55fa729f85cdc7c64 --- /dev/null +++ b/cosmos_predict1/diffusion/config/registry.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.config.base.conditioner import ( + BaseVideoConditionerConfig, + VideoConditionerFpsSizePaddingConfig, + VideoConditionerFpsSizePaddingFrameRepeatConfig, + VideoExtendConditionerConfig, + VideoExtendConditionerFrameRepeatConfig, +) +from cosmos_predict1.diffusion.config.base.net import FADITV2_14B_Config, FADITV2_Multiview_Config, FADITV2Config +from cosmos_predict1.diffusion.config.base.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 + + +def register_net(cs): + cs.store( + group="net", + package="model.net", + name="faditv2_7b", + node=FADITV2Config, + ) + cs.store( + group="net", + package="model.net", + name="faditv2_14b", + node=FADITV2_14B_Config, + ) + cs.store( + group="net", + package="model.net", + name="faditv2_multiview_7b", + node=FADITV2_Multiview_Config, + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="basic", + node=BaseVideoConditionerConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond", + node=VideoExtendConditionerConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask_frame_repeat", + node=VideoConditionerFpsSizePaddingFrameRepeatConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond_frame_repeat", + node=VideoExtendConditionerFrameRepeatConfig, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_net(cs) + register_conditioner(cs) + register_tokenizer(cs) diff --git a/cosmos_predict1/diffusion/functional/batch_ops.py b/cosmos_predict1/diffusion/functional/batch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a72b24097f7cc9e7e6a8b324919455131bf84d47 --- /dev/null +++ b/cosmos_predict1/diffusion/functional/batch_ops.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Functions for performing operations with broadcasting to the right axis +# +# Example +# input1: tensor of size (N1, N2) +# input2: tensor of size (N1, N2, N3, N4) +# batch_mul(input1, input2) = input1[:, :, None, None] * input2 +# +# If the common dimensions don't match, we raise an assertion error. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x / y diff --git a/cosmos_predict1/diffusion/functional/multi_step.py b/cosmos_predict1/diffusion/functional/multi_step.py new file mode 100644 index 0000000000000000000000000000000000000000..76cb57aea441bfa36ddf7eeac9be75b40761cc5b --- /dev/null +++ b/cosmos_predict1/diffusion/functional/multi_step.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Impl of multistep methods to solve the ODE in the diffusion model. +""" + +from typing import Callable, List, Tuple + +import torch + +from cosmos_predict1.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step + + +def order2_fn( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + impl the second order multistep method in https://arxiv.org/pdf/2308.02157 + Adams Bashforth approach! + """ + if x0_preds: + x0_s1, s1 = x0_preds[0] + x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) + else: + x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] + return x_t, [(x0_s, s)] + + +# key: method name, value: method function +# key: order + algorithm name +MULTISTEP_FNs = { + "2ab": order2_fn, +} + + +def get_multi_step_fn(name: str) -> Callable: + if name in MULTISTEP_FNs: + return MULTISTEP_FNs[name] + methods = "\n\t".join(MULTISTEP_FNs.keys()) + raise RuntimeError("Only support multistep method\n" + methods) + + +def is_multi_step_fn_supported(name: str) -> bool: + """ + Check if the multistep method is supported. + """ + return name in MULTISTEP_FNs diff --git a/cosmos_predict1/diffusion/functional/runge_kutta.py b/cosmos_predict1/diffusion/functional/runge_kutta.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5841db9fb4ad463f9411206c953f405a7fad50 --- /dev/null +++ b/cosmos_predict1/diffusion/functional/runge_kutta.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple + +import torch + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul + + +def phi1(t: torch.Tensor) -> torch.Tensor: + """ + Compute the first order phi function: (exp(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi1 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return (torch.expm1(t) / t).to(dtype=input_dtype) + + +def phi2(t: torch.Tensor) -> torch.Tensor: + """ + Compute the second order phi function: (phi1(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi2 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) + + +def res_x0_rk2_step( + x_s: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + x0_s: torch.Tensor, + s1: torch.Tensor, + x0_s1: torch.Tensor, +) -> torch.Tensor: + """ + Perform a residual-based 2nd order Runge-Kutta step. + + Args: + x_s: Current state tensor. + t: Target time tensor. + s: Current time tensor. + x0_s: Prediction at current time. + s1: Intermediate time tensor. + x0_s1: Prediction at intermediate time. + + Returns: + Tensor: Updated state tensor. + + Raises: + AssertionError: If step size is too small. + """ + s = -torch.log(s) + t = -torch.log(t) + m = -torch.log(s1) + + dt = t - s + assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + + c2 = (m - s) / dt + phi1_val, phi2_val = phi1(-dt), phi2(-dt) + + # Handle edge case where t = s = m + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) + + +def reg_x0_euler_step( + x_s: torch.Tensor, + s: torch.Tensor, + t: torch.Tensor, + x0_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on x0 prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_s: Prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current prediction. + """ + coef_x0 = (s - t) / s + coef_xs = t / s + return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s + + +def reg_eps_euler_step( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on epsilon prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + eps_s: Epsilon prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction. + """ + return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s) + + +def rk1_euler( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a first-order Runge-Kutta (Euler) step. + + Recommended for diffusion models with guidance or model undertrained + Usually more stable at the cost of a bit slower convergence. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + x0_s = x0_fn(x_s, s) + return reg_x0_euler_step(x_s, s, t, x0_s) + + +def rk2_mid_stable( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a stable second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, _ = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + return reg_x0_euler_step(x_s, s, t, x0_s1) + + +def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + + return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1 + + +def rk_2heun_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + eps_s = batch_mul(1.0 / s, x_t - x0_s) + x0_t = x0_fn(x_t, t) + eps_t = batch_mul(1.0 / t, x_t - x0_t) + + avg_eps = (eps_s + eps_t) / 2 + + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +def rk_2heun_edm( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based no EDM second order Heun method + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + x0_t = x0_fn(x_t, t) + + avg_x0 = (x0_s + x0_t) / 2 + + return reg_x0_euler_step(x_s, s, t, avg_x0) + + +def rk_3kutta_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive third-order Runge-Kutta step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + c2, c3 = 0.5, 1.0 + a31, a32 = -1.0, 2.0 + b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6 + + delta = t - s + + s1 = c2 * delta + s + s2 = c3 * delta + s + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + eps_s = batch_mul(1.0 / s, x_s - x0_s) + x0_s1 = x0_fn(x_s1, s1) + eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1) + + _eps = a31 * eps_s + a32 * eps_s1 + x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps) + + x0_s2 = x0_fn(x_s2, s2) + eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2) + + avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2 + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +# key : order + name +RK_FNs = { + "1euler": rk1_euler, + "2mid": rk2_mid, + "2mid_stable": rk2_mid_stable, + "2heun_edm": rk_2heun_edm, + "2heun_naive": rk_2heun_naive, + "3kutta_naive": rk_3kutta_naive, +} + + +def get_runge_kutta_fn(name: str) -> Callable: + """ + Get the specified Runge-Kutta function. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + Callable: The specified Runge-Kutta function. + + Raises: + RuntimeError: If the specified method is not supported. + """ + if name in RK_FNs: + return RK_FNs[name] + methods = "\n\t".join(RK_FNs.keys()) + raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}") + + +def is_runge_kutta_fn_supported(name: str) -> bool: + """ + Check if the specified Runge-Kutta function is supported. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + bool: True if the method is supported, False otherwise. + """ + return name in RK_FNs diff --git a/cosmos_predict1/diffusion/inference/cache_3d.py b/cosmos_predict1/diffusion/inference/cache_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d3697779cf22acebd3e4100f1254fdc7f33e33 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/cache_3d.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange + +from cosmos_predict1.diffusion.inference.forward_warp_utils_pytorch import ( + forward_warp, + reliable_depth_mask_range_batch, + unproject_points, +) +from cosmos_predict1.diffusion.inference.camera_utils import align_depth + +class Cache3D_Base: + def __init__( + self, + input_image, + input_depth, + input_w2c, + input_intrinsics, + input_mask=None, + input_format=None, + input_points=None, + weight_dtype=torch.float32, + is_depth=True, + device="cuda", + filter_points_threshold=1.0, + foreground_masking=False, + ): + """ + input_image: Tensor with varying dimensions. + input_format: List of dimension labels corresponding to input_image's dimensions. + E.g., ['B', 'C', 'H', 'W'], ['B', 'F', 'C', 'H', 'W'], etc. + """ + self.weight_dtype = weight_dtype + self.is_depth = is_depth + self.device = device + self.filter_points_threshold = filter_points_threshold + self.foreground_masking = foreground_masking + if input_format is None: + assert input_image.dim() == 4 + input_format = ["B", "C", "H", "W"] + + # Map dimension names to their indices in input_image + format_to_indices = {dim: idx for idx, dim in enumerate(input_format)} + input_shape = input_image.shape + if input_mask is not None: + input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C")) + + # B (batch size), F (frame count), N dimensions: no aggregation during warping. + # Only broadcasting over F to match the target w2c. + # V: aggregate via concatenation or duster + B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1 # batch + F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1 # frame + N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1 # buffer + V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1 # view + H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None + W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None + + # Desired dimension order + desired_dims = ["B", "F", "N", "V", "C", "H", "W"] + + # Build permute order based on input_format + permute_order = [] + for dim in desired_dims: + idx = format_to_indices.get(dim) + if idx is not None: + permute_order.append(idx) + else: + # Placeholder for dimensions to be added later + permute_order.append(None) + + # Remove None values for permute operation + permute_indices = [idx for idx in permute_order if idx is not None] + input_image = input_image.permute(*permute_indices) + + # Insert dimensions of size 1 where necessary + for i, idx in enumerate(permute_order): + if idx is None: + input_image = input_image.unsqueeze(i) + + # Now input_image has the shape B x F x N x V x C x H x W + if input_mask is not None: + self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:] + self.input_mask = self.input_mask.to("cpu") + else: + self.input_mask = None + self.input_image = input_image + self.input_image = self.input_image.to(weight_dtype).to("cpu") + + if input_points is not None: + self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu") + self.input_depth = None + else: + input_depth = torch.nan_to_num(input_depth, nan=100) + input_depth = torch.clamp(input_depth, min=0, max=100) + if weight_dtype == torch.float16: + input_depth = torch.clamp(input_depth, max=70) + self.input_points = ( + self._compute_input_points( + input_depth.reshape(-1, 1, H, W), + input_w2c.reshape(-1, 4, 4), + input_intrinsics.reshape(-1, 3, 3), + ) + .to(weight_dtype) + .reshape(B, F, N, V, H, W, 3) + .to("cpu") + ) + self.input_depth = input_depth + + if self.filter_points_threshold < 1.0 and input_depth is not None: + input_depth = input_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(input_depth, ratio_thresh=self.filter_points_threshold).reshape(B, F, N, V, 1, H, W) + if self.input_mask is None: + self.input_mask = depth_mask.to("cpu") + else: + self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device) + self.boundary_mask = None + if foreground_masking: + input_depth = input_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(input_depth) + self.boundary_mask = (~depth_mask).reshape(B, F, N, V, 1, H, W).to("cpu") + + def _compute_input_points(self, input_depth, input_w2c, input_intrinsics): + input_points = unproject_points( + input_depth, + input_w2c, + input_intrinsics, + is_depth=self.is_depth, + ) + return input_points + + def update_cache(self): + raise NotImplementedError + + def input_frame_count(self) -> int: + return self.input_image.shape[1] + + def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): + bs, F_target, _, _ = target_w2cs.shape + + B, F, N, V, C, H, W = self.input_image.shape + assert bs == B + + target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4) + target_intrinsics = ( + target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3, 3).reshape(-1, 3, 3) + ) + + first_images = rearrange(self.input_image[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, C, H, W), "B F N V C H W-> (B F N) V C H W").to(self.device) + first_points = rearrange( + self.input_points[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, H, W, 3), "B F N V H W C-> (B F N) V H W C" + ).to(self.device) + first_masks = rearrange( + self.input_mask[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" + ).to(self.device) if self.input_mask is not None else None + boundary_masks = rearrange( + self.boundary_mask.expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" + ) if self.boundary_mask is not None else None + + if first_images.shape[1] == 1: + warp_chunk_size = 2 + rendered_warp_images = [] + rendered_warp_masks = [] + rendered_warp_depth = [] + rendered_warped_flows = [] + + first_images = first_images.squeeze(1) + first_points = first_points.squeeze(1) + first_masks = first_masks.squeeze(1) if first_masks is not None else None + for i in range(0, first_images.shape[0], warp_chunk_size): + ( + rendered_warp_images_chunk, + rendered_warp_masks_chunk, + rendered_warp_depth_chunk, + rendered_warped_flows_chunk, + ) = forward_warp( + first_images[i : i + warp_chunk_size], + mask1=first_masks[i : i + warp_chunk_size] if first_masks is not None else None, + depth1=None, + transformation1=None, + transformation2=target_w2cs[i : i + warp_chunk_size], + intrinsic1=target_intrinsics[i : i + warp_chunk_size], + intrinsic2=target_intrinsics[i : i + warp_chunk_size], + render_depth=render_depth, + world_points1=first_points[i : i + warp_chunk_size], + foreground_masking=self.foreground_masking, + boundary_mask=boundary_masks[i : i + warp_chunk_size, 0, 0] if boundary_masks is not None else None + ) + rendered_warp_images.append(rendered_warp_images_chunk) + rendered_warp_masks.append(rendered_warp_masks_chunk) + rendered_warp_depth.append(rendered_warp_depth_chunk) + rendered_warped_flows.append(rendered_warped_flows_chunk) + rendered_warp_images = torch.cat(rendered_warp_images, dim=0) + rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0) + if render_depth: + rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0) + rendered_warped_flows = torch.cat(rendered_warped_flows, dim=0) + + else: + raise NotImplementedError + + pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) + masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) + if render_depth: + pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N) + return pixels, masks + + +class Cache3D_Buffer(Cache3D_Base): + def __init__(self, frame_buffer_max=0, noise_aug_strength=0, generator=None, **kwargs): + super().__init__(**kwargs) + self.frame_buffer_max = frame_buffer_max + self.noise_aug_strength = noise_aug_strength + self.generator = generator + + def update_cache(self, new_image, new_depth, new_w2c, new_mask=None, new_intrinsics=None, depth_alignment=True, alignment_method="non_rigid"): # 3D cache + new_image = new_image.to(self.weight_dtype).to(self.device) + new_depth = new_depth.to(self.weight_dtype).to(self.device) + new_w2c = new_w2c.to(self.weight_dtype).to(self.device) + if new_intrinsics is not None: + new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device) + + new_depth = torch.nan_to_num(new_depth, nan=1e4) + new_depth = torch.clamp(new_depth, min=0, max=1e4) + + if depth_alignment: + target_depth, target_mask = self.render_cache( + new_w2c.unsqueeze(1), new_intrinsics.unsqueeze(1), render_depth=True + ) + target_depth, target_mask = target_depth[:, :, 0], target_mask[:, :, 0] + if alignment_method == "rigid": + new_depth = ( + align_depth( + new_depth.squeeze(), + target_depth.squeeze(), + target_mask.bool().squeeze(), + ) + .reshape_as(new_depth) + .detach() + ) + elif alignment_method == "non_rigid": + with torch.enable_grad(): + new_depth = ( + align_depth( + new_depth.squeeze(), + target_depth.squeeze(), + target_mask.bool().squeeze(), + k=new_intrinsics.squeeze(), + c2w=torch.inverse(new_w2c.squeeze()), + alignment_method="non_rigid", + num_iters=100, + lambda_arap=0.1, + smoothing_kernel_size=3, + ) + .reshape_as(new_depth) + .detach() + ) + else: + raise NotImplementedError + new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu() + new_image = new_image.cpu() + + if self.filter_points_threshold < 1.0: + B, F, N, V, C, H, W = self.input_image.shape + new_depth = new_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(new_depth, ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W) + if new_mask is None: + new_mask = depth_mask.to("cpu") + else: + new_mask = new_mask * depth_mask.to(new_mask.device) + if new_mask is not None: + new_mask = new_mask.cpu() + if self.frame_buffer_max > 1: # newest frame first + if self.input_image.shape[2] < self.frame_buffer_max: + self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2) + self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2) + if self.input_mask is not None: + self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2) + else: + self.input_image[:, :, 0] = new_image[:, None, None] + self.input_points[:, :, 0] = new_points[:, None, None] + if self.input_mask is not None: + self.input_mask[:, :, 0] = new_mask[:, None, None] + else: + self.input_image = new_image[:, None, None, None] + self.input_points = new_points[:, None, None, None] + + + def render_cache( + self, + target_w2cs, + target_intrinsics, + render_depth: bool = False, + start_frame_idx: int = 0, # For consistency with Cache4D + ): + assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3D_Buffer" + + output_device = target_w2cs.device + target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device) + target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device) + pixels, masks = super().render_cache( + target_w2cs, target_intrinsics, render_depth + ) + if not render_depth: + noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype) + per_buffer_noise = ( + torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) + * self.noise_aug_strength + ) + pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1) # B, F, N, C, H, W + return pixels.to(output_device), masks.to(output_device) + + +class Cache4D(Cache3D_Base): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def update_cache(self, **kwargs): + raise NotImplementedError + + def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): + rendered_warp_images, rendered_warp_masks = super().render_cache(target_w2cs, target_intrinsics, render_depth, start_frame_idx) + return rendered_warp_images, rendered_warp_masks diff --git a/cosmos_predict1/diffusion/inference/camera_utils.py b/cosmos_predict1/diffusion/inference/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa370636e2067e95d1332a7001b73e3583d35448 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/camera_utils.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import math +import torch.nn.functional as F +from .forward_warp_utils_pytorch import unproject_points + +def apply_transformation(Bx4x4, another_matrix): + B = Bx4x4.shape[0] + if another_matrix.dim() == 2: + another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1) # Make another_matrix compatible with batch size + transformed_matrix = torch.bmm(Bx4x4, another_matrix) # Shape: (B, 4, 4) + + return transformed_matrix + + +def look_at_matrix(camera_pos, target, invert_pos=True): + """Creates a 4x4 look-at matrix, keeping the camera pointing towards a target.""" + forward = (target - camera_pos).float() + forward = forward / torch.norm(forward) + + up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device) # assuming Y-up coordinate system + right = torch.cross(up, forward) + right = right / torch.norm(right) + up = torch.cross(forward, right) + + look_at = torch.eye(4, device=camera_pos.device) + look_at[0, :3] = right + look_at[1, :3] = up + look_at[2, :3] = forward + look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos + + return look_at + +def create_horizontal_trajectory( + world_to_camera_matrix, center_depth, positive=True, n_steps=13, distance=0.1, device="cuda", axis="x", camera_rotation="center_facing" +): + look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) + # Spiral motion key points + trajectory = [] + translation_positions = [] + initial_camera_pos = torch.tensor([0, 0, 0], device=device) + + for i in range(n_steps): + if axis == "x": # pos - right + x = i * distance * center_depth / n_steps * (1 if positive else -1) + y = 0 + z = 0 + elif axis == "y": # pos - down + x = 0 + y = i * distance * center_depth / n_steps * (1 if positive else -1) + z = 0 + elif axis == "z": # pos - in + x = 0 + y = 0 + z = i * distance * center_depth / n_steps * (1 if positive else -1) + else: + raise ValueError("Axis should be x, y or z") + + translation_positions.append(torch.tensor([x, y, z], device=device)) + + for pos in translation_positions: + camera_pos = initial_camera_pos + pos + if camera_rotation == "trajectory_aligned": + _look_at = look_at + pos * 2 + elif camera_rotation == "center_facing": + _look_at = look_at + elif camera_rotation == "no_rotation": + _look_at = look_at + pos + else: + raise ValueError("Camera rotation should be center_facing or trajectory_aligned") + view_matrix = look_at_matrix(camera_pos, _look_at) + trajectory.append(view_matrix) + trajectory = torch.stack(trajectory) + return apply_transformation(trajectory, world_to_camera_matrix) + + +def create_spiral_trajectory( + world_to_camera_matrix, + center_depth, + radius_x=0.03, + radius_y=0.02, + radius_z=0.0, + positive=True, + camera_rotation="center_facing", + n_steps=13, + device="cuda", + start_from_zero=True, + num_circles=1, +): + + look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) + + # Spiral motion key points + trajectory = [] + spiral_positions = [] + initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() + + example_scale = 1.0 + + theta_max = 2 * math.pi * num_circles + + for i in range(n_steps): + # theta = 2 * math.pi * i / (n_steps-1) # angle for each point + theta = theta_max * i / (n_steps - 1) # angle for each point + if start_from_zero: + x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * (center_depth / example_scale) + else: + x = radius_x * (math.cos(theta)) * (center_depth / example_scale) + + y = radius_y * math.sin(theta) * (center_depth / example_scale) + z = radius_z * math.sin(theta) * (center_depth / example_scale) + spiral_positions.append(torch.tensor([x, y, z], device=device)) + + for pos in spiral_positions: + if camera_rotation == "center_facing": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at) + elif camera_rotation == "trajectory_aligned": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos * 2) + elif camera_rotation == "no_rotation": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos) + else: + raise ValueError("Camera rotation should be center_facing, trajectory_aligned or no_rotation") + trajectory.append(view_matrix) + trajectory = torch.stack(trajectory) + return apply_transformation(trajectory, world_to_camera_matrix) + + +def generate_camera_trajectory( + trajectory_type: str, + initial_w2c: torch.Tensor, # Shape: (4, 4) + initial_intrinsics: torch.Tensor, # Shape: (3, 3) + num_frames: int, + movement_distance: float, + camera_rotation: str, + center_depth: float = 1.0, + device: str = "cuda", +): + """ + Generates a sequence of camera poses (world-to-camera matrices) and intrinsics + for a specified trajectory type. + + Args: + trajectory_type: Type of trajectory (e.g., "left", "right", "up", "down", "zoom_in", "zoom_out"). + initial_w2c: Initial world-to-camera matrix (4x4 tensor or num_framesx4x4 tensor). + initial_intrinsics: Camera intrinsics matrix (3x3 tensor or num_framesx3x3 tensor). + num_frames: Number of frames (steps) in the trajectory. + movement_distance: Distance factor for the camera movement. + camera_rotation: Type of camera rotation ('center_facing', 'no_rotation', 'trajectory_aligned'). + center_depth: Depth of the center point the camera might focus on. + device: Computation device ("cuda" or "cpu"). + + Returns: + A tuple (generated_w2cs, generated_intrinsics): + - generated_w2cs: Batch of world-to-camera matrices for the trajectory (1, num_frames, 4, 4 tensor). + - generated_intrinsics: Batch of camera intrinsics for the trajectory (1, num_frames, 3, 3 tensor). + """ + if trajectory_type in ["clockwise", "counterclockwise"]: + new_w2cs_seq = create_spiral_trajectory( + world_to_camera_matrix=initial_w2c, + center_depth=center_depth, + n_steps=num_frames, + positive=trajectory_type == "clockwise", + device=device, + camera_rotation=camera_rotation, + radius_x=movement_distance, + radius_y=movement_distance, + ) + else: + if trajectory_type == "left": + positive = False + axis = "x" + elif trajectory_type == "right": + positive = True + axis = "x" + elif trajectory_type == "up": + positive = False # Assuming 'up' means camera moves in negative y direction if y points down + axis = "y" + elif trajectory_type == "down": + positive = True # Assuming 'down' means camera moves in positive y direction if y points down + axis = "y" + elif trajectory_type == "zoom_in": + positive = True # Assuming 'zoom_in' means camera moves in positive z direction (forward) + axis = "z" + elif trajectory_type == "zoom_out": + positive = False # Assuming 'zoom_out' means camera moves in negative z direction (backward) + axis = "z" + else: + raise ValueError(f"Unsupported trajectory type: {trajectory_type}") + + # Generate world-to-camera matrices using create_horizontal_trajectory + new_w2cs_seq = create_horizontal_trajectory( + world_to_camera_matrix=initial_w2c, + center_depth=center_depth, + n_steps=num_frames, + positive=positive, + axis=axis, + distance=movement_distance, + device=device, + camera_rotation=camera_rotation, + ) + + generated_w2cs = new_w2cs_seq.unsqueeze(0) # Shape: [1, num_frames, 4, 4] + if initial_intrinsics.dim() == 2: + generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1) + else: + generated_intrinsics = initial_intrinsics.unsqueeze(0) + + return generated_w2cs, generated_intrinsics + + +def _align_inv_depth_to_depth( + source_inv_depth: torch.Tensor, + target_depth: torch.Tensor, + target_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Apply affine transformation to align source inverse depth to target depth. + + Args: + source_inv_depth: Inverse depth map to be aligned. Shape: (H, W). + target_depth: Target depth map. Shape: (H, W). + target_mask: Mask of valid target pixels. Shape: (H, W). + + Returns: + Aligned Depth map. Shape: (H, W). + """ + target_inv_depth = 1.0 / target_depth + source_mask = source_inv_depth > 0 + target_depth_mask = target_depth > 0 + + if target_mask is None: + target_mask = target_depth_mask + else: + target_mask = torch.logical_and(target_mask > 0, target_depth_mask) + + # Remove outliers + outlier_quantiles = torch.tensor([0.1, 0.9], device=source_inv_depth.device) + + source_data_low, source_data_high = torch.quantile(source_inv_depth[source_mask], outlier_quantiles) + target_data_low, target_data_high = torch.quantile(target_inv_depth[target_mask], outlier_quantiles) + source_mask = (source_inv_depth > source_data_low) & (source_inv_depth < source_data_high) + target_mask = (target_inv_depth > target_data_low) & (target_inv_depth < target_data_high) + + mask = torch.logical_and(source_mask, target_mask) + + source_data = source_inv_depth[mask].view(-1, 1) + target_data = target_inv_depth[mask].view(-1, 1) + + ones = torch.ones((source_data.shape[0], 1), device=source_data.device) + source_data_h = torch.cat([source_data, ones], dim=1) + transform_matrix = torch.linalg.lstsq(source_data_h, target_data).solution + + scale, bias = transform_matrix[0, 0], transform_matrix[1, 0] + aligned_inv_depth = source_inv_depth * scale + bias + + return 1.0 / aligned_inv_depth + + +def align_depth( + source_depth: torch.Tensor, + target_depth: torch.Tensor, + target_mask: torch.Tensor, + k: torch.Tensor = None, + c2w: torch.Tensor = None, + alignment_method: str = "rigid", + num_iters: int = 100, + lambda_arap: float = 0.1, + smoothing_kernel_size: int = 3, +) -> torch.Tensor: + if alignment_method == "rigid": + source_inv_depth = 1.0 / source_depth + source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) + return source_depth + elif alignment_method == "non_rigid": + if k is None or c2w is None: + raise ValueError("Camera intrinsics (k) and camera-to-world matrix (c2w) are required for non-rigid alignment") + + source_inv_depth = 1.0 / source_depth + source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) + + # Initialize scale map + sc_map = torch.ones_like(source_depth).float().to(source_depth.device).requires_grad_(True) + optimizer = torch.optim.Adam(params=[sc_map], lr=0.001) + + # Unproject target depth + target_unprojected = unproject_points( + target_depth.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions + c2w.unsqueeze(0), # Add batch dimension + k.unsqueeze(0), # Add batch dimension + is_depth=True, + mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + ).squeeze(0) # Remove batch dimension + + # Create smoothing kernel + smoothing_kernel = torch.ones( + (1, 1, smoothing_kernel_size, smoothing_kernel_size), + device=source_depth.device + ) / (smoothing_kernel_size**2) + + for _ in range(num_iters): + # Unproject scaled source depth + source_unprojected = unproject_points( + (source_depth * sc_map).unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions + c2w.unsqueeze(0), # Add batch dimension + k.unsqueeze(0), # Add batch dimension + is_depth=True, + mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + ).squeeze(0) # Remove batch dimension + + # Data loss + data_loss = torch.abs(source_unprojected[target_mask] - target_unprojected[target_mask]).mean() + + # Apply smoothing filter to sc_map + sc_map_reshaped = sc_map.unsqueeze(0).unsqueeze(0) + sc_map_smoothed = F.conv2d( + sc_map_reshaped, + smoothing_kernel, + padding=smoothing_kernel_size // 2 + ).squeeze(0).squeeze(0) + + # ARAP loss + arap_loss = torch.abs(sc_map_smoothed - sc_map).mean() + + # Total loss + loss = data_loss + lambda_arap * arap_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return source_depth * sc_map + else: + raise ValueError(f"Unsupported alignment method: {alignment_method}") diff --git a/cosmos_predict1/diffusion/inference/data_loader_utils.py b/cosmos_predict1/diffusion/inference/data_loader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c47d0573fa548ce76819ff43ac9fbf5ebe9e98f1 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/data_loader_utils.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data loading utilities for the distributed format: +- RGB from mp4 +- Depth from float16 numpy +- Camera data from float32 numpy +""" + +import os +import numpy as np +import torch +import cv2 +from pathlib import Path + + +def load_rgb_from_mp4(video_path): + """ + Load RGB video from mp4 file and convert to tensor. + + Args: + video_path: str, path to the mp4 file + + Returns: + torch.Tensor: RGB tensor of shape [T, C, H, W] with range [-1, 1] + """ + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video file: {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + + cap.release() + + if not frames: + raise ValueError(f"No frames found in video: {video_path}") + + # Convert to numpy array and then tensor + frames_np = np.stack(frames, axis=0) # [T, H, W, C] + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() # [T, C, H, W] + + # Convert from [0, 255] to [-1, 1] + frames_tensor = (frames_tensor / 127.5) - 1.0 + + return frames_tensor + + +def load_depth_from_numpy(depth_path): + """ + Load depth data from compressed NPZ file. + + Args: + depth_path: str, path to the NPZ file + + Returns: + torch.Tensor: Depth tensor of shape [T, 1, H, W] + """ + data = np.load(depth_path) + depth_np = data['depth'] # [T, H, W] + depth_tensor = torch.from_numpy(depth_np.astype(np.float32)) + + # Add channel dimension: [T, H, W] -> [T, 1, H, W] + depth_tensor = depth_tensor.unsqueeze(1) + + return depth_tensor + + +def load_mask_from_numpy(mask_path): + """ + Load mask data from compressed NPZ file. + + Args: + mask_path: str, path to the NPZ file + + Returns: + torch.Tensor: Mask tensor of shape [T, 1, H, W] + """ + data = np.load(mask_path) + mask_np = data['mask'] # [T, H, W] as bool + mask_tensor = torch.from_numpy(mask_np.astype(np.float32)) # Convert bool to float32 + + # Add channel dimension: [T, H, W] -> [T, 1, H, W] + mask_tensor = mask_tensor.unsqueeze(1) + + return mask_tensor + + +def load_camera_from_numpy(data_dir): + """ + Load camera parameters from compressed NPZ file. + + Args: + data_dir: str, directory containing camera.npz + + Returns: + tuple: (w2c_tensor, intrinsics_tensor) + - w2c_tensor: torch.Tensor of shape [T, 4, 4] + - intrinsics_tensor: torch.Tensor of shape [T, 3, 3] + """ + camera_path = os.path.join(data_dir, "camera.npz") + + if not os.path.exists(camera_path): + raise FileNotFoundError(f"camera file not found: {camera_path}") + + data = np.load(camera_path) + w2c_np = data['w2c'] + intrinsics_np = data['intrinsics'] + + w2c_tensor = torch.from_numpy(w2c_np) + intrinsics_tensor = torch.from_numpy(intrinsics_np) + + return w2c_tensor, intrinsics_tensor + + +def load_data_distributed_format(data_dir): + """Load data from distributed format (mp4 + numpy files)""" + data_path = Path(data_dir) + + # Load RGB from mp4 + cap = cv2.VideoCapture(str(data_path / "rgb.mp4")) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + + frames_np = np.stack(frames, axis=0) + image_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + image_tensor = (image_tensor / 127.5) - 1.0 # [0,255] -> [-1,1] + + # Load depth and mask + depth_tensor = torch.from_numpy(np.load(data_path / "depth.npz")['depth'].astype(np.float32)).unsqueeze(1) + mask_tensor = torch.from_numpy(np.load(data_path / "mask.npz")['mask'].astype(np.float32)).unsqueeze(1) + + # Load camera data + camera_data = np.load(data_path / "camera.npz") + w2c_tensor = torch.from_numpy(camera_data['w2c']) + intrinsics_tensor = torch.from_numpy(camera_data['intrinsics']) + + return image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor + + +def load_data_packaged_format(pt_path): + """ + Load data from the packaged pt format for backward compatibility. + + Args: + pt_path: str, path to the pt file + + Returns: + tuple: (image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor) + """ + data = torch.load(pt_path) + + if len(data) != 5: + raise ValueError(f"Expected 5 tensors in pt file, got {len(data)}") + + return data + + +def load_data_auto_detect(input_path): + """Auto-detect format and load data""" + input_path = Path(input_path) + + if input_path.is_file() and input_path.suffix == '.pt': + return load_data_packaged_format(input_path) + elif input_path.is_dir(): + return load_data_distributed_format(input_path) + else: + raise ValueError(f"Invalid input path: {input_path}") \ No newline at end of file diff --git a/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py b/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..3bba954a1175492d2496b1dda16526e09cc98170 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py @@ -0,0 +1,721 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple +import numpy as np +import torch +import os +import torch.nn.functional as F +try: + import warp as wp +except ImportError: + raise ImportError("NVIDIA Warp is required for ray-triangle intersection") + +_warp_initialized = False +_ray_triangle_intersection_func = None + +def _init_warp(): + global _warp_initialized, _ray_triangle_intersection_func + + if not _warp_initialized: + print(f"Initializing Warp library (local_rank {os.getenv('LOCAL_RANK')})...") + wp.init() + _warp_initialized = True + print(f"Warp library initialized successfully (local_rank {os.getenv('LOCAL_RANK')})") + + if _ray_triangle_intersection_func is None: + try: + from .ray_triangle_intersection_warp import ray_triangle_intersection_warp + _ray_triangle_intersection_func = ray_triangle_intersection_warp + print(f"Warp: ray_triangle_intersection_warp kernel loaded (local_rank {os.getenv('LOCAL_RANK')})") + except ImportError: + from ray_triangle_intersection_warp import ray_triangle_intersection_warp + _ray_triangle_intersection_func = ray_triangle_intersection_warp + print(f"Warp: ray_triangle_intersection_warp kernel loaded (local_rank {os.getenv('LOCAL_RANK')})") + + +def points_to_mesh(points, mask, resolution=None): + """ + Convert a grid of 3D points to a triangle mesh based on mask. + + Args: + points: Tensor of shape [H, W, 3] containing 3D points + mask: Tensor of shape [H, W] containing binary mask + resolution: Optional tuple (new_H, new_W) to resize to + + Returns: + vertices: Tensor of shape [N, 3] containing unique vertices + faces: Tensor of shape [M, 3] containing triangle indices + """ + H, W = points.shape[:2] + + # Resize if resolution is provided + if resolution is not None: + new_H, new_W = resolution + # Resize points using bilinear interpolation + points = points.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] + points = F.interpolate(points, size=(new_H, new_W), mode='bilinear', align_corners=False) + points = points.squeeze(0).permute(1, 2, 0) # [new_H, new_W, 3] + + # Resize mask using nearest neighbor + mask = mask.unsqueeze(0).unsqueeze(0).float() # [1, 1, H, W] + mask = F.interpolate(mask, size=(new_H, new_W), mode='nearest') + mask = mask.squeeze(0).squeeze(0).bool() # [new_H, new_W] + + H, W = new_H, new_W + + # Create vertex indices grid + vertex_indices = torch.arange(H * W, device=points.device).reshape(H, W) + + # Find 2x2 patches where at least one vertex is in the mask + # Create shifted views for efficient neighbor checking + mask_tl = mask[:-1, :-1] # top-left + mask_tr = mask[:-1, 1:] # top-right + mask_bl = mask[1:, :-1] # bottom-left + mask_br = mask[1:, 1:] # bottom-right + + # A patch is valid if any of its 4 vertices is in the mask + valid_patches = mask_tl | mask_tr | mask_bl | mask_br # [H-1, W-1] + + # Get indices of valid patches + valid_h, valid_w = torch.where(valid_patches) + + # For each valid patch, create two triangles + # Triangle 1: (u,v), (u,v+1), (u+1,v) + # Triangle 2: (u,v+1), (u+1,v+1), (u+1,v) + n_valid = len(valid_h) + + if n_valid == 0: + # No valid patches, return empty mesh + return torch.empty((0, 3), device=points.device), torch.empty((0, 3), dtype=torch.long, device=points.device) + + # Vectorized triangle creation + idx_tl = vertex_indices[valid_h, valid_w] # top-left + idx_tr = vertex_indices[valid_h, valid_w + 1] # top-right + idx_bl = vertex_indices[valid_h + 1, valid_w] # bottom-left + idx_br = vertex_indices[valid_h + 1, valid_w + 1] # bottom-right + + # Create faces (2 triangles per patch) + faces1 = torch.stack([idx_tl, idx_tr, idx_bl], dim=1) # [n_valid, 3] + faces2 = torch.stack([idx_tr, idx_br, idx_bl], dim=1) # [n_valid, 3] + faces = torch.cat([faces1, faces2], dim=0) # [2*n_valid, 3] + + # Flatten points to get vertices + vertices = points.reshape(-1, 3) # [H*W, 3] + + # Optional: Remove unused vertices and remap faces + # First, find which vertices are actually used + used_vertices = torch.unique(faces.flatten()) + + # Create a mapping from old indices to new indices + new_idx_map = torch.full((H * W,), -1, dtype=torch.long, device=points.device) + new_idx_map[used_vertices] = torch.arange(len(used_vertices), device=points.device) + + # Extract only used vertices + vertices = vertices[used_vertices] + + # Remap face indices + faces = new_idx_map[faces.flatten()].reshape(-1, 3) + + return vertices, faces + +def get_max_exponent_for_dtype(dtype): + # Set the maximum exponent based on dtype + if dtype == torch.bfloat16: + return 80.0 # Safe maximum exponent for bfloat16 + elif dtype == torch.float16: + return 10.0 # Safe maximum exponent for float16 + elif dtype == torch.float32: + return 80.0 # Safe maximum exponent for float32 + elif dtype == torch.float64: + return 700.0 # Safe maximum exponent for float64 + else: + return 80.0 # Default safe value + +def inverse_with_conversion(mtx): + return torch.linalg.inv(mtx.to(torch.float32)).to(mtx.dtype) + + +def get_camera_rays(h, w, intrinsic: np.ndarray) -> np.ndarray: + """Backproject 2D pixels into 3D rays.""" + device = intrinsic.device + x1d = torch.arange(0, w, device=device, dtype=intrinsic.dtype)[None] + y1d = torch.arange(0, h, device=device, dtype=intrinsic.dtype)[:, None] + x2d = x1d.repeat([h, 1]) # .to(intrinsic) # (h, w) + y2d = y1d.repeat([1, w]) # .to(intrinsic) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device, dtype=intrinsic.dtype) # .to(intrinsic) # (h, w) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + # Normalize the rays + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo).squeeze(-1) + # Normalize the rays + norm = torch.norm(unnormalized_pos, dim=-1, keepdim=True) + norm[norm == 0] = 1 + return unnormalized_pos / norm + + +def forward_warp( + frame1: torch.Tensor, + mask1: Optional[torch.Tensor], + depth1: Optional[torch.Tensor], + transformation1: Optional[torch.Tensor], + transformation2: torch.Tensor, + intrinsic1: Optional[torch.Tensor], + intrinsic2: Optional[torch.Tensor], + is_image=True, + conditioned_normal1=None, + cameraray_filtering=False, + is_depth=True, + render_depth=False, + world_points1=None, + foreground_masking=False, + boundary_mask=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using + bilinear splatting. + All arrays should be torch tensors with batch dimension and channel first + :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling + bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting() + method accordingly. + :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional + :param depth1: (b, 1, h, w) + :param transformation1: (b, 4, 4) extrinsic transformation matrix (camera-to-world pose) of first view. Required if depth1 is not None, or if cleaning is enabled. + :param transformation2: (b, 4, 4) extrinsic transformation matrix (camera-to-world pose) of second view. + :param intrinsic1: (b, 3, 3) camera intrinsic matrix. Required if depth1 is not None. + :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional (defaults to intrinsic1 if provided). + :param is_image: bool, whether frame1 represents image data (affects clipping and fill value). + :param conditioned_normal1: Optional (b, 3, h, w) normals for filtering. + :param cameraray_filtering: bool, use camera rays for filtering instead of normals. + :param is_depth: bool, whether depth1 represents depth along Z or distance to camera center. Used only if depth1 is not None. + :param render_depth: bool, whether to also render and return the warped depth map. + :param world_points1: Optional (b, h, w, 3) world points. Required if depth1 is None. + :param foreground_masking: bool, enable foreground occlusion masking using mesh rendering. + :param boundary_mask: Optional (b, h, w) mask for mesh generation, required if foreground_masking is True. + """ + device = frame1.device + b, c, h, w = frame1.shape + dtype = frame1.dtype + if mask1 is None: + mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=frame1.dtype) + if intrinsic2 is None: + assert intrinsic1 is not None, "intrinsic2 cannot be derived if intrinsic1 is None and intrinsic2 is None" + intrinsic2 = intrinsic1.clone() + + if depth1 is None: + assert world_points1.shape == (b, h, w, 3) + if foreground_masking: + trans_points1, cam_points_target = project_points(world_points1, transformation2, intrinsic2, return_cam_points=True) + else: + trans_points1 = project_points(world_points1, transformation2, intrinsic2) + else: + # assert frame1.shape == (b, 3, h, w) + assert mask1.shape == (b, 1, h, w) + assert depth1.shape == (b, 1, h, w) + assert transformation1.shape == (b, 4, 4) + assert transformation2.shape == (b, 4, 4) + assert intrinsic1.shape == (b, 3, 3) + assert intrinsic2.shape == (b, 3, 3) + + depth1 = torch.nan_to_num(depth1, nan=1e4) + depth1 = torch.clamp(depth1, min=0, max=1e4) + if foreground_masking: + trans_points1, cam_points_target = compute_transformed_points( + depth1, transformation1, transformation2, intrinsic1, is_depth, intrinsic2, return_cam_points=True + ) + else: + trans_points1 = compute_transformed_points( + depth1, transformation1, transformation2, intrinsic1, is_depth, intrinsic2 + ) + mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0) + trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7) + trans_coordinates = trans_coordinates.permute(0, 3, 1, 2) # b, 2, h, w + trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1) + + grid = create_grid(b, h, w, device=device, dtype=dtype) # .to(trans_coordinates) + flow12 = trans_coordinates - grid + if conditioned_normal1 is not None or cameraray_filtering: + camera_rays = get_camera_rays(h, w, intrinsic1) # b, h, w, 3 + transformation = torch.bmm(transformation2, inverse_with_conversion(transformation1)) + transformation[:, :3, 3] = 0 + trans_4d = transformation[:, None, None] + if cameraray_filtering: # use normal for filtering + conditioned_normal1 = camera_rays + inversion_vector = torch.tensor([-1, -1, -1], dtype=camera_rays.dtype, device=camera_rays.device).view( + 1, 1, 1, 3, 1 + ) + else: # use normal for filtering + assert conditioned_normal1.shape == (b, 3, h, w) + inversion_vector = torch.tensor([-1, 1, 1], dtype=camera_rays.dtype, device=camera_rays.device).view( + 1, 1, 1, 3, 1 + ) + conditioned_normal1 = conditioned_normal1.permute(0, 2, 3, 1) + # rotate normal into target camera spaces + normal_4d = conditioned_normal1.unsqueeze(-1) + b, _, h, w = depth1.shape + ones_2d = torch.ones(size=(h, w), device=device, dtype=dtype) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) + normal_4d_homo = torch.cat([normal_4d * inversion_vector, ones_4d], dim=3) + + trans_normal = torch.matmul(trans_4d, normal_4d_homo).squeeze(-1)[..., :3] # b, h, w, 3 + dot_product = torch.sum(trans_normal * camera_rays, dim=-1) + + # Create binary mask for angles < 90 degrees + binary_mask = dot_product > 0 + # import ipdb;ipdb.set_trace() + mask1 *= binary_mask.unsqueeze(1) + warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image) + warped_depth2 = None + if render_depth or foreground_masking: + warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0] + if foreground_masking: + for batch_idx in range(b): + assert boundary_mask is not None + mesh_mask = boundary_mask[batch_idx] + + mesh_downsample_factor = 4 + vertices_masked, faces_masked = points_to_mesh( + cam_points_target[batch_idx], + mesh_mask, + resolution=(h // mesh_downsample_factor, w // mesh_downsample_factor) + ) + + if vertices_masked.shape[0] == 0 or faces_masked.shape[0] == 0: + continue + + ray_scale_factor = 1 + ray_downsampled_h = h // ray_scale_factor + ray_downsampled_w = w // ray_scale_factor + current_intrinsic_batch = intrinsic2[batch_idx:batch_idx+1] + scaled_intrinsic = current_intrinsic_batch.clone() + + scaled_intrinsic[0, 0, 0] /= ray_scale_factor # fx + scaled_intrinsic[0, 1, 1] /= ray_scale_factor # fy + scaled_intrinsic[0, 0, 2] /= ray_scale_factor # cx + scaled_intrinsic[0, 1, 2] /= ray_scale_factor # cy + + camera_rays = get_camera_rays(ray_downsampled_h, ray_downsampled_w, scaled_intrinsic) # (1, h_ds, w_ds, 3) + camera_rays = camera_rays[0] # (h_ds, w_ds, 3) + + ray_origins = torch.zeros((ray_downsampled_h, ray_downsampled_w, 3), device=device, dtype=dtype) + + mesh_depth = ray_triangle_intersection( + ray_origins, + camera_rays, + vertices_masked, + faces_masked, + device + ) + ray_z = camera_rays[:, :, 2] # (h, w) + mesh_z_depth = mesh_depth * ray_z # Convert to z-depth + mesh_z_depth = F.interpolate(mesh_z_depth.unsqueeze(0).unsqueeze(0), size=(h, w), mode='bilinear').squeeze(0).squeeze(0) + + warped_depth_batch = warped_depth2[batch_idx] # (h, w) + + + mesh_valid = mesh_z_depth > 0 + mesh_closer = ((mesh_z_depth + 0.02) < warped_depth_batch) & mesh_valid + + mask2[batch_idx, 0] = mask2[batch_idx, 0] * (~mesh_closer).float() + warped_frame2[batch_idx] = (warped_frame2[batch_idx] + 1) * (~mesh_closer.unsqueeze(0)).float() - 1 + warped_depth2[batch_idx] = warped_depth2[batch_idx] * (~mesh_closer.unsqueeze(0)).float() + return warped_frame2, mask2, warped_depth2, flow12 + +def reliable_depth_mask_range_batch(depth, window_size=5, ratio_thresh=0.05, eps=1e-6): + assert window_size % 2 == 1, "Window size must be odd." + if depth.dim() == 3: # Input shape: (B, H, W) + depth_unsq = depth.unsqueeze(1) + elif depth.dim() == 4: # Already has shape (B, 1, H, W) + depth_unsq = depth + else: + raise ValueError("depth tensor must be of shape (B, H, W) or (B, 1, H, W)") + + local_max = torch.nn.functional.max_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + local_min = -torch.nn.functional.max_pool2d(-depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + local_mean = torch.nn.functional.avg_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + ratio = (local_max - local_min) / (local_mean + eps) + reliable_mask = (ratio < ratio_thresh) & (depth_unsq > 0) + + return reliable_mask + +def double_forward_warp( + frame1: torch.Tensor, + mask1: torch.Tensor, + depth1: torch.Tensor, + intrinsic1: torch.Tensor, + double_proj_w2cs: torch.Tensor, +): + """ + Double projection using forward warping with your APIs. + + 1. Warps frame1 from the original view (identity transformation) + to the target view defined by double_proj_w2cs. + 2. Computes a warped flow field and then warps the intermediate result + back to the original view using the original depth. + + :param frame1: (b, 3, h, w) original image. + :param mask1: (b, 1, h, w) valid mask. + :param depth1: (b, 1, h, w) depth map. + :param intrinsic1: (b, 3, 3) intrinsic matrix. + :param double_proj_w2cs: (b, 4, 4) target view transformation. + :return: twice_warped_frame1, warped_frame2, None, None + """ + b, c, h, w = frame1.shape + device, dtype = frame1.device, frame1.dtype + + if mask1 is None: + mask1 = torch.ones((b, 1, h, w), device=device, dtype=dtype) + + # Use identity transformation for the original view. + identity = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(b, 1, 1) + + trans_points = compute_transformed_points( + depth1, identity, double_proj_w2cs, intrinsic1, is_depth=True, intrinsic2=intrinsic1 + ) + trans_coordinates = trans_points[:, :, :, :2, 0] / (trans_points[:, :, :, 2:3, 0] + 1e-7) + trans_depth = trans_points[:, :, :, 2, 0] + + grid = create_grid(b, h, w, device=device, dtype=dtype) + flow12 = trans_coordinates.permute(0, 3, 1, 2) - grid + + warped_frame2, mask2 = bilinear_splatting( + frame1, mask1, trans_depth.unsqueeze(1), flow12, None, is_image=True, n_views=1, depth_weight_scale=50 + ) + + warped_flow, _ = bilinear_splatting( + flow12, mask1, trans_depth.unsqueeze(1), flow12, None, is_image=False, n_views=1, depth_weight_scale=50 + ) + + twice_warped_frame1, twice_warped_mask1 = bilinear_splatting( + warped_frame2, mask2, depth1, -warped_flow, None, is_image=True, n_views=1, depth_weight_scale=50 + ) + + return twice_warped_frame1, twice_warped_mask1, warped_frame2, mask2 + + +def unproject_points(depth: torch.Tensor, + w2c: torch.Tensor, + intrinsic: torch.Tensor, + is_depth: bool = True, + mask: Optional[torch.Tensor] = None): + + b, _, h, w = depth.shape + device = depth.device + dtype = depth.dtype + if mask is None: + mask = depth > 0 + if mask.dim() == depth.dim() and mask.shape[1] == 1: + mask = mask[:, 0] + + idx = torch.nonzero(mask) + if idx.numel() == 0: + return torch.zeros((b, h, w, 3), device=device, dtype=dtype) + + b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2] + + + intrinsic_inv = inverse_with_conversion(intrinsic) # (b, 3, 3) + + x_valid = x_idx.to(dtype) + y_valid = y_idx.to(dtype) + ones = torch.ones_like(x_valid) + pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1) # (N, 3, 1) + + intrinsic_inv_valid = intrinsic_inv[b_idx] # (N, 3, 3) + unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos) # (N, 3, 1) + + depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1) + if is_depth: + world_points_cam = depth_valid * unnormalized_pos + else: + norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True) + direction = unnormalized_pos / (norm_val + 1e-8) + world_points_cam = depth_valid * direction + + ones_h = torch.ones((world_points_cam.shape[0], 1, 1), + device=device, dtype=dtype) + world_points_homo = torch.cat([world_points_cam, ones_h], dim=1) # (N, 4, 1) + + trans = inverse_with_conversion(w2c) # (b, 4, 4) + trans_valid = trans[b_idx] # (N, 4, 4) + world_points_transformed = torch.matmul(trans_valid, world_points_homo) # (N, 4, 1) + sparse_points = world_points_transformed[:, :3, 0] # (N, 3) + + out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype) + out_points[b_idx, y_idx, x_idx, :] = sparse_points + return out_points + +def project_points(world_points: torch.Tensor, w2c: torch.Tensor, intrinsic: torch.Tensor, return_cam_points: bool = False): + """ + Projects 3D world points back into 2D pixel space. + """ + world_points = world_points.unsqueeze(-1) # (b, h, w, 3) -> # (b, h, w, 3, 1) + b, h, w, _, _ = world_points.shape + + ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype) # (b, h, w, 1, 1) + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + + # Apply transformation2 to convert world points to camera space + trans_4d = w2c[:, None, None] # (b, 1, 1, 4, 4) + camera_points_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + + # Remove homogeneous coordinate and project to image plane + camera_points = camera_points_homo[:, :, :, :3] # (b, h, w, 3, 1) + intrinsic_4d = intrinsic[:, None, None] # (b, 1, 1, 3, 3) + projected_points = torch.matmul(intrinsic_4d, camera_points) # (b, h, w, 3, 1) + + if return_cam_points: + # Return both projected points and camera space points + cam_points_3d = camera_points.squeeze(-1) # (b, h, w, 3) + return projected_points, cam_points_3d + else: + return projected_points + + +def unproject_depth_torch( + depth1: torch.Tensor, + transformation1: torch.Tensor, + intrinsic1: torch.Tensor, +) -> torch.Tensor: + b, c, h, w = depth1.shape + assert depth1.shape == (b, 1, h, w) + assert transformation1.shape == (b, 4, 4) + assert intrinsic1.shape == (b, 3, 3) + device = depth1.device + x1d = torch.arange(0, w, device=device)[None] + y1d = torch.arange(0, h, device=device)[:, None] + x2d = x1d.repeat([h, 1]) # .to(depth1) # (h, w) + y2d = y1d.repeat([1, w]) # .to(depth1) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic1) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + + depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1) + + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) # (b, h, w, 3, 1) + world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1) + + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + trans_4d = transformation1[:, None, None] # (b, 1, 1, 4, 4) + trans_world_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1) + trans_world = trans_world.squeeze(dim=-1) + return trans_world + + +def compute_transformed_points( + depth1: torch.Tensor, + transformation1: torch.Tensor, + transformation2: torch.Tensor, + intrinsic1: torch.Tensor, + is_depth: bool = True, + intrinsic2: Optional[torch.Tensor] = None, + return_cam_points: bool = False, +): + """ + Computes transformed position for each pixel location + """ + b, _, h, w = depth1.shape + if intrinsic2 is None: + intrinsic2 = intrinsic1.clone() + transformation = torch.bmm( + transformation2, inverse_with_conversion(transformation1) + ) # (b, 4, 4) transformation is w2c + device = depth1.device + x1d = torch.arange(0, w, device=device, dtype=depth1.dtype)[None] + y1d = torch.arange(0, h, device=device, dtype=depth1.dtype)[:, None] + x2d = x1d.repeat([h, 1]) # .to(depth1) # (h, w) + y2d = y1d.repeat([1, w]) # .to(depth1) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device, dtype=depth1.dtype) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic1) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + intrinsic2_4d = intrinsic2[:, None, None] # (b, 1, 1, 3, 3) + depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1) + trans_4d = transformation[:, None, None] # (b, 1, 1, 4, 4) + + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) # (b, h, w, 3, 1) + if is_depth: + world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1) + else: # if 'depth' is defined as distance to camera center + direction_vectors = unnormalized_pos / torch.norm(unnormalized_pos, dim=-2, keepdim=True) # (b, h, w, 3, 1) + world_points = depth_4d * direction_vectors # (b, h, w, 3, 1) + + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + trans_world_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1) + trans_norm_points = torch.matmul(intrinsic2_4d, trans_world) # (b, h, w, 3, 1) + + if return_cam_points: + # Return both projected points and camera space points + cam_points = trans_world.squeeze(-1) # (b, h, w, 3) + return trans_norm_points, cam_points + else: + return trans_norm_points + + +def bilinear_splatting( + frame1: torch.Tensor, + mask1: Optional[torch.Tensor], + depth1: torch.Tensor, + flow12: torch.Tensor, + flow12_mask: Optional[torch.Tensor], + is_image: bool = False, + n_views=1, + depth_weight_scale=50, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Bilinear splatting + :param frame1: (b,c,h,w) + :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional + :param depth1: (b,1,h,w) + :param flow12: (b,2,h,w) + :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional + :param is_image: if true, output will be clipped to (-1,1) range + :return: warped_frame2: (b,c,h,w) + mask2: (b,1,h,w): 1 for known and 0 for unknown + """ + b, c, h, w = frame1.shape + device = frame1.device + dtype = frame1.dtype + if mask1 is None: + mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(frame1) + if flow12_mask is None: + flow12_mask = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(flow12) + grid = create_grid(b, h, w, device=device, dtype=dtype).to(dtype) # .to(frame1) + trans_pos = flow12 + grid + + trans_pos_offset = trans_pos + 1 + trans_pos_floor = torch.floor(trans_pos_offset).long() + trans_pos_ceil = torch.ceil(trans_pos_offset).long() + trans_pos_offset = torch.stack( + [torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], + dim=1, + ) + trans_pos_floor = torch.stack( + [torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], + dim=1, + ) + trans_pos_ceil = torch.stack( + [torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], + dim=1, + ) + + prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( + 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) + ) + prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( + 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) + ) + prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( + 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) + ) + prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( + 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) + ) + + # Calculate depth weights, preventing overflow and removing saturation + # Clamp depth to be non-negative before log1p + clamped_depth1 = torch.clamp(depth1, min=0) + log_depth1 = torch.log1p(clamped_depth1) # Use log1p for better precision near 0 + # Normalize and scale log depth + exponent = log_depth1 / (log_depth1.max() + 1e-7) * depth_weight_scale + # Clamp exponent before exp to prevent overflow + max_exponent = get_max_exponent_for_dtype(depth1.dtype) + clamped_exponent = torch.clamp(exponent, max=max_exponent) + # Compute depth weights with added epsilon for stability when dividing later + depth_weights = torch.exp(clamped_exponent) + 1e-7 + + + weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + + warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=dtype, device=device) # .to(frame1) + warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=dtype, device=device) # .to(frame1) + + frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2]) + batch_indices = torch.arange(b, device=device, dtype=torch.long)[:, None, None] # .to(frame1.device) + warped_frame.index_put_( + (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_nw, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_sw, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_ne, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_se, accumulate=True + ) + + warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True) + if n_views > 1: + warped_frame = warped_frame.reshape(b // n_views, n_views, h + 2, w + 2, c).sum(1) + warped_weights = warped_weights.reshape(b // n_views, n_views, h + 2, w + 2, 1).sum(1) + + warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1]) + warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1]) + cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1] + cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1] + cropped_weights = torch.nan_to_num(cropped_weights, nan=1000.0) + + mask = cropped_weights > 0 + zero_value = -1 if is_image else 0 + zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device) + warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor) + mask2 = mask.to(frame1) + if is_image: + # assert warped_frame2.min() >= -1.1 # Allow for rounding errors + # assert warped_frame2.max() <= 1.1 + warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1) + return warped_frame2, mask2 + +def create_grid(b: int, h: int, w: int, device="cpu", dtype=torch.float) -> torch.Tensor: + """ + Create a dense grid of (x,y) coordinates of shape (b, 2, h, w). + """ + x = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w) + y = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w) + return torch.cat([x, y], dim=1) + +def ray_triangle_intersection( + ray_origins: torch.Tensor, # (H, W, 3) + ray_directions: torch.Tensor, # (H, W, 3) + vertices: torch.Tensor, # (N, 3) + faces: torch.Tensor, # (M, 3) + device: torch.device +) -> torch.Tensor: + """ + Compute ray-triangle intersections for all rays and triangles. + Returns depth map of shape (H, W) with intersection distances. + + Uses NVIDIA Warp acceleration for fast performance. + """ + _init_warp() + return _ray_triangle_intersection_func( + ray_origins, ray_directions, vertices, faces, device + ) \ No newline at end of file diff --git a/cosmos_predict1/diffusion/inference/gen3c_dynamic.py b/cosmos_predict1/diffusion/inference/gen3c_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..d803afdca409b2fd4a9c15bc992a011469371ae9 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_dynamic.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, +) +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache4D +from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory +from cosmos_predict1.diffusion.inference.data_loader_utils import load_data_auto_detect +import torch.nn.functional as F +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) # TODO: do we need this? + parser.add_argument( + "--input_image_path", + type=str, + help="Input image path for generating a single video", + ) + parser.add_argument( + "--trajectory", + type=str, + choices=[ + "left", + "right", + "up", + "down", + "zoom_in", + "zoom_out", + "clockwise", + "counterclockwise", + ], + default="left", + help="Select a trajectory type from the available options (default: original)", + ) + parser.add_argument( + "--camera_rotation", + type=str, + choices=["center_facing", "no_rotation", "trajectory_aligned"], + default="center_facing", + help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", + ) + parser.add_argument( + "--movement_distance", + type=float, + default=0.3, + help="Distance of the camera from the center of the scene", + ) + parser.add_argument( + "--save_buffer", + action="store_true", + help="If set, save the warped images (buffer) side by side with the output video.", + ) + parser.add_argument( + "--filter_points_threshold", + type=float, + default=0.05, + help="If set, filter the points continuity of the warped images.", + ) + parser.add_argument( + "--foreground_masking", + action="store_true", + help="If set, use foreground masking for the warped images.", + ) + return parser.parse_args() + +def validate_args(args): + assert args.num_video_frames is not None, "num_video_frames must be provided" + assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" + + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + disable_prompt_encoder=args.disable_prompt_encoder, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=121, + seed=args.seed, + ) + + sample_n_frames = pipeline.model.chunk_size + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] + + os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_video_path = input_dict.get("visual_input", None) + if current_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Load data using the new auto-detect loader (supports both old pt and new format) + try: + ( + image_bchw_float, + depth_b1hw, + mask_b1hw, + initial_w2c_b44, + intrinsics_b33, + ) = load_data_auto_detect(current_video_path) + except Exception as e: + log.critical(f"Failed to load visual input from {current_video_path}: {e}") + continue + + image_bchw_float = image_bchw_float.to(device) + depth_b1hw = depth_b1hw.to(device) + mask_b1hw = mask_b1hw.to(device) + initial_w2c_b44 = initial_w2c_b44.to(device) + intrinsics_b33 = intrinsics_b33.to(device) + + cache = Cache4D( + input_image=image_bchw_float.clone(), # [B, C, H, W] + input_depth=depth_b1hw, # [B, 1, H, W] + input_mask=mask_b1hw, # [B, 1, H, W] + input_w2c=initial_w2c_b44, # [B, 4, 4] + input_intrinsics=intrinsics_b33,# [B, 3, 3] + filter_points_threshold=args.filter_points_threshold, + input_format=["F", "C", "H", "W"], + foreground_masking=args.foreground_masking, + ) + + initial_cam_w2c_for_traj = initial_w2c_b44 + initial_cam_intrinsics_for_traj = intrinsics_b33 + + # Generate camera trajectory using the new utility function + try: + generated_w2cs, generated_intrinsics = generate_camera_trajectory( + trajectory_type=args.trajectory, + initial_w2c=initial_cam_w2c_for_traj, + initial_intrinsics=initial_cam_intrinsics_for_traj, + num_frames=args.num_video_frames, + movement_distance=args.movement_distance, + camera_rotation=args.camera_rotation, + center_depth=1.0, + device=device.type, + ) + except (ValueError, NotImplementedError) as e: + log.critical(f"Failed to generate trajectory: {e}") + continue + + log.info(f"Generating 0 - {sample_n_frames} frames") + + rendered_warp_images, rendered_warp_masks = cache.render_cache( + generated_w2cs[:, 0:sample_n_frames], + generated_intrinsics[:, 0:sample_n_frames], + start_frame_idx=0, + ) + + all_rendered_warps = [] + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=image_bchw_float[0].unsqueeze(0).unsqueeze(2), + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) + for num_iter in range(1, num_ar_iterations): + start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame + end_frame_idx = start_frame_idx + sample_n_frames + + log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") + + last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] + rendered_warp_images, rendered_warp_masks = cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + start_frame_idx=start_frame_idx, + ) + + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) + + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, prompt = generated_output + video = np.concatenate([video, video_new[1:]], axis=0) + + # Final video processing + final_video_to_save = video + final_width = args.width + + if args.save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + else: + log.info("No warp buffers to save.") + + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + + os.makedirs(os.path.dirname(video_save_path), exist_ok=True) + + # Save video + save_video( + video=final_video_to_save, + fps=args.fps, + H=args.height, + W=final_width, + video_save_quality=5, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + if args.prompt is None: + args.prompt = "" + args.disable_guardrail = True + args.disable_prompt_upsampler = True + demo(args) diff --git a/cosmos_predict1/diffusion/inference/gen3c_persistent.py b/cosmos_predict1/diffusion/inference/gen3c_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..8113d53f306781282fad2dafc700a79ae60cf266 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_persistent.py @@ -0,0 +1,569 @@ +import argparse +import os +import time + +from moge.model.v1 import MoGeModel +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.diffusion.inference.gen3c_single_image import ( + create_parser as create_parser_base, + validate_args as validate_args_base, + _predict_moge_depth, + _predict_moge_depth_from_tensor +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.distributed import device_with_rank, is_rank0, get_rank +from cosmos_predict1.utils.io import save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer, Cache4D +import torch.nn.functional as F + + +def create_parser(): + return create_parser_base() + + +def validate_args(args: argparse.Namespace): + validate_args_base(args) + assert args.batch_input_path is None, "Unsupported in persistent mode" + assert args.prompt is not None, "Prompt is required in persistent mode (but it can be the empty string)" + assert args.input_image_path is None, "Image should be provided directly by value in persistent mode" + assert args.trajectory in (None, 'none'), "Trajectory should be provided directly by value in persistent mode, set --trajectory=none" + assert not args.video_save_name, f"Video saving name will be set automatically for each inference request. Found string: \"{args.video_save_name}\"" + + +def resize_intrinsics(intrinsics: np.ndarray | torch.Tensor, + old_size: tuple[int, int], new_size: tuple[int, int], + crop_size: tuple[int, int] | None = None) -> np.ndarray | torch.Tensor: + # intrinsics: (3, 3) + # old_size: (h1, w1) + # new_size: (h2, w2) + if isinstance(intrinsics, np.ndarray): + intrinsics_copy = np.copy(intrinsics) + elif isinstance(intrinsics, torch.Tensor): + intrinsics_copy = intrinsics.clone() + else: + raise ValueError(f"Invalid intrinsics type: {type(intrinsics)}") + intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1] + intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0] + if crop_size is not None: + intrinsics_copy[:, 0, -1] = intrinsics_copy[:, 0, -1] - (new_size[1] - crop_size[1]) / 2 + intrinsics_copy[:, 1, -1] = intrinsics_copy[:, 1, -1] - (new_size[0] - crop_size[0]) / 2 + return intrinsics_copy + + +class Gen3cPersistentModel(): + """Helper class to run Gen3C image-to-video or video-to-video inference. + + This class loads the models only once and can be reused for multiple inputs. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + """ + + @torch.no_grad() + def __init__(self, args: argparse.Namespace): + misc.set_random_seed(args.seed) + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + self.frames_per_batch = 121 + self.inference_overlap_frames = 1 + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type="video2world", + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=self.frames_per_batch, + seed=args.seed, + ) + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + self.args = args + self.frame_buffer_max = pipeline.model.frame_buffer_max + self.generator = torch.Generator(device=device).manual_seed(args.seed) + self.sample_n_frames = pipeline.model.chunk_size + self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + self.pipeline = pipeline + self.device = device + self.device_with_rank = device_with_rank(self.device) + + self.cache: Cache3D_Buffer | Cache4D | None = None + self.model_was_seeded = False + # User-provided seeding image, after pre-processing. + # Shape [B, C, T, H, W], type float, range [-1, 1]. + self.seeding_image: torch.Tensor | None = None + + + @torch.no_grad() + def seed_model_from_values(self, + images_np: np.ndarray, + depths_np: np.ndarray | None, + world_to_cameras_np: np.ndarray, + focal_lengths_np: np.ndarray, + principal_point_rel_np: np.ndarray, + resolutions: np.ndarray, + masks_np: np.ndarray | None = None): + import torchvision.transforms.functional as transforms_F + + # Check inputs + n = images_np.shape[0] + assert images_np.shape[-1] == 3 + assert world_to_cameras_np.shape == (n, 4, 4) + assert focal_lengths_np.shape == (n, 2) + assert principal_point_rel_np.shape == (n, 2) + assert resolutions.shape == (n, 2) + assert (depths_np is None) or (depths_np.shape == images_np.shape[:-1]) + assert (masks_np is None) or (masks_np.shape == images_np.shape[:-1]) + + + if n == 1: + # TODO: allow user to provide depths, extrinsics and intrinsics + assert depths_np is None, "Not supported yet: directly providing pre-estimated depth values along with a single image." + + # Note: image is received as 0..1 float, but MoGE expects 0..255 uint8. + input_image_np = images_np[0, ...] * 255.0 + del images_np + + # Predict depth and initialize 3D cache. + # Note: even though internally MoGE may use a different resolution, all of the outputs + # are properly resized & adapted to our desired (self.args.height, self.args.width) resolution, + # including the intrinsics. + ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) = _predict_moge_depth( + input_image_np, self.args.height, self.args.width, self.device_with_rank, self.moge_model + ) + + # TODO: MoGE provides camera params, is it okay to just ignore the user-provided ones? + input_image = moge_image_b1chw_float[:, 0].clone() + self.cache = Cache3D_Buffer( + frame_buffer_max=self.frame_buffer_max, + generator=self.generator, + noise_aug_strength=self.args.noise_aug_strength, + input_image=input_image, # [B, C, H, W] + input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] + # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] + input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] + input_intrinsics=moge_intrinsics_b133[:, 0], # [B, 3, 3] + filter_points_threshold=self.args.filter_points_threshold, + foreground_masking=self.args.foreground_masking, + ) + + seeding_image = input_image_np.transpose(2, 0, 1)[None, ...] / 128.0 - 1.0 + seeding_image = torch.from_numpy(seeding_image).to(device_with_rank(self.device_with_rank)) + + # Return the estimated extrinsics and intrinsics in the same format as the input + estimated_w2c_b44_np = moge_initial_w2c_b144.cpu().numpy()[:, 0, ...] + moge_intrinsics_b133_np = moge_intrinsics_b133.cpu().numpy() + estimated_focal_lengths_b2_np = np.stack([moge_intrinsics_b133_np[:, 0, 0, 0], + moge_intrinsics_b133_np[:, 0, 1, 1]], axis=1) + estimated_principal_point_rel_b2_np = moge_intrinsics_b133_np[:, 0, :2, 2] + + else: + if depths_np is None: + raise NotImplementedError("Seeding from multiple frames requires providing depth values.") + if masks_np is None: + raise NotImplementedError("Seeding from multiple frames requires providing mask values.") + + # RGB: [B, H, W, C] to [B, C, H, W] + image_bchw_float = torch.from_numpy(images_np.transpose(0, 3, 1, 2).astype(np.float32)).to(self.device_with_rank) + # Images are received as 0..1 float32, we convert to -1..1 range. + image_bchw_float = (image_bchw_float * 2.0) - 1.0 + del images_np + + # Depth: [B, H, W] to [B, 1, H, W] + depth_b1hw = torch.from_numpy(depths_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) + # Mask: [B, H, W] to [B, 1, H, W] + mask_b1hw = torch.from_numpy(masks_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) + # World-to-camera: [B, 4, 4] + initial_w2c_b44 = torch.from_numpy(world_to_cameras_np).to(self.device_with_rank) + # Intrinsics: [B, 3, 3] + intrinsics_b33_np = np.zeros((n, 3, 3), dtype=np.float32) + intrinsics_b33_np[:, 0, 0] = focal_lengths_np[:, 0] + intrinsics_b33_np[:, 1, 1] = focal_lengths_np[:, 1] + intrinsics_b33_np[:, 0, 2] = principal_point_rel_np[:, 0] * self.args.width + intrinsics_b33_np[:, 1, 2] = principal_point_rel_np[:, 1] * self.args.height + intrinsics_b33_np[:, 2, 2] = 1.0 + intrinsics_b33 = torch.from_numpy(intrinsics_b33_np).to(self.device_with_rank) + + self.cache = Cache4D( + input_image=image_bchw_float.clone(), # [B, C, H, W] + input_depth=depth_b1hw, # [B, 1, H, W] + input_mask=mask_b1hw, # [B, 1, H, W] + input_w2c=initial_w2c_b44, # [B, 4, 4] + input_intrinsics=intrinsics_b33, # [B, 3, 3] + filter_points_threshold=self.args.filter_points_threshold, + foreground_masking=self.args.foreground_masking, + input_format=["F", "C", "H", "W"], + ) + + # Return the given extrinsics and intrinsics in the same format as the input + seeding_image = image_bchw_float + estimated_w2c_b44_np = world_to_cameras_np + estimated_focal_lengths_b2_np = focal_lengths_np + estimated_principal_point_rel_b2_np = principal_point_rel_np + + # Resize seeding image to match the desired resolution. + if (seeding_image.shape[2] != self.H) or (seeding_image.shape[3] != self.W): + # TODO: would it be better to crop if aspect ratio is off? + seeding_image = transforms_F.resize( + seeding_image, + size=(self.H, self.W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + # Switch from [B, C, H, W] to [B, C, T, H, W]. + self.seeding_image = seeding_image[:, :, None, ...] + + working_resolutions_b2_np = np.tile([[self.args.width, self.args.height]], (n, 1)) + return ( + estimated_w2c_b44_np, + estimated_focal_lengths_b2_np, + estimated_principal_point_rel_b2_np, + working_resolutions_b2_np + ) + + + @torch.no_grad() + def inference_on_cameras(self, view_cameras_w2cs: np.ndarray, view_camera_intrinsics: np.ndarray, + fps: int | float, + overlap_frames:int = 1, + return_estimated_depths: bool = False, + video_save_quality: int = 5, + save_buffer: bool | None = None) -> dict | None: + + # TODO: this is not safe if multiple inference requests are served in parallel. + # TODO: also, it's not 100% clear whether it is correct to override this request + # after initialization of the pipeline. + self.pipeline.fps = int(fps) + del fps + save_buffer = save_buffer if (save_buffer is not None) else self.args.save_buffer + + video_save_name = self.args.video_save_name + if not video_save_name: + video_save_name = f"video_{time.strftime('%Y-%m-%d_%H-%M-%S')}" + video_save_path = os.path.join(self.args.video_save_folder, f"{video_save_name}.mp4") + os.makedirs(self.args.video_save_folder, exist_ok=True) + + cache_is_multiframe = isinstance(self.cache, Cache4D) + + # Note: the inference server already adjusted intrinsics to match our + # inference resolution (self.W, self.H), so this call is just to make sure + # that all tensors have the right shape, etc. + view_cameras_w2cs, view_camera_intrinsics = self.prepare_camera_for_inference( + view_cameras_w2cs, view_camera_intrinsics, + old_size=(self.H, self.W), new_size=(self.H, self.W) + ) + + n_frames_total = view_cameras_w2cs.shape[1] + num_ar_iterations = (n_frames_total - overlap_frames) // (self.sample_n_frames - overlap_frames) + log.info(f"Generating {n_frames_total} frames will take {num_ar_iterations} auto-regressive iterations") + + # Note: camera trajectory is given by the user, no need to generate it. + log.info(f"Generating frames 0 - {self.sample_n_frames} (out of {n_frames_total} total)...") + rendered_warp_images, rendered_warp_masks = self.cache.render_cache( + view_cameras_w2cs[:, 0:self.sample_n_frames], + view_camera_intrinsics[:, 0:self.sample_n_frames], + start_frame_idx=0, + ) + + all_rendered_warps = [] + all_predicted_depth = [] + if save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + + current_prompt = self.args.prompt + if current_prompt is None and self.args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + return + + + # Generate video + starting_frame = self.seeding_image + if cache_is_multiframe: + starting_frame = starting_frame[0].unsqueeze(0) + + generated_output = self.pipeline.generate( + prompt=current_prompt, + image_path=starting_frame, + negative_prompt=self.args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + return + video, _ = generated_output + + + def depth_for_frame(frame: np.ndarray | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + last_frame_hwc_0_255 = torch.tensor(frame, device=self.device_with_rank) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + pred_depth, pred_mask = _predict_moge_depth_from_tensor( + pred_image_for_depth_chw_0_1, self.moge_model + ) + return pred_depth, pred_mask, pred_image_for_depth_chw_0_1 + + + # We predict depth either if we need it (multi-round generation without depth in the cache), + # or if the user requested it explicitly. + need_depth_of_latest_frame = return_estimated_depths or (num_ar_iterations > 1 and not cache_is_multiframe) + if need_depth_of_latest_frame: + pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video[-1]) + + if return_estimated_depths: + # For easier indexing, we include entries even for the frames for which we don't predict + # depth. Since the results will be transmitted in compressed format, this hopefully + # shouldn't take up any additional bandwidth. + depths_batch_0 = np.full((video.shape[0], 1, self.H, self.W), fill_value=np.nan, + dtype=np.float32) + depths_batch_0[-1, ...] = pred_depth.cpu().numpy() + all_predicted_depth.append(depths_batch_0) + del depths_batch_0 + + + # Autoregressive generation (if needed) + for num_iter in range(1, num_ar_iterations): + # Overlap by `overlap_frames` frames + start_frame_idx = num_iter * (self.sample_n_frames - overlap_frames) + end_frame_idx = start_frame_idx + self.sample_n_frames + log.info(f"Generating frames {start_frame_idx} - {end_frame_idx} (out of {n_frames_total} total)...") + + if cache_is_multiframe: + # Nothing much to do, we assume that depth is alraedy provided and + # all frames of the seeding video are already in the cache. + pred_image_for_depth_chw_0_1 = torch.tensor( + video[-1], device=self.device_with_rank + ).permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + else: + self.cache.update_cache( + new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] + new_depth=pred_depth, # (1,1,H,W) + # new_mask=pred_mask, # (1,1,H,W) + new_w2c=view_cameras_w2cs[:, start_frame_idx], + new_intrinsics=view_camera_intrinsics[:, start_frame_idx], + ) + + current_segment_w2cs = view_cameras_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = view_camera_intrinsics[:, start_frame_idx:end_frame_idx] + + cache_start_frame_idx = 0 + if cache_is_multiframe: + # If requesting more frames than are available in the cache, + # freeze (hold) on the last batch of frames. + cache_start_frame_idx = min( + start_frame_idx, + self.cache.input_frame_count() - (end_frame_idx - start_frame_idx) + ) + + rendered_warp_images, rendered_warp_masks = self.cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + start_frame_idx=cache_start_frame_idx, + ) + + if save_buffer: + all_rendered_warps.append(rendered_warp_images[:, overlap_frames:].clone().cpu()) + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = self.pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=self.args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, _ = generated_output + + video = np.concatenate([video, video_new[overlap_frames:]], axis=0) + + # Prepare depth prediction for the next AR iteration. + need_depth_of_latest_frame = return_estimated_depths or ((num_iter < num_ar_iterations - 1) and not cache_is_multiframe) + if need_depth_of_latest_frame: + # Either we don't have depth (e.g. single-image seeding), or the user requested + # depth to be returned explicitly. + pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video_new[-1]) + if return_estimated_depths: + depths_batch_i = np.full((video_new.shape[0] - overlap_frames, 1, self.H, self.W), + fill_value=np.nan, dtype=np.float32) + depths_batch_i[-1, ...] = pred_depth.cpu().numpy() + all_predicted_depth.append(depths_batch_i) + del depths_batch_i + + + if is_rank0(): + # Final video processing + final_video_to_save = video + final_width = self.args.width + + if save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = self.args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + + else: + log.info("No warp buffers to save.") + + # Save video + save_video( + video=final_video_to_save, + fps=self.pipeline.fps, + H=self.args.height, + W=final_width, + video_save_quality=video_save_quality, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + + if return_estimated_depths: + predicted_depth = np.concatenate(all_predicted_depth, axis=0) + else: + predicted_depth = None + + + # Currently `video` is [n_frames, height, width, channels]. + # Return as [1, n_frames, channels, height, width] for consistency with other codebases. + video = video.transpose(0, 3, 1, 2)[None, ...] + # Depth is returned as [n_frames, channels, height, width]. + + # TODO: handle overlap + rendered_warp_images_no_overlap = rendered_warp_images + video_no_overlap = video + return { + "rendered_warp_images": rendered_warp_images, + "video": video, + "rendered_warp_images_no_overlap": rendered_warp_images_no_overlap, + "video_no_overlap": video_no_overlap, + "predicted_depth": predicted_depth, + "video_save_path": video_save_path, + } + + # -------------------- + + def prepare_camera_for_inference(self, view_cameras: np.ndarray, view_camera_intrinsics: np.ndarray, + old_size: tuple[int, int], new_size: tuple[int, int]): + """Old and new sizes should be given as (height, width).""" + if isinstance(view_cameras, np.ndarray): + view_cameras = torch.from_numpy(view_cameras).float().contiguous() + if view_cameras.ndim == 3: + view_cameras = view_cameras.unsqueeze(dim=0) + + if isinstance(view_camera_intrinsics, np.ndarray): + view_camera_intrinsics = torch.from_numpy(view_camera_intrinsics).float().contiguous() + + view_camera_intrinsics = resize_intrinsics(view_camera_intrinsics, old_size, new_size) + view_camera_intrinsics = view_camera_intrinsics.unsqueeze(dim=0) + assert view_camera_intrinsics.ndim == 4 + + return view_cameras.to(device_with_rank(self.device_with_rank)), \ + view_camera_intrinsics.to(device_with_rank(self.device_with_rank)) + + + def get_cache_input_depths(self) -> torch.Tensor | None: + if self.cache is None: + return None + return self.cache.input_depth + + @property + def W(self) -> int: + return self.args.width + + @property + def H(self) -> int: + return self.args.height + + + def clear_cache(self) -> None: + self.cache = None + self.model_was_seeded = False + + + def cleanup(self) -> None: + if self.args.num_gpus > 1: + rank = get_rank() + log.info(f"Model cleanup: destroying model parallel group on rank={rank}.", + rank0_only=False) + from megatron.core import parallel_state + parallel_state.destroy_model_parallel() + + import torch.distributed as dist + dist.destroy_process_group() + + log.info(f"Destroyed model parallel group on rank={rank}.", rank0_only=False) + else: + log.info("Model cleanup: nothing to do (no parallelism).", rank0_only=False) diff --git a/cosmos_predict1/diffusion/inference/gen3c_pipeline.py b/cosmos_predict1/diffusion/inference/gen3c_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d3fe1742927aac89ebf00ff773ced0131e2db2 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_pipeline.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + generate_world_from_video, + get_video_batch, + load_model_by_config, +) +from cosmos_predict1.diffusion.model.model_gen3c import DiffusionGen3CModel +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline +from cosmos_predict1.utils import log + +class Gen3cPipeline(DiffusionVideo2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + disable_prompt_encoder: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + disable_prompt_encoder: Whether to disable prompt encoder + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + """ + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + disable_prompt_encoder=disable_prompt_encoder, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + num_input_frames=1, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionGen3CModel, + ) + + def generate( + self, + prompt: str, + image_path: str, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt: Optional[str] = None, + ) -> Any: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_ path: Path to conditioning image + rendered_warp_images: Rendered warp images + rendered_warp_masks: Rendered warp masks + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + if type(image_path) == str: + log.info(f"Run with image path: {image_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + log.info(f"Run with prompt: {prompt}") + if not self.disable_guardrail: + log.info(f"Run guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical(f"Input {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt is not safe") + return None + log.info(f"Pass guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + else: + log.info("Not running guardrail") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_path, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> Any: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + condition_latent = self._run_tokenizer_encoding(image_or_video_path) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, rendered_warp_images, rendered_warp_masks, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + ) -> Any: + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + data_batch["condition_state"] = rendered_warp_images + data_batch["condition_state_mask"] = rendered_warp_masks + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=self.model.state_shape, + is_negative_prompt=True, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video diff --git a/cosmos_predict1/diffusion/inference/gen3c_single_image.py b/cosmos_predict1/diffusion/inference/gen3c_single_image.py new file mode 100644 index 0000000000000000000000000000000000000000..53b215c34834189975a2576155b37915151977d5 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_single_image.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +from moge.model.v1 import MoGeModel +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + validate_args, +) +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer +from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory +import torch.nn.functional as F +torch.enable_grad(False) + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) # TODO: do we need this? + parser.add_argument( + "--input_image_path", + type=str, + help="Input image path for generating a single video", + ) + parser.add_argument( + "--trajectory", + type=str, + choices=[ + "left", + "right", + "up", + "down", + "zoom_in", + "zoom_out", + "clockwise", + "counterclockwise", + "none", + ], + default="left", + help="Select a trajectory type from the available options (default: original)", + ) + parser.add_argument( + "--camera_rotation", + type=str, + choices=["center_facing", "no_rotation", "trajectory_aligned"], + default="center_facing", + help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", + ) + parser.add_argument( + "--movement_distance", + type=float, + default=0.3, + help="Distance of the camera from the center of the scene", + ) + parser.add_argument( + "--noise_aug_strength", + type=float, + default=0.0, + help="Strength of noise augmentation on warped frames", + ) + parser.add_argument( + "--save_buffer", + action="store_true", + help="If set, save the warped images (buffer) side by side with the output video.", + ) + parser.add_argument( + "--filter_points_threshold", + type=float, + default=0.05, + help="If set, filter the points continuity of the warped images.", + ) + parser.add_argument( + "--foreground_masking", + action="store_true", + help="If set, use foreground masking for the warped images.", + ) + return parser + +def parse_arguments() -> argparse.Namespace: + parser = create_parser() + return parser.parse_args() + + +def validate_args(args): + assert args.num_video_frames is not None, "num_video_frames must be provided" + assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" + +def _predict_moge_depth(current_image_path: str | np.ndarray, + target_h: int, target_w: int, + device: torch.device, moge_model: MoGeModel): + """Handles MoGe depth prediction for a single image. + + If the image is directly provided as a NumPy array, it should have shape [H, W, C], + where the channels are RGB and the pixel values are in [0..255]. + """ + + if isinstance(current_image_path, str): + input_image_bgr = cv2.imread(current_image_path) + if input_image_bgr is None: + raise FileNotFoundError(f"Input image not found: {current_image_path}") + input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB) + else: + input_image_rgb = current_image_path + del current_image_path + + depth_pred_h, depth_pred_w = 720, 1280 + + input_image_for_depth_resized = cv2.resize(input_image_rgb, (depth_pred_w, depth_pred_h)) + input_image_for_depth_tensor_chw = torch.tensor(input_image_for_depth_resized / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) + moge_output_full = moge_model.infer(input_image_for_depth_tensor_chw) + moge_depth_hw_full = moge_output_full["depth"] + moge_intrinsics_33_full_normalized = moge_output_full["intrinsics"] + moge_mask_hw_full = moge_output_full["mask"] + + moge_depth_hw_full = torch.where(moge_mask_hw_full==0, torch.tensor(1000.0, device=moge_depth_hw_full.device), moge_depth_hw_full) + moge_intrinsics_33_full_pixel = moge_intrinsics_33_full_normalized.clone() + moge_intrinsics_33_full_pixel[0, 0] *= depth_pred_w + moge_intrinsics_33_full_pixel[1, 1] *= depth_pred_h + moge_intrinsics_33_full_pixel[0, 2] *= depth_pred_w + moge_intrinsics_33_full_pixel[1, 2] *= depth_pred_h + + # Calculate scaling factor for height + height_scale_factor = target_h / depth_pred_h + width_scale_factor = target_w / depth_pred_w + + # Resize depth map, mask, and image tensor + # Resizing depth: (H, W) -> (1, 1, H, W) for interpolate, then squeeze + moge_depth_hw = F.interpolate( + moge_depth_hw_full.unsqueeze(0).unsqueeze(0), + size=(target_h, target_w), + mode='bilinear', + align_corners=False + ).squeeze(0).squeeze(0) + + # Resizing mask: (H, W) -> (1, 1, H, W) for interpolate, then squeeze + moge_mask_hw = F.interpolate( + moge_mask_hw_full.unsqueeze(0).unsqueeze(0).to(torch.float32), + size=(target_h, target_w), + mode='nearest', # Using nearest neighbor for binary mask + ).squeeze(0).squeeze(0).to(torch.bool) + + # Resizing image tensor: (C, H, W) -> (1, C, H, W) for interpolate, then squeeze + input_image_tensor_chw_target_res = F.interpolate( + input_image_for_depth_tensor_chw.unsqueeze(0), + size=(target_h, target_w), + mode='bilinear', + align_corners=False + ).squeeze(0) + + moge_image_b1chw_float = input_image_tensor_chw_target_res.unsqueeze(0).unsqueeze(1) * 2 - 1 + + moge_intrinsics_33 = moge_intrinsics_33_full_pixel.clone() + # Adjust intrinsics for resized height + moge_intrinsics_33[1, 1] *= height_scale_factor # fy + moge_intrinsics_33[1, 2] *= height_scale_factor # cy + moge_intrinsics_33[0, 0] *= width_scale_factor # fx + moge_intrinsics_33[0, 2] *= width_scale_factor # cx + + moge_depth_b11hw = moge_depth_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) + moge_depth_b11hw = torch.nan_to_num(moge_depth_b11hw, nan=1e4) + moge_depth_b11hw = torch.clamp(moge_depth_b11hw, min=0, max=1e4) + moge_mask_b11hw = moge_mask_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # Prepare initial intrinsics [B, 1, 3, 3] + moge_intrinsics_b133 = moge_intrinsics_33.unsqueeze(0).unsqueeze(0) + initial_w2c_44 = torch.eye(4, dtype=torch.float32, device=device) + moge_initial_w2c_b144 = initial_w2c_44.unsqueeze(0).unsqueeze(0) + + return ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) + +def _predict_moge_depth_from_tensor( + image_tensor_chw_0_1: torch.Tensor, # Shape (C, H_input, W_input), range [0,1] + moge_model: MoGeModel +): + """Handles MoGe depth prediction from an image tensor.""" + moge_output_full = moge_model.infer(image_tensor_chw_0_1) + moge_depth_hw_full = moge_output_full["depth"] # (moge_inf_h, moge_inf_w) + moge_mask_hw_full = moge_output_full["mask"] # (moge_inf_h, moge_inf_w) + + moge_depth_11hw = moge_depth_hw_full.unsqueeze(0).unsqueeze(0) + moge_depth_11hw = torch.nan_to_num(moge_depth_11hw, nan=1e4) + moge_depth_11hw = torch.clamp(moge_depth_11hw, min=0, max=1e4) + moge_mask_11hw = moge_mask_hw_full.unsqueeze(0).unsqueeze(0) + moge_depth_11hw = torch.where(moge_mask_11hw==0, torch.tensor(1000.0, device=moge_depth_11hw.device), moge_depth_11hw) + + return moge_depth_11hw, moge_mask_11hw + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + disable_prompt_encoder=args.disable_prompt_encoder, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=121, + seed=args.seed, + ) + + frame_buffer_max = pipeline.model.frame_buffer_max + generator = torch.Generator(device=device).manual_seed(args.seed) + sample_n_frames = pipeline.model.chunk_size + + # Load the model and assign it to pipeline.model.model + pipeline.model.model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + + #if args.num_gpus > 1: + # Now pipeline.model.model should be the loaded MoGeModel instance + #pipeline.model.model.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] + + os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_path = input_dict.get("visual_input", None) + if current_image_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_path, 1): + print(f"Input image {current_image_path} is not valid, skipping.") + continue + + # load image, predict depth and initialize 3D cache + ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) = _predict_moge_depth( + current_image_path, args.height, args.width, device, pipeline.model.model + ) + + cache = Cache3D_Buffer( + frame_buffer_max=frame_buffer_max, + generator=generator, + noise_aug_strength=args.noise_aug_strength, + input_image=moge_image_b1chw_float[:, 0].clone(), # [B, C, H, W] + input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] + # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] + input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] + input_intrinsics=moge_intrinsics_b133[:, 0],# [B, 3, 3] + filter_points_threshold=args.filter_points_threshold, + foreground_masking=args.foreground_masking, + ) + + initial_cam_w2c_for_traj = moge_initial_w2c_b144[0, 0] + initial_cam_intrinsics_for_traj = moge_intrinsics_b133[0, 0] + + # Generate camera trajectory using the new utility function + try: + generated_w2cs, generated_intrinsics = generate_camera_trajectory( + trajectory_type=args.trajectory, + initial_w2c=initial_cam_w2c_for_traj, + initial_intrinsics=initial_cam_intrinsics_for_traj, + num_frames=args.num_video_frames, + movement_distance=args.movement_distance, + camera_rotation=args.camera_rotation, + center_depth=1.0, + device=device.type, + ) + except (ValueError, NotImplementedError) as e: + log.critical(f"Failed to generate trajectory: {e}") + continue + + log.info(f"Generating 0 - {sample_n_frames} frames") + rendered_warp_images, rendered_warp_masks = cache.render_cache( + generated_w2cs[:, 0:sample_n_frames], + generated_intrinsics[:, 0:sample_n_frames], + ) + + all_rendered_warps = [] + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=current_image_path, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) + for num_iter in range(1, num_ar_iterations): + start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame + end_frame_idx = start_frame_idx + sample_n_frames + + log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") + + last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + pred_depth, pred_mask = _predict_moge_depth_from_tensor( + pred_image_for_depth_chw_0_1, pipeline.model.model + ) + + cache.update_cache( + new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] + new_depth=pred_depth, # (1,1,H,W) + # new_mask=pred_mask, # (1,1,H,W) + new_w2c=generated_w2cs[:, start_frame_idx], + new_intrinsics=generated_intrinsics[:, start_frame_idx], + ) + current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] + rendered_warp_images, rendered_warp_masks = cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + ) + + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) + + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, prompt = generated_output + video = np.concatenate([video, video_new[1:]], axis=0) + + # Final video processing + final_video_to_save = video + final_width = args.width + + if args.save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + else: + log.info("No warp buffers to save.") + + + video_save_path = os.path.join( + args.video_save_folder, + f"{i if args.batch_input_path else args.video_save_name}.mp4" + ) + + os.makedirs(os.path.dirname(video_save_path), exist_ok=True) + + # Save video + save_video( + video=final_video_to_save, + fps=args.fps, + H=args.height, + W=final_width, + video_save_quality=5, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + if args.prompt is None: + args.prompt = "" + args.disable_guardrail = True + args.disable_prompt_upsampler = True + demo(args) diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d470d492a61330194901abaad2720130b650560 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/inference_utils.py @@ -0,0 +1,941 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +from contextlib import contextmanager +from typing import List, NamedTuple, Optional, Tuple + +import einops +import imageio +import numpy as np +import omegaconf.errors +import torch +import torchvision.transforms.functional as transforms_F +from omegaconf import OmegaConf + +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.model.model_v2w_multiview import DiffusionMultiviewV2WModel +from cosmos_predict1.diffusion.model.model_world_interpolator import DiffusionWorldInterpolatorWModel +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.io import load_from_fileobj + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + +DEFAULT_AUGMENT_SIGMA = 0.001 + + +def add_common_arguments(parser): + """Add common command line arguments for text2world and video2world generation. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + + The arguments include: + - checkpoint_dir: Base directory containing model weights + - tokenizer_dir: Directory containing tokenizer weights + - video_save_name: Output video filename for single video generation + - video_save_folder: Output directory for batch video generation + - prompt: Text prompt for single video generation + - batch_input_path: Path to JSONL file with input prompts for batch video generation + - negative_prompt: Text prompt describing undesired attributes + - num_steps: Number of diffusion sampling steps + - guidance: Classifier-free guidance scale + - num_video_frames: Number of frames to generate + - height/width: Output video dimensions + - fps: Output video frame rate + - seed: Random seed for reproducibility + - Various model offloading flags + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--tokenizer_dir", + type=str, + default="Cosmos-Tokenize1-CV8x8x8-720p", + help="Tokenizer weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument( + "--video_save_folder", + type=str, + default="outputs/", + help="Output folder for generating a batch of videos", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Path to a JSONL file of input prompts for generating a batch of videos", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special " + "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and " + "flickering. Overall, the video is of poor quality.", + help="Negative prompt for the video", + ) + parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") + parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value") + parser.add_argument( + "--num_video_frames", + type=int, + default=121, + # choices=[8 * n + 1 for n in range(16)] + [10, 117], + help="Number of video frames to sample", + ) + parser.add_argument("--height", type=int, default=704, help="Height of video to sample") + parser.add_argument("--width", type=int, default=1280, help="Width of video to sample") + parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") + parser.add_argument( + "--disable_prompt_upsampler", + action="store_true", + help="Disable prompt upsampling", + ) + parser.add_argument( + "--offload_diffusion_transformer", + action="store_true", + help="Offload DiT after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload tokenizer after inference", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload text encoder model after inference", + ) + parser.add_argument( + "--offload_prompt_upsampler", + action="store_true", + help="Offload prompt upsampler after inference", + ) + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + parser.add_argument( + "--disable_guardrail", + action="store_true", + help="Disable guardrail models", + ) + parser.add_argument( + "--disable_prompt_encoder", + action="store_true", + help="Disable prompt encoder to save memory, returns dummy embeddings instead", + ) + + +# Function to fully remove an argument +def remove_argument(parser, arg_name): + # Get a list of actions to remove + actions_to_remove = [action for action in parser._actions if action.dest == arg_name] + + for action in actions_to_remove: + # Remove action from parser._actions + parser._actions.remove(action) + + # Remove option strings + for option_string in action.option_strings: + parser._option_string_actions.pop(option_string, None) + + +def validate_args(args: argparse.Namespace, inference_type: str) -> None: + """Validate command line arguments for text2world and video2world generation.""" + assert inference_type in [ + "text2world", + "video2world", + "world_interpolator", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + # Validate prompt/image/video args for single or batch generation + if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler): + assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided." + if (inference_type == "video2world" or inference_type == "world_interpolator") and not args.batch_input_path: + assert ( + args.input_image_or_video_path + ), "--input_image_or_video_path must be provided for single video generation." + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + """Load a model checkpoint with non-strict matching, handling shape mismatches. + + Args: + model (torch.nn.Module): Model to load weights into + checkpoint_state_dict (dict): State dict from checkpoint + + Returns: + _IncompatibleKeys: Named tuple containing: + - missing_keys: Keys present in model but missing from checkpoint + - unexpected_keys: Keys present in checkpoint but not in model + - incorrect_shapes: Keys with mismatched tensor shapes + + The function handles special cases like: + - Uninitialized parameters + - Quantization observers + - TransformerEngine FP8 states + """ + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) + + +@contextmanager +def skip_init_linear(): + # skip init of nn.Linear + orig_reset_parameters = torch.nn.Linear.reset_parameters + torch.nn.Linear.reset_parameters = lambda x: x + xavier_uniform_ = torch.nn.init.xavier_uniform_ + torch.nn.init.xavier_uniform_ = lambda x: x + yield + torch.nn.Linear.reset_parameters = orig_reset_parameters + torch.nn.init.xavier_uniform_ = xavier_uniform_ + + +def load_model_by_config( + config_job_name, + config_file="projects/cosmos_video/config/config.py", + model_class=DiffusionT2WModel, +): + config_module = get_config_module(config_file) + config = importlib.import_module(config_module).make_config() + + config = override(config, ["--", f"experiment={config_job_name}"]) + + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + # Initialize model + with skip_init_linear(): + model = model_class(config.model) + return model + + +def load_network_model(model: DiffusionT2WModel, ckpt_path: str): + with skip_init_linear(): + model.set_up_model() + try: + net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + except Exception: + # Posttrained models can be loaded with weights_only=False + net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if "model" in net_state_dict: + model_state_dict = net_state_dict["model"] + if "ema" in net_state_dict and model.config.peft_control and model.config.peft_control.enabled: + ema_state_dict = net_state_dict["ema"] + # Convert ema state_dict to model state_dict by replacing "-" with "." + ema_state_dict = {k.replace("-", "."): v for k, v in ema_state_dict.items()} + model_state_dict.update(ema_state_dict) + net_state_dict = model_state_dict + else: + net_state_dict = model_state_dict + + log.debug(non_strict_load_model(model.model, net_state_dict)) + model.cuda() + + +def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): + with skip_init_linear(): + model.set_up_tokenizer(tokenizer_dir) + model.cuda() + + +def prepare_data_batch( + height: int, + width: int, + num_frames: int, + fps: int, + prompt_embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, +): + """Prepare input batch tensors for video generation. + + Args: + height (int): Height of video frames + width (int): Width of video frames + num_frames (int): Number of frames to generate + fps (int): Frames per second + prompt_embedding (torch.Tensor): Encoded text prompt embeddings + negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings + + Returns: + dict: Batch dictionary containing: + - video: Zero tensor of target video shape + - t5_text_mask: Attention mask for text embeddings + - image_size: Target frame dimensions + - fps: Target frame rate + - num_frames: Number of frames + - padding_mask: Frame padding mask + - t5_text_embeddings: Prompt embeddings + - neg_t5_text_embeddings: Negative prompt embeddings (if provided) + - neg_t5_text_mask: Mask for negative embeddings (if provided) + """ + # Create base data batch + data_batch = { + "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), + "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), + } + + # Handle text embeddings + + t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["t5_text_embeddings"] = t5_embed + + if negative_prompt_embedding is not None: + neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["neg_t5_text_embeddings"] = neg_t5_embed + data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() + + return data_batch + + +def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): + """Prepare complete input batch for video generation including latent dimensions. + + Args: + model: Diffusion model instance + prompt_embedding (torch.Tensor): Text prompt embeddings + negative_prompt_embedding (torch.Tensor): Negative prompt embeddings + height (int): Output video height + width (int): Output video width + fps (int): Output video frame rate + num_video_frames (int): Number of frames to generate + + Returns: + tuple: + - data_batch (dict): Complete model input batch + - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression + """ + raw_video_batch = prepare_data_batch( + height=height, + width=width, + num_frames=num_video_frames, + fps=fps, + prompt_embedding=prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + try: + condition_location = model.config.conditioner.video_cond_bool.condition_location + except omegaconf.errors.ConfigAttributeError: + condition_location = None + + # Use condition_location in your logic + if condition_location == "first_and_last_1": + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(num_video_frames - 1) + 1, # +1 for the last frame + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + else: + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(num_video_frames), + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + + return raw_video_batch, state_shape + + +def get_video_batch_for_multiview_model( + model, prompt_embedding, height, width, fps, num_video_frames, frame_repeat_negative_condition +): + """Prepare complete input batch for video generation including latent dimensions. + + Args: + model: Diffusion model instance + prompt_embedding (torch.Tensor): Text prompt embeddings + height (int): Output video height + width (int): Output video width + fps (int): Output video frame rate + num_video_frames (int): Number of frames to generate + frame_repeat_negative_condition (int): Number of frames to generate + + Returns: + tuple: + - data_batch (dict): Complete model input batch + - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression + """ + n_views = len(prompt_embedding) + prompt_embedding = einops.rearrange(torch.cat(prompt_embedding), "n t d -> (n t) d").unsqueeze(0) + raw_video_batch = prepare_data_batch( + height=height, + width=width, + num_frames=num_video_frames, + fps=fps, + prompt_embedding=prompt_embedding, + ) + if frame_repeat_negative_condition != -1: + frame_repeat = torch.zeros(n_views) + frame_repeat[-1] = frame_repeat_negative_condition + frame_repeat[-2] = frame_repeat_negative_condition + raw_video_batch["frame_repeat"] = frame_repeat.unsqueeze(0).to(dtype=torch.bfloat16).cuda() + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(int(num_video_frames / n_views)) * n_views, + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + return raw_video_batch, state_shape + + +def generate_world_from_text( + model: DiffusionT2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, +): + """Generate video from text prompt using diffusion model. + + Args: + model (DiffusionT2WModel): Text-to-video diffusion model + state_shape (list[int]): Latent state dimensions [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Model input batch with embeddings + guidance (float): Classifier-free guidance scale + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for reproducibility + + Returns: + np.ndarray: Generated video frames [T,H,W,C], range [0,255] + + The function: + 1. Initializes random latent with maximum noise + 2. Performs guided diffusion sampling + 3. Decodes latents to pixel space + """ + + # Generate video + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + ) + + return sample + + +def generate_world_from_video( + model: DiffusionV2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, + condition_latent: torch.Tensor, + num_input_frames: int, +) -> Tuple[np.array, list, list]: + """Generate video using a conditioning video/image input. + + Args: + model (DiffusionV2WModel): The diffusion model instance + state_shape (list[int]): Shape of the latent state [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Batch containing model inputs including text embeddings + guidance (float): Classifier-free guidance scale for sampling + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for generation + condition_latent (torch.Tensor): Latent tensor from conditioning video/image file + num_input_frames (int): Number of input frames + + Returns: + np.array: Generated video frames in shape [T,H,W,C], range [0,255] + """ + assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" + augment_sigma = DEFAULT_AUGMENT_SIGMA + + if condition_latent.shape[2] < state_shape[1]: + # Padding condition latent to state shape + b, c, t, h, w = condition_latent.shape + condition_latent = torch.cat( + [ + condition_latent, + condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), + ], + dim=2, + ).contiguous() + num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) + + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + condition_latent=condition_latent, + num_condition_t=num_of_latent_condition, + condition_augment_sigma=augment_sigma, + ) + return sample + + +def read_video_or_image_into_frames_BCTHW( + input_path: str, + input_path_format: str = "mp4", + H: int = None, + W: int = None, + normalize: bool = True, + max_frames: int = -1, + also_return_fps: bool = False, +) -> torch.Tensor: + """Read video or image file and convert to tensor format. + + Args: + input_path (str): Path to input video/image file + input_path_format (str): Format of input file (default: "mp4") + H (int, optional): Height to resize frames to + W (int, optional): Width to resize frames to + normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) + max_frames (int): Maximum number of frames to read (-1 for all frames) + also_return_fps (bool): Whether to return fps along with frames + + Returns: + torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested + """ + log.debug(f"Reading video from {input_path}") + + loaded_data = load_from_fileobj(input_path, format=input_path_format) + frames, meta_data = loaded_data + if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): + frames = np.array(frames[0]) # HWC, [0,255] + if frames.shape[-1] > 3: # RGBA, set the transparent to white + # Separate the RGB and Alpha channels + rgb_channels = frames[..., :3] + alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] + + # Create a white background + white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB + + # Blend the RGB channels with the white background based on the alpha channel + frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( + np.uint8 + ) + frames = [frames] + fps = 0 + else: + fps = int(meta_data.get("fps")) + if max_frames != -1: + frames = frames[:max_frames] + input_tensor = np.stack(frames, axis=0) + input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") + if normalize: + input_tensor = input_tensor / 128.0 - 1.0 + input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW + log.debug(f"Raw data shape: {input_tensor.shape}") + if H is not None and W is not None: + input_tensor = transforms_F.resize( + input_tensor, + size=(H, W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) + if normalize: + input_tensor = input_tensor.to("cuda") + log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") + if also_return_fps: + return input_tensor, fps + return input_tensor + + +def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: + """This function computes the number of latent frames given the number of input frames. + Args: + model (DiffusionV2WModel): video generation model + num_input_frames (int): number of input frames + downsample_factor (int): downsample factor for temporal reduce + Returns: + int: number of latent frames + """ + # First find how many vae chunks are contained with in num_input_frames + num_latent_frames = ( + num_input_frames + // model.tokenizer.video_vae.pixel_chunk_duration + * model.tokenizer.video_vae.latent_chunk_duration + ) + # Then handle the remainder + if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1: + num_latent_frames += 1 + elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1: + assert ( + num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 + ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" + num_latent_frames += ( + 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor + ) + + return num_latent_frames + + +def create_condition_latent_from_input_frames( + model: DiffusionV2WModel, + input_frames: torch.Tensor, + num_frames_condition: int = 25, +): + """Create condition latent for video generation from input frames. + + Takes the last num_frames_condition frames from input as conditioning. + + Args: + model (DiffusionV2WModel): Video generation model + input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] + num_frames_condition (int): Number of frames to use for conditioning + + Returns: + tuple: (condition_latent, encode_input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - encode_input_frames (torch.Tensor): Padded input frames used for encoding + """ + B, C, T, H, W = input_frames.shape + num_frames_encode = ( + model.tokenizer.pixel_chunk_duration + ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1 + log.debug( + f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" + ) + + log.debug( + f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" + ) + + assert ( + input_frames.shape[2] >= num_frames_condition + ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" + assert ( + num_frames_encode >= num_frames_condition + ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" + + # Put the conditioal frames to the begining of the video, and pad the end with zero + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + condition_frames_first = input_frames[:, :, :num_frames_condition] + condition_frames_last = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) + else: + condition_frames = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) + + log.debug( + f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" + ) + if hasattr(model, "n_views"): + encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW + latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) + latent = torch.cat([latent1, latent2], dim=2) # BCTHW + else: + latent = model.encode(encode_input_frames) + return latent, encode_input_frames + + +def compute_num_frames_condition(model: DiffusionV2WModel, num_of_latent_overlap: int, downsample_factor=8) -> int: + """This function computes the number of condition pixel frames given the number of latent frames to overlap. + Args: + model (ExtendDiffusionModel): video generation model + num_of_latent_overlap (int): number of latent frames to overlap + downsample_factor (int): downsample factor for temporal reduce + Returns: + int: number of condition frames in output space + """ + if getattr(model.tokenizer.video_vae, "is_casual", True): + # For casual model + num_frames_condition = ( + num_of_latent_overlap + // model.tokenizer.video_vae.latent_chunk_duration + * model.tokenizer.video_vae.pixel_chunk_duration + ) + if num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration == 1: + num_frames_condition += 1 + elif num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration > 1: + num_frames_condition += ( + 1 + (num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration - 1) * downsample_factor + ) + else: + num_frames_condition = num_of_latent_overlap * downsample_factor + + return num_frames_condition + + +def get_condition_latent( + model: DiffusionV2WModel, + input_image_or_video_path: str, + num_input_frames: int = 1, + state_shape: list[int] = None, + frame_index: int = 0, + frame_stride: int = 1, +): + """Get condition latent from input image/video file. + + Args: + model (DiffusionV2WModel): Video generation model + input_image_or_video_path (str): Path to conditioning image/video + num_input_frames (int): Number of input frames for video2world prediction + + Returns: + tuple: (condition_latent, input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] + """ + if state_shape is None: + state_shape = model.state_shape + assert num_input_frames > 0, "num_input_frames must be greater than 0" + + H, W = ( + state_shape[-2] * model.tokenizer.spatial_compression_factor, + state_shape[-1] * model.tokenizer.spatial_compression_factor, + ) + if type(input_image_or_video_path) == str: + input_path_format = input_image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + input_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + else: + input_frames = input_image_or_video_path + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + start_frame = frame_index * frame_stride + end_frame = (frame_index + 1) * frame_stride + curr_input_frames = torch.cat( + [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 + ).contiguous() # BCTHW + num_of_latent_condition = 1 + num_frames_condition = compute_num_frames_condition( + model, num_of_latent_condition, downsample_factor=model.tokenizer.temporal_compression_factor + ) + + condition_latent, _ = create_condition_latent_from_input_frames(model, curr_input_frames, num_frames_condition) + condition_latent = condition_latent.to(torch.bfloat16) + return condition_latent + + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) + condition_latent = condition_latent.to(torch.bfloat16) + + return condition_latent + + +def get_condition_latent_multiview( + model: DiffusionMultiviewV2WModel, + input_image_or_video_path: str, + num_input_frames: int = 1, + state_shape: list[int] = None, +): + """Get condition latent from input image/video file. This is the function for the multi-view model where each view has one latent condition frame. + + Args: + model (DiffusionMultiviewV2WModel): Video generation model + input_image_or_video_path (str): Path to conditioning image/video + num_input_frames (int): Number of input frames for video2world prediction + + Returns: + tuple: (condition_latent, input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] + """ + if state_shape is None: + state_shape = model.state_shape + assert num_input_frames > 0, "num_input_frames must be greater than 0" + + H, W = ( + state_shape[-2] * model.tokenizer.spatial_compression_factor, + state_shape[-1] * model.tokenizer.spatial_compression_factor, + ) + input_path_format = input_image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + input_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + input_frames = einops.rearrange(input_frames, "B C (V T) H W -> (B V) C T H W", V=model.n_views) + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) + condition_latent = condition_latent.to(torch.bfloat16) + + return condition_latent, einops.rearrange(input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views)[0] + + +def check_input_frames(input_path: str, required_frames: int) -> bool: + """Check if input video/image has sufficient frames. + + Args: + input_path: Path to input video or image + required_frames: Number of required frames + + Returns: + bool: True if input has sufficient frames, False otherwise + """ + if input_path.endswith((".jpg", ".jpeg", ".png")): + if required_frames > 1: + log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") + return False + return True # Let the pipeline handle image loading + # For video input + try: + vid = imageio.get_reader(input_path, "ffmpeg") + frame_count = vid.count_frames() + + if frame_count < required_frames: + log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") + return False + else: + return True + except Exception as e: + log.error(f"Error reading video file {input_path}: {e}") + return False + + +def get_input_sizes(input_path: str) -> tuple[int, int]: + """Get the height and width of input video or image. + + Args: + input_path: Path to input video or image file + + Returns: + tuple: (height, width) dimensions of the input + """ + if input_path.endswith((".jpg", ".jpeg", ".png")): + # For image input + try: + img = imageio.imread(input_path) + return img.shape[0], img.shape[1] + except Exception as e: + log.error(f"Error reading image file {input_path}: {e}") + raise + else: + # For video input + try: + vid = imageio.get_reader(input_path, "ffmpeg") + first_frame = vid.get_data(0) + return first_frame.shape[0], first_frame.shape[1] + except Exception as e: + log.error(f"Error reading video file {input_path}: {e}") + raise diff --git a/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py b/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..e11589b0a9a3395db4b14293f3461a5572035d49 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import warp as wp +import numpy as np + +# Initialize Warp with CUDA +wp.init() + +@wp.kernel +def ray_triangle_intersection_kernel( + ray_origins: wp.array2d(dtype=wp.float32), # (H*W, 3) + ray_directions: wp.array2d(dtype=wp.float32), # (H*W, 3) + vertices: wp.array2d(dtype=wp.float32), # (N, 3) + faces: wp.array2d(dtype=wp.int32), # (M, 3) + depth_map: wp.array(dtype=wp.float32), # (H*W,) + num_triangles: wp.int32, + epsilon: wp.float32 +): + """ + Warp kernel for ray-triangle intersection using Möller–Trumbore algorithm. + Each thread processes one ray against all triangles. + """ + # Get thread index (ray index) + ray_idx = wp.tid() + + # Get ray origin and direction + ray_origin = wp.vec3( + ray_origins[ray_idx, 0], + ray_origins[ray_idx, 1], + ray_origins[ray_idx, 2] + ) + ray_dir = wp.vec3( + ray_directions[ray_idx, 0], + ray_directions[ray_idx, 1], + ray_directions[ray_idx, 2] + ) + + # Initialize minimum distance + min_t = wp.float32(1e10) + + # Iterate through all triangles + for tri_idx in range(num_triangles): + # Get triangle vertex indices + i0 = faces[tri_idx, 0] + i1 = faces[tri_idx, 1] + i2 = faces[tri_idx, 2] + + # Get triangle vertices + v0 = wp.vec3(vertices[i0, 0], vertices[i0, 1], vertices[i0, 2]) + v1 = wp.vec3(vertices[i1, 0], vertices[i1, 1], vertices[i1, 2]) + v2 = wp.vec3(vertices[i2, 0], vertices[i2, 1], vertices[i2, 2]) + + # Compute edges + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Möller–Trumbore algorithm + h = wp.cross(ray_dir, edge2) + a = wp.dot(edge1, h) + + # Check if ray is parallel to triangle + if wp.abs(a) < epsilon: + continue + + f = 1.0 / a + s = ray_origin - v0 + u = f * wp.dot(s, h) + + # Check if intersection is within triangle (u >= 0 and u <= 1) + if u < 0.0 or u > 1.0: + continue + + q = wp.cross(s, edge1) + v = f * wp.dot(ray_dir, q) + + # Check if intersection is within triangle (v >= 0 and u + v <= 1) + if v < 0.0 or (u + v) > 1.0: + continue + + # Compute t (distance along ray) + t = f * wp.dot(edge2, q) + + # Only consider intersections in front of camera (t > 0) + if t > epsilon and t < min_t: + min_t = t + + # Write result + if min_t < 1e10: + depth_map[ray_idx] = min_t + else: + depth_map[ray_idx] = 0.0 + + +@wp.kernel +def ray_triangle_intersection_tiled_kernel( + ray_origins: wp.array2d(dtype=wp.float32), # (H*W, 3) + ray_directions: wp.array2d(dtype=wp.float32), # (H*W, 3) + vertices: wp.array2d(dtype=wp.float32), # (N, 3) + faces: wp.array2d(dtype=wp.int32), # (M, 3) + depth_map: wp.array(dtype=wp.float32), # (H*W,) + tri_start: wp.int32, # Start triangle index for this tile + tri_end: wp.int32, # End triangle index for this tile + epsilon: wp.float32 +): + """ + Tiled version of ray-triangle intersection kernel. + Processes a subset of triangles to improve memory access patterns. + """ + # Get thread index (ray index) + ray_idx = wp.tid() + + # Get ray origin and direction + ray_origin = wp.vec3( + ray_origins[ray_idx, 0], + ray_origins[ray_idx, 1], + ray_origins[ray_idx, 2] + ) + ray_dir = wp.vec3( + ray_directions[ray_idx, 0], + ray_directions[ray_idx, 1], + ray_directions[ray_idx, 2] + ) + + # Get current minimum distance + min_t = depth_map[ray_idx] + if min_t == 0.0: + min_t = wp.float32(1e10) + + # Process triangles in this tile + for tri_idx in range(tri_start, tri_end): + # Get triangle vertex indices + i0 = faces[tri_idx, 0] + i1 = faces[tri_idx, 1] + i2 = faces[tri_idx, 2] + + # Get triangle vertices + v0 = wp.vec3(vertices[i0, 0], vertices[i0, 1], vertices[i0, 2]) + v1 = wp.vec3(vertices[i1, 0], vertices[i1, 1], vertices[i1, 2]) + v2 = wp.vec3(vertices[i2, 0], vertices[i2, 1], vertices[i2, 2]) + + # Compute edges + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Möller–Trumbore algorithm + h = wp.cross(ray_dir, edge2) + a = wp.dot(edge1, h) + + # Check if ray is parallel to triangle + if wp.abs(a) < epsilon: + continue + + f = 1.0 / a + s = ray_origin - v0 + u = f * wp.dot(s, h) + + # Check if intersection is within triangle (u >= 0 and u <= 1) + if u < 0.0 or u > 1.0: + continue + + q = wp.cross(s, edge1) + v = f * wp.dot(ray_dir, q) + + # Check if intersection is within triangle (v >= 0 and u + v <= 1) + if v < 0.0 or (u + v) > 1.0: + continue + + # Compute t (distance along ray) + t = f * wp.dot(edge2, q) + + # Only consider intersections in front of camera (t > 0) + if t > epsilon and t < min_t: + min_t = t + + # Write result using atomic min to handle concurrent updates + if min_t < 1e10: + wp.atomic_min(depth_map, ray_idx, min_t) + + +def ray_triangle_intersection_warp( + ray_origins: torch.Tensor, # (H, W, 3) + ray_directions: torch.Tensor, # (H, W, 3) + vertices: torch.Tensor, # (N, 3) + faces: torch.Tensor, # (M, 3) + device: torch.device +) -> torch.Tensor: + """ + Compute ray-triangle intersections using NVIDIA Warp for maximum GPU acceleration. + + This implementation uses Warp kernels to achieve the best possible performance + on NVIDIA GPUs by: + 1. Using native CUDA kernels through Warp + 2. Tiling triangles for better memory access patterns + 3. Using atomic operations for concurrent updates + 4. Minimizing memory transfers + + Args: + ray_origins: (H, W, 3) ray origins in camera space + ray_directions: (H, W, 3) ray directions (should be normalized) + vertices: (N, 3) mesh vertices + faces: (M, 3) triangle face indices + device: torch device (must be CUDA) + + Returns: + depth_map: (H, W) depth values, 0 where no intersection + """ + H, W = ray_origins.shape[:2] + num_rays = H * W + num_triangles = faces.shape[0] + + # Reshape rays to 2D arrays + ray_origins_flat = ray_origins.reshape(-1, 3).contiguous() + ray_directions_flat = ray_directions.reshape(-1, 3).contiguous() + + # Convert PyTorch tensors to Warp arrays (as float arrays, not vec3) + wp_ray_origins = wp.from_torch(ray_origins_flat, dtype=wp.float32) + wp_ray_directions = wp.from_torch(ray_directions_flat, dtype=wp.float32) + wp_vertices = wp.from_torch(vertices.contiguous(), dtype=wp.float32) + wp_faces = wp.from_torch(faces.int().contiguous(), dtype=wp.int32) + + # Create output depth map + depth_map_flat = torch.zeros(num_rays, device=device, dtype=torch.float32) + wp_depth_map = wp.from_torch(depth_map_flat, dtype=wp.float32) + + # Choose implementation based on problem size + if num_triangles < 10000: + # For smaller meshes, use simple kernel + wp.launch( + kernel=ray_triangle_intersection_kernel, + dim=num_rays, + inputs=[ + wp_ray_origins, + wp_ray_directions, + wp_vertices, + wp_faces, + wp_depth_map, + num_triangles, + 1e-8 # epsilon + ], + device=f"cuda:{device.index}" if device.index is not None else "cuda:0" + ) + else: + # For larger meshes, use tiled approach for better memory access + triangle_tile_size = 10000 # Process triangles in tiles + + # Initialize depth map to infinity + depth_map_flat.fill_(float('inf')) + + # Process triangles in tiles + for tri_start in range(0, num_triangles, triangle_tile_size): + tri_end = min(tri_start + triangle_tile_size, num_triangles) + + wp.launch( + kernel=ray_triangle_intersection_tiled_kernel, + dim=num_rays, + inputs=[ + wp_ray_origins, + wp_ray_directions, + wp_vertices, + wp_faces, + wp_depth_map, + tri_start, + tri_end, + 1e-8 # epsilon + ], + device=f"cuda:{device.index}" if device.index is not None else "cuda:0" + ) + + # Convert infinity back to 0 + depth_map_flat[depth_map_flat == float('inf')] = 0.0 + + # Synchronize to ensure kernel completion + wp.synchronize() + + # Reshape back to 2D + depth_map = depth_map_flat.reshape(H, W) + + return depth_map diff --git a/cosmos_predict1/diffusion/inference/text2world.py b/cosmos_predict1/diffusion/inference/text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac81213e0fea581f69de3bf5008820b718e45c0 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/text2world.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Text to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add text2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Text2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Text2World", + "Cosmos-Predict1-14B-Text2World", + "Cosmos-Predict1-7B-Text2World_post-trained", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_80gb", + "Cosmos-Predict1-7B-Text2World_post-trained-8gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-lora", + "Cosmos-Predict1-14B-Text2World_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-UpsamplePrompt1-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + + parser.add_argument( + "--word_limit_to_skip_upsampler", + type=int, + default=250, + help="Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value", + ) + + return parser.parse_args() + + +def demo(args): + """Run text-to-world generation demo. + + This function handles the main text-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts from input + - Generating videos from text prompts + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "text2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize text2world generation model pipeline + pipeline = DiffusionText2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + seed=args.seed, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None: + log.critical("Prompt is missing, skipping world generation.") + continue + + # Generate video + generated_output = pipeline.generate(current_prompt, args.negative_prompt, args.word_limit_to_skip_upsampler) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + video, prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/text2world_multiview.py b/cosmos_predict1/diffusion/inference/text2world_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..717a7a9a4e267cc49d5ec9df31798f3121cff67f --- /dev/null +++ b/cosmos_predict1/diffusion/inference/text2world_multiview.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, remove_argument, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionText2WorldMultiviewGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Text to world generation demo script") + # Add common arguments + add_common_arguments(parser) + remove_argument(parser, "width") + remove_argument(parser, "height") + remove_argument(parser, "num_video_frames") + parser.add_argument("--height", type=int, default=480, help="Height of video to sample") + parser.add_argument("--width", type=int, default=848, help="Width of video to sample") + parser.add_argument( + "--num_video_frames", + type=int, + default=57, + choices=[57], + help="Number of video frames to sample, this is per-camera frame number.", + ) + # Add text2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview", + ], + ) + parser.add_argument( + "--prompt_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the right.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--prompt_back", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing backwards.", + help="Text prompt for generating rear camera view video", + ) + parser.add_argument( + "--prompt_back_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_back_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--frame_repeat_negative_condition", + type=float, + default=10.0, + help="frame_repeat number to be used as negative condition", + ) + + return parser.parse_args() + + +def demo(args): + """Run multi-view text-to-world generation demo. + + This function handles the main text-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts from input + - Generating videos from text prompts + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "text2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize text2world generation model pipeline + pipeline = DiffusionText2WorldMultiviewGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + frame_repeat_negative_condition=args.frame_repeat_negative_condition, + seed=args.seed, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [ + { + "prompt": args.prompt, + "prompt_left": args.prompt_left, + "prompt_right": args.prompt_right, + "prompt_back": args.prompt_back, + "prompt_back_left": args.prompt_back_left, + "prompt_back_right": args.prompt_back_right, + } + ] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, current_prompt in enumerate(prompts): + # Generate video + generated_output = pipeline.generate(current_prompt) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + [video_grid, video], prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{i}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=10, + video_save_path=video_save_path, + ) + + save_video( + video=video_grid, + fps=args.fps, + H=args.height * 2, + W=args.width * 3, + video_save_quality=5, + video_save_path=video_grid_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + for key, value in prompt.items(): + f.write(value.encode("utf-8")) + f.write("\n".encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/video2world.py b/cosmos_predict1/diffusion/inference/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..04acd69075831641a026174674d801299cd26191 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/video2world.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + get_input_sizes, + validate_args, +) +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Video2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Video2World", + "Cosmos-Predict1-14B-Video2World", + "Cosmos-Predict1-7B-Video2World_post-trained", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_80gb", + "Cosmos-Predict1-7B-Video2World_post-trained-8gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-lora", + "Cosmos-Predict1-14B-Video2World_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=1, + help="Number of input frames for video2world prediction", + choices=[1, 9], + ) + + return parser.parse_args() + + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = DiffusionVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + seed=args.seed, + num_input_frames=args.num_input_frames, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_or_video_path}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_or_video_path = input_dict.get("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + log.warning("Visual input is provided, overriding --height and --width arguments.") + args.height, args.width = get_input_sizes(current_image_or_video_path) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + negative_prompt=args.negative_prompt, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/video2world_multiview.py b/cosmos_predict1/diffusion/inference/video2world_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..9bca0721e8a0c57476e73206dc6838aaaab3234b --- /dev/null +++ b/cosmos_predict1/diffusion/inference/video2world_multiview.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + get_input_sizes, + remove_argument, + validate_args, +) +from cosmos_predict1.diffusion.inference.world_generation_pipeline import ( + DiffusionVideo2WorldMultiviewGenerationPipeline, +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + remove_argument(parser, "width") + remove_argument(parser, "height") + remove_argument(parser, "num_video_frames") + parser.add_argument("--height", type=int, default=480, help="Height of video to sample") + parser.add_argument("--width", type=int, default=848, help="Width of video to sample") + + parser.add_argument( + "--num_video_frames", + type=int, + default=57, + choices=[57], + help="Number of video frames to sample, this is per-camera frame number.", + ) + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview", + ], + ) + parser.add_argument( + "--prompt_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the right.", + help="Text prompt for generating right camera view video", + ) + + parser.add_argument( + "--prompt_back", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing backwards.", + help="Text prompt for generating rear camera view video", + ) + parser.add_argument( + "--prompt_back_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_back_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--frame_repeat_negative_condition", + type=float, + default=10.0, + help="frame_repeat number to be used as negative condition", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=1, + help="Number of input frames for video2world prediction", + choices=[1, 9], + ) + + return parser.parse_args() + + +def demo(args): + """Run multi-view video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = DiffusionVideo2WorldMultiviewGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + frame_repeat_negative_condition=args.frame_repeat_negative_condition, + seed=args.seed, + num_input_frames=args.num_input_frames, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [ + { + "prompt": args.prompt, + "prompt_left": args.prompt_left, + "prompt_right": args.prompt_right, + "prompt_back": args.prompt_back, + "prompt_back_left": args.prompt_back_left, + "prompt_back_right": args.prompt_back_right, + "visual_input": args.input_image_or_video_path, + } + ] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_image_or_video_path = input_dict.pop("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + current_prompt = input_dict + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + log.warning("Visual input is provided, overriding --height and --width arguments.") + args.height, args.width = get_input_sizes(current_image_or_video_path) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + [video_grid, video], prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{i}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=10, + video_save_path=video_save_path, + ) + save_video( + video=video_grid, + fps=args.fps, + H=args.height * 2, + W=args.width * 3, + video_save_quality=5, + video_save_path=video_grid_save_path, + ) + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + for key, value in prompt.items(): + f.write(value.encode("utf-8")) + f.write("\n".encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e835cf02efecd566aeaec7b9b1d5c99eba55d89b --- /dev/null +++ b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py @@ -0,0 +1,1470 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import Any, Optional + +import einops +import numpy as np +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + generate_world_from_text, + generate_world_from_video, + get_condition_latent, + get_condition_latent_multiview, + get_video_batch, + get_video_batch_for_multiview_model, + load_model_by_config, + load_network_model, + load_tokenizer_model, + read_video_or_image_into_frames_BCTHW, +) +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.model.model_t2w_multiview import DiffusionMultiviewT2WModel +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.model.model_v2w_multiview import DiffusionMultiviewV2WModel +from cosmos_predict1.diffusion.model.model_world_interpolator import DiffusionWorldInterpolatorWModel +from cosmos_predict1.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( + create_prompt_upsampler, + run_chat_completion, +) +from cosmos_predict1.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( + create_vlm_prompt_upsampler, + prepare_dialog, +) +from cosmos_predict1.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( + run_chat_completion as run_chat_completion_vlm, +) +from cosmos_predict1.diffusion.training.utils.inference_long_video import generate_video_from_batch_with_loop +from cosmos_predict1.utils import log +from cosmos_predict1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline + +MODEL_NAME_DICT = { + # text2world + "Cosmos-Predict1-7B-Text2World": "Cosmos_Predict1_Text2World_7B", + "Cosmos-Predict1-14B-Text2World": "Cosmos_Predict1_Text2World_14B", + "Cosmos-Predict1-7B-Text2World_post-trained": "Cosmos_Predict1_Text2World_7B_Post_trained", + "Cosmos-Predict1-14B-Text2World_post-trained": "Cosmos_Predict1_Text2World_14B_Post_trained", + # text2world low-memory + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_80gb": "Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb", + "Cosmos-Predict1-7B-Text2World_post-trained-8gpu_40gb": "Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_40gb": "Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb", + # text2world lora + "Cosmos-Predict1-7B-Text2World_post-trained-lora": "Cosmos_Predict1_Text2World_7B_Post_trained_lora", + # video2world + "Cosmos-Predict1-7B-Video2World": "Cosmos_Predict1_Video2World_7B", + "Cosmos-Predict1-14B-Video2World": "Cosmos_Predict1_Video2World_14B", + "Cosmos-Predict1-7B-Video2World_post-trained": "Cosmos_Predict1_Video2World_7B_Post_trained", + "Cosmos-Predict1-14B-Video2World_post-trained": "Cosmos_Predict1_Video2World_14B_Post_trained", + # video2world low-memory + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_80gb": "Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb", + "Cosmos-Predict1-7B-Video2World_post-trained-8gpu_40gb": "Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_40gb": "Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb", + # video2world lora + "Cosmos-Predict1-7B-Video2World_post-trained-lora": "Cosmos_Predict1_Video2World_7B_Post_trained_lora", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview": "Cosmos_Predict1_Text2World_7B_Multiview", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview": "Cosmos_Predict1_Video2World_7B_Multiview", + "Cosmos-Predict1-7B-WorldInterpolator": "Cosmos_Predict1_WorldInterpolator_7B", + # Gen3c + "Gen3C-Cosmos-7B": "GEN3C_Cosmos_7B", +} + + +class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + disable_prompt_encoder: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + ): + """Initialize the diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + disable_prompt_encoder: Whether to disable prompt encoder + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + """ + assert inference_type in [ + "text2world", + "video2world", + "world_interpolator", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + self.model_name = MODEL_NAME_DICT[checkpoint_name] + self.guidance = guidance + self.num_steps = num_steps + self.height = height + self.width = width + self.fps = fps + self.num_video_frames = num_video_frames + self.seed = seed + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + disable_prompt_encoder=disable_prompt_encoder, + ) + self.prompt_upsampler_dir = prompt_upsampler_dir + self.enable_prompt_upsampler = enable_prompt_upsampler + self.offload_prompt_upsampler = offload_prompt_upsampler + + self.prompt_upsampler = None + if enable_prompt_upsampler and not offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionT2WModel, + ) + + def _load_network(self): + load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}/model.pt") + + def _load_tokenizer(self): + load_tokenizer_model(self.model, f"{self.checkpoint_dir}/Cosmos-Tokenize1-CV8x8x8-720p") + + def _offload_prompt_upsampler_model(self): + """Move prompt enhancement model to CPU/disk. + + Offloads prompt upsampling model after processing input + to reduce GPU memory usage. + """ + if self.prompt_upsampler: + del self.prompt_upsampler + self.prompt_upsampler = None + gc.collect() + torch.cuda.empty_cache() + + def _run_prompt_upsampler_on_prompt(self, prompt: str) -> str: + """Enhance the input prompt using the prompt upsampler model. + + Args: + prompt: Raw text prompt to be enhanced + + Returns: + str: Enhanced version of the input prompt with more descriptive details + """ + upsampled_prompt = run_chat_completion(self.prompt_upsampler, prompt) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _run_prompt_upsampler_on_prompt_with_offload(self, *args: Any, **kwargs: Any) -> str: + """Enhance prompt with prompt upsampler model. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Enhanced prompt string + """ + if self.offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + enhanced_prompt = self._run_prompt_upsampler_on_prompt(*args, **kwargs) + + if self.offload_prompt_upsampler: + self._offload_prompt_upsampler_model() + + return enhanced_prompt + + def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: + """Decode latent samples to video frames using the tokenizer decoder. + + Args: + sample: Latent tensor from diffusion model [B, C, T, H, W] + + Returns: + np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] + with values in range [0, 255] + """ + # Decode video + video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] + video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + + return video + + def _run_model( + self, + embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Generate video latents using the diffusion model. + + Args: + embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + torch.Tensor: Generated video latents before tokenizer decoding + + Note: + The model and tokenizer are automatically offloaded after inference + if offloading is enabled in the config. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + sample = generate_world_from_text( + model=self.model, + state_shape=state_shape, + is_negative_prompt=True if negative_prompt_embedding is not None else False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + ) + + return sample + + def _run_model_with_offload( + self, prompt_embedding: torch.Tensor, negative_prompt_embedding: Optional[torch.Tensor] = None + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation + """ + if self.offload_network: + self._load_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_model(prompt_embedding, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + return sample + + def generate( + self, + prompt: str, + negative_prompt: Optional[str] = None, + word_limit_to_skip_upsampler: Optional[int] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + negative_prompt: Optional text to guide what not to generate + word_limit_to_skip_upsampler: Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if not self.disable_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + # Enhance prompt + if self.enable_prompt_upsampler: + word_count = len(prompt.split()) + if word_limit_to_skip_upsampler is None or word_count <= word_limit_to_skip_upsampler: + log.info("Run prompt upsampler on prompt") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(prompt) + if not self.disable_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt=prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + else: + log.info( + f"Skip prompt upsampler for better robustness because the number of words ({word_count}) in the prompt is greater than {word_limit_to_skip_upsampler}" + ) + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + +class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + disable_prompt_encoder: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + num_input_frames: int = 1, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + disable_prompt_encoder: Whether to disable prompt encoder + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + disable_prompt_encoder=disable_prompt_encoder, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + ) + + def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str: + """Enhance the input prompt using visual context from the conditioning image. + + Args: + image_or_video_path: Path to conditioning image or video used for visual context + + Returns: + str: Enhanced prompt incorporating visual details from the image + """ + dialog = prepare_dialog(image_or_video_path) + upsampled_prompt = run_chat_completion_vlm( + self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False + ) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_vlm_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionV2WModel, + ) + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + ) -> torch.Tensor: + """Generate video frames using the diffusion model. + + Args: + embedding: Text embedding tensor from T5 encoder + condition_latent: Latent tensor from conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + Tensor of generated video frames + + Note: + Model and tokenizer are automatically offloaded after inference + if offloading is enabled. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=self.model.state_shape, + is_negative_prompt=True, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video + + def _run_tokenizer_encoding(self, image_or_video_path: str) -> torch.Tensor: + """ + Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent = get_condition_latent( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=self.model.state_shape, + ) + + return condition_latent + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + condition_latent = self._run_tokenizer_encoding(image_or_video_path) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def generate( + self, + prompt: str, + image_or_video_path: str, + negative_prompt: Optional[str] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_or_video_path: Path to conditioning image or video + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + + log.info(f"Run with image or video path: {image_or_video_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if self.enable_prompt_upsampler: + log.info("Run prompt upsampler on image or video, input prompt is not used") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path) + + log.info(f"Run with prompt: {prompt}") + if not self.disable_guardrail: + log.info(f"Run guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical(f"Input {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt is not safe") + return None + log.info(f"Pass guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + else: + log.info("Not running guardrail") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_or_video_path, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + +class DiffusionText2WorldMultiviewGenerationPipeline(DiffusionText2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + n_views: int = 6, + frame_repeat_negative_condition: int = 10, + seed: int = 0, + ): + """Initialize the diffusion multi-view world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + n_views: Number of views + frame_repeat_negative_condition: Number of frames to repeat to be used as negative condition. + seed: Random seed for sampling + """ + assert inference_type in [ + "text2world", + "video2world", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + self.n_views = n_views + self.frame_repeat_negative_condition = frame_repeat_negative_condition + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=False, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionMultiviewT2WModel, + ) + + def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: + """Decode latent samples to video frames using the tokenizer decoder. + + Args: + sample: Latent tensor from diffusion model [B, C, T, H, W] + + Returns: + np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] + with values in range [0, 255] + """ + # Decode video + video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] + video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.n_views) + grid_video = torch.stack( + [video_segments[:, :, i] for i in [1, 0, 2, 4, 3, 5]], + dim=2, + ) + grid_video = einops.rearrange(grid_video, "b c (h w) t h1 w1 -> b c t (h h1) (w w1)", h=2, w=3) + grid_video = (grid_video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + + return [grid_video, video] + + def _run_model( + self, + embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Generate video latents using the diffusion model. + + Args: + embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + torch.Tensor: Generated video latents before tokenizer decoding + + Note: + The model and tokenizer are automatically offloaded after inference + if offloading is enabled in the config. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch_for_multiview_model( + model=self.model, + prompt_embedding=embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames * len(embedding), # number of views + frame_repeat_negative_condition=self.frame_repeat_negative_condition, + ) + + # Generate video frames + sample = generate_world_from_text( + model=self.model, + state_shape=state_shape, + is_negative_prompt=False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + ) + + return sample + + def generate( + self, + prompt: dict, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Convert prompt to embeddings + 2. Generate video frames using diffusion + + Args: + prompt: A dictionary of text description of desired video. + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + + prompts = [ + prompt["prompt"], + prompt["prompt_left"], + prompt["prompt_right"], + prompt["prompt_back"], + prompt["prompt_back_left"], + prompt["prompt_back_right"], + ] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + videos = self._run_model_with_offload( + prompt_embeddings, + ) + log.info("Finish generation") + + return videos, prompt + + +class DiffusionVideo2WorldMultiviewGenerationPipeline(DiffusionText2WorldMultiviewGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + num_input_frames: int = 1, + n_views: int = 6, + frame_repeat_negative_condition: int = 10, + ): + """Initialize diffusion world multi-view generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + n_views=n_views, + frame_repeat_negative_condition=frame_repeat_negative_condition, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionMultiviewV2WModel, + ) + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + data_batch: dict = None, + state_shape: list = None, + ) -> torch.Tensor: + """Generate video frames using the diffusion model. + + Args: + embedding: Text embedding tensor from T5 encoder + condition_latent: Latent tensor from conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + Tensor of generated video frames + + Note: + Model and tokenizer are automatically offloaded after inference + if offloading is enabled. + """ + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=state_shape, + is_negative_prompt=False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video + + def _run_tokenizer_encoding(self, image_or_video_path: str, state_shape: list) -> torch.Tensor: + """ + Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent, condition_frames = get_condition_latent_multiview( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=state_shape, + ) + + return condition_latent, condition_frames + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + data_batch, state_shape = get_video_batch_for_multiview_model( + model=self.model, + prompt_embedding=prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames * len(prompt_embedding), # number of views + frame_repeat_negative_condition=self.frame_repeat_negative_condition, + ) + + condition_latent, condition_frames = self._run_tokenizer_encoding(image_or_video_path, state_shape) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding, data_batch, state_shape) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def generate( + self, + prompt: dict, + image_or_video_path: str, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Convert prompt to embeddings + 2. Generate video frames using diffusion + + Args: + prompt: A dictionary of text description of desired video. + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + + prompts = [ + prompt["prompt"], + prompt["prompt_left"], + prompt["prompt_right"], + prompt["prompt_back"], + prompt["prompt_back_left"], + prompt["prompt_back_right"], + ] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embeddings, + image_or_video_path=image_or_video_path, + ) + log.info("Finish generation") + + return video, prompt + + +class DiffusionWorldInterpolatorGenerationPipeline(DiffusionVideo2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = -1.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 11, + num_input_frames: int = 1, + num_frame_pairs: int = 1, + frame_index_start: int = 0, + frame_stride: int = 1, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + self.num_frame_pairs = num_frame_pairs + self.frame_index_start = frame_index_start + self.frame_stride = frame_stride + self.num_steps = num_steps + self.height = height + self.width = width + self.fps = fps + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + num_input_frames=num_input_frames, + ) + + def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str: + """Enhance the input prompt using visual context from the conditioning image. + + Args: + image_or_video_path: Path to conditioning image or video used for visual context + + Returns: + str: Enhanced prompt incorporating visual details from the image + """ + dialog = prepare_dialog(image_or_video_path) + upsampled_prompt = run_chat_completion_vlm( + self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False + ) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_vlm_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionWorldInterpolatorWModel, + ) + + @torch.inference_mode() + def _run_model( + self, + condition_latent: torch.Tensor | None = None, + negative_prompt_embedding: torch.Tensor | None = None, + num_of_loops: int = 1, + num_of_latent_overlap_list: list[int] = [1], + augment_sigma_list: list[float] = [0.001], + add_input_frames_guidance: float = 0, + skip_reencode: int = 0, + state_shape: list = None, + raw_data_batch: dict = None, + ) -> np.ndarray: + """Generate video frames using the diffusion model, supporting chunk processing for video extension. + + Args: + condition_latent: Latent tensor from conditioning image or video (optional for video extension). + negative_prompt_embedding: Optional embedding for negative prompt guidance. + num_of_loops: Number of loops for generating video segments. + num_of_latent_overlap_list: List of overlaps for latent conditions in each loop. + augment_sigma_list: List of sigma values for augmentation. + add_input_frames_guidance: Guidance strength for input frames. + skip_reencode: Whether to skip reencoding. + frame_index_start: Starting index for frame pairs. + num_frame_pairs: Number of frame pairs to process. + frame_stride: Stride between frame pairs. + is_interpolator_model: Whether the model is an interpolator. + input_frames: Input video frames for interpolation (optional). + + Returns: + np.ndarray: Generated video frames in shape (T, H, W, C). + """ + video_np_THWC, _, _ = generate_video_from_batch_with_loop( + model=self.model, + data_batch=raw_data_batch, + condition_latent=condition_latent, + num_of_loops=num_of_loops, + num_of_latent_overlap_list=num_of_latent_overlap_list, + guidance=self.guidance, + state_shape=state_shape, + num_steps=self.num_steps, + seed=self.seed, + is_negative_prompt=True if negative_prompt_embedding is not None else False, + visualize=False, + save_fig_path=None, + augment_sigma_list=augment_sigma_list, + add_input_frames_guidance=add_input_frames_guidance, + skip_reencode=skip_reencode, + ) + + return video_np_THWC + + def _run_tokenizer_encoding( + self, image_or_video_path: str, frame_index: int = 0, frame_stride: int = 1 + ) -> torch.Tensor: + """Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + frame_index: Starting frame index for encoding + frame_stride: Stride between frames for encoding + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent = get_condition_latent( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=self.model.state_shape, + frame_index=frame_index, + frame_stride=frame_stride, + ) + + return condition_latent + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + frame_index_start: int = 0, + num_frame_pairs: int = 1, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + frame_index_start: Starting index for frame pairs + num_frame_pairs: Number of frame pairs to process + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + # Prepare video batch and state shape + raw_data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + H, W = ( + state_shape[-2] * self.model.tokenizer.spatial_compression_factor, + state_shape[-1] * self.model.tokenizer.spatial_compression_factor, + ) + + input_path_format = image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + + num_frames = input_frames.shape[2] + num_frame_pairs = num_frame_pairs or num_frames // self.frame_stride + frame_stride = self.frame_stride + + video_output = [] + for frame_index in range(frame_index_start, num_frame_pairs): + print(f"Processing frame pair {frame_index + 1} / {num_frame_pairs}...") + + condition_latent = self._run_tokenizer_encoding(image_or_video_path, frame_index, frame_stride) + + video_np_THWC = self._run_model( + condition_latent=condition_latent, + negative_prompt_embedding=negative_prompt_embedding, + raw_data_batch=raw_data_batch, + state_shape=state_shape, + ) + + # Convert to tensor, rearrange, and normalize to [0, 1] + video_0_1 = einops.rearrange(torch.from_numpy(video_np_THWC), "t h w c -> c t h w") / 255.0 + + # Handle overlap by skipping the first frame of subsequent segments + if len(video_output) == 0: + video_output.append(video_0_1) + else: + video_output.append(video_0_1[:, 1:, :, :]) # Skip first frame to avoid duplication + + # Concatenate all segments + video_tensor = torch.cat(video_output, dim=1) # Shape: (C, total_num_frames, H, W) + + # Convert to NumPy array for guardrail: [T, H, W, C], uint8, [0, 255] + video_np = (video_tensor.permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() # Shape: (T, H, W, C) + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + + return video_np + + def generate( + self, + prompt: str, + image_or_video_path: str, + negative_prompt: Optional[str] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_or_video_path: Path to conditioning image or video + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with image or video path: {image_or_video_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if not self.disable_guardrail and not self.enable_prompt_upsampler: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + else: + log.info("Run prompt upsampler on image or video, input prompt is not used") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path) + + if not self.disable_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_or_video_path, + frame_index_start=self.frame_index_start, + num_frame_pairs=self.num_frame_pairs, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt diff --git a/cosmos_predict1/diffusion/inference/world_interpolator.py b/cosmos_predict1/diffusion/inference/world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fd640a52438062c85919083225254fc58a98dc --- /dev/null +++ b/cosmos_predict1/diffusion/inference/world_interpolator.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CUDA_VISIBLE_DEVICES=1 python3 -m cosmos_predict1.diffusion.inference.world_interpolator \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-Predict1-7B-WorldInterpolator \ + --input_image_or_video_path assets/diffusion/interpolation_example.mp4 \ + --num_input_frames 1 \ + --offload_prompt_upsampler \ + --video_save_name diffusion-world-interpolator-7b \ + --num_video_frames 10 \ + --num_frame_pairs 2 +""" + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, check_input_frames, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionWorldInterpolatorGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +# from cosmos_predict1.utils.visualize.video import save_img_or_video +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-WorldInterpolator", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-WorldInterpolator", + "Cosmos-Predict1-7B-WorldInterpolator_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=2, + help="The minimum number of input frames for world_interpolator predictions.", + ) + # parser.add_argument("--num_video_frames", type=int, default=118, help="numer of video frames to sample") + parser.add_argument("--pixel_chunk_duration", type=int, default=121, help="pixel chunk duration") + parser.add_argument( + "--frame_stride", + type=int, + default=1, + help="Specifies the gap between frames used for interpolation. A step_size of 1 means consecutive frame " + "pairs are treated as inputs (e.g., (x0, x1), (x1, x2)), while a step_size of 2 pairs frames with one " + "frame in between (e.g., (x0, x2), (x2, x4) are treated as input at a time). Increasing this value " + "results in interpolation over a larger temporal range. Default is 1.", + ) + parser.add_argument( + "--frame_index_start", + type=int, + default=0, + help="Specifies the gap between frames used for interpolation. A step_size of 1 means consecutive frame " + "pairs are treated as inputs (e.g., (x0, x1), (x1, x2)), while a step_size of 2 pairs frames with one " + "frame in between (e.g., (x0, x2), (x2, x4) are treated as input at a time). Increasing this value " + "results in interpolation over a larger temporal range. Default is 1.", + ) + parser.add_argument( + "--num_frame_pairs", + type=int, + default=None, + help="Limits the number of unique frame pairs processed for interpolation. By default (None), the interpolator " + "runs on all possible pairs extracted from the input video with the given step_size. If set to 1, only the first " + "frame pair is processed (e.g., (x0, x1) for step_size=1, (x0, x2) for step_size=2). Higher values allow processing more " + "pairs up to the maximum possible with the given step_size.", + ) + return parser.parse_args() + + +def demo(args): + """Run world-interpolator generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + # import ipdb; ipdb.set_trace() + misc.set_random_seed(args.seed) + inference_type = "world_interpolator" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video_interpolator generation model pipeline + pipeline = DiffusionWorldInterpolatorGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + num_input_frames=args.num_input_frames, + num_frame_pairs=args.num_frame_pairs, + frame_stride=args.frame_stride, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_or_video_path}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_or_video_path = input_dict.get("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + negative_prompt=args.negative_prompt, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + # Save video + + video_save_path = os.path.join(args.video_save_folder, args.video_save_name + ".mp4") + prompt_save_path = os.path.join(args.video_save_folder, args.video_save_name + ".txt") + + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + with open(prompt_save_path, "w") as f: + f.write(prompt) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py b/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8d6b31e91d8e612ad3561aebd61533a1e80e4c --- /dev/null +++ b/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.modules.res_sampler import Sampler +from cosmos_predict1.utils import log, misc + +IS_PREPROCESSED_KEY = "is_preprocessed" +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.types import DenoisePrediction + + +class DiffusionWorldInterpolatorWModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.input_image_key = getattr(self.config, "input_image_key", None) + self.input_data_key = self.config.input_data_key + self.sampler = Sampler() # Added to resolve the AttributeError + self.scaling = EDMScaling(self.sigma_data) + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: VideoExtendCondition) -> DenoisePrediction: + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + condition_dict = { + k: v.to(self.precision) if isinstance(v, torch.Tensor) else v for k, v in condition.to_dict().items() + } + net_output = self.net( + x=batch_mul(c_in, xt), + timesteps=c_noise, + **condition_dict, + ) + logvar = self.model.logvar(c_noise) if hasattr(self.model, "logvar") else None + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + return DenoisePrediction(x0_pred, eps_pred, logvar) + + def _normalize_video_databatch_inplace(self, data_batch: Dict[str, Tensor]) -> None: + if self.input_data_key in data_batch: + if IS_PREPROCESSED_KEY not in data_batch or not data_batch[IS_PREPROCESSED_KEY]: + assert data_batch[self.input_data_key].dtype == torch.uint8, "Video data must be uint8." + data_batch[self.input_data_key] = data_batch[self.input_data_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: Dict[str, Tensor]) -> None: + if self.input_image_key in data_batch: + if IS_PREPROCESSED_KEY not in data_batch or not data_batch[IS_PREPROCESSED_KEY]: + data_batch[self.input_image_key] = rearrange( + data_batch[self.input_image_key], "b c h w -> b c 1 h w" + ).contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def is_image_batch(self, data_batch: Dict[str, Tensor]) -> bool: + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert is_image != is_video, "Batch must contain either image or video data, not both or neither." + return is_image + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Optional[int] = None + ) -> VideoExtendCondition: + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.debug( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + B, C, T, H, W = latent_state.shape + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + if condition.video_cond_bool: + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: + condition.condition_video_input_mask = zeros_padding + return condition + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + assert condition.gt_latent.allclose(uncondition.gt_latent) + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + return condition, uncondition + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + noise = misc.arch_invariant_rand(latent.shape, torch.float32, self.tensor_kwargs["device"], seed) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = split_inputs_cp(condition.gt_latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator + + def _reverse_precondition_input(self, xt: Tensor, sigma: Tensor) -> Tensor: + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + xt_unscaled = xt / c_in + return xt_unscaled + + def _reverse_precondition_output(self, latent: Tensor, xt: Tensor, sigma: Tensor) -> Tensor: + sigma_data = self.scheduler.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + latent_unscaled = latent / c_out - c_skip * xt + return latent_unscaled + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + assert condition_latent is not None, "condition_latent must be provided for video generation." + condition, uncondition = self._get_conditions( + data_batch, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + add_input_frames_guidance=add_input_frames_guidance, + ) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_xt, cond_latent, cond_indicator = self._augment_noise_with_latent( + noise_x, + sigma, + condition, + condition_augment_sigma=condition_video_augment_sigma_in_inference or 0.001, + seed=seed_inference, + ) + cond_pred = self.denoise(cond_xt, sigma, condition) + cond_x0 = cond_pred.x0_pred_replaced if hasattr(cond_pred, "x0_pred_replaced") else cond_pred.x0 + uncond_xt, _, _ = self._augment_noise_with_latent( + noise_x, + sigma, + uncondition, + condition_augment_sigma=condition_video_augment_sigma_in_inference or 0.001, + seed=seed_inference, + ) + uncond_pred = self.denoise(uncond_xt, sigma, uncondition) + uncond_x0 = uncond_pred.x0_pred_replaced if hasattr(uncond_pred, "x0_pred_replaced") else uncond_pred.x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + self._normalize_video_databatch_inplace(data_batch) + # self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + assert condition_latent is not None, "condition_latent should be provided" + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, + ) + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * 80 + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=80) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / 80 + return samples diff --git a/cosmos_predict1/diffusion/model/model_gen3c.py b/cosmos_predict1/diffusion/model/model_gen3c.py new file mode 100644 index 0000000000000000000000000000000000000000..2c77228d50013b89559f1f7bd3a831101861b5fc --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_gen3c.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition + + +class DiffusionGen3CModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.frame_buffer_max = config.frame_buffer_max + self.chunk_size = 121 + + def encode_warped_frames( + self, + condition_state: torch.Tensor, + condition_state_mask: torch.Tensor, + dtype: torch.dtype, + ): + + assert condition_state.dim() == 6 + condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1) + latent_condition = [] + for i in range(condition_state.shape[2]): + current_video_latent = self.encode( + condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) + ).contiguous() # 1, 16, 8, 88, 160 + + current_mask_latent = self.encode( + condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) + ).contiguous() + latent_condition.append(current_video_latent) + latent_condition.append(current_mask_latent) + for _ in range(self.frame_buffer_max - condition_state.shape[2]): + latent_condition.append(torch.zeros_like(current_video_latent)) + latent_condition.append(torch.zeros_like(current_mask_latent)) + + latent_condition = torch.cat(latent_condition, dim=1) + return latent_condition + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + # encode warped frames + condition_state, condition_state_mask = ( + data_batch["condition_state"], + data_batch["condition_state_mask"], + ) + latent_condition = self.encode_warped_frames( + condition_state, condition_state_mask, self.tensor_kwargs["dtype"] + ) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + condition = self.add_condition_pose(latent_condition, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + uncondition = self.add_condition_pose(latent_condition, uncondition, drop_out_latent = True) + assert condition.gt_latent.allclose(uncondition.gt_latent) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + def add_condition_pose(self, latent_condition: torch.Tensor, condition: VideoExtendCondition, + drop_out_latent: bool = False) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + if drop_out_latent: + condition.condition_video_pose = torch.zeros_like(latent_condition.contiguous()) + else: + condition.condition_video_pose = latent_condition.contiguous() + + to_cp = self.net.is_context_parallel_enabled + + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition diff --git a/cosmos_predict1/diffusion/model/model_t2w.py b/cosmos_predict1/diffusion/model/model_t2w.py new file mode 100644 index 0000000000000000000000000000000000000000..b2910e09ec3e9779f6b962251a3bba8c70dddbb8 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_t2w.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers import EDMEulerScheduler +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition +from cosmos_predict1.diffusion.module import parallel +from cosmos_predict1.diffusion.module.blocks import FourierFeatures +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.distributed import get_rank +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +class DiffusionT2WModel(torch.nn.Module): + """Text-to-world diffusion model that generates video frames from text descriptions. + + This model implements a diffusion-based approach for generating videos conditioned on text input. + It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, + and classifier-free guidance. + """ + + def __init__(self, config): + """Initialize the diffusion model. + + Args: + config: Configuration object containing model parameters and architecture settings + """ + super().__init__() + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.config = config + + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.debug(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.0002, sigma_data=self.sigma_data) + self.tokenizer = None + self.model = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + @property + def logvar(self): + return self.model.logvar + + def set_up_tokenizer(self, tokenizer_dir: str): + self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) + self.tokenizer.load_weights(tokenizer_dir) + if hasattr(self.tokenizer, "reset_dtype"): + self.tokenizer.reset_dtype() + + @misc.timer("DiffusionModel: set_up_model") + def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): + """Initialize the core model components including network, conditioner and logvar.""" + self.model = self.build_model() + if self.config.peft_control and self.config.peft_control.enabled: + log.info("Setting up LoRA layers") + peft_control_config_parser = LayerControlConfigParser(config=self.config.peft_control) + peft_control_config = peft_control_config_parser.parse() + add_lora_layers(self.model, peft_control_config) + num_lora_params = setup_lora_requires_grad(self.model) + self.model.requires_grad_(False) + if num_lora_params == 0: + raise ValueError("No LoRA parameters found. Please check the model configuration.") + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + def build_model(self) -> torch.nn.ModuleDict: + """Construct the model's neural network components. + + Returns: + ModuleDict containing the network, conditioner and logvar components + """ + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """Encode input state into latent representation using VAE. + + Args: + state: Input tensor to encode + + Returns: + Encoded latent representation scaled by sigma_data + """ + return self.tokenizer.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """Decode latent representation back to pixel space using VAE. + + Args: + latent: Latent tensor to decode + + Returns: + Decoded tensor in pixel space + """ + return self.tokenizer.decode(latent / self.sigma_data) + + def setup_data_key(self) -> None: + """Configure input data keys for video and image data.""" + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + ) -> Tensor: + """Generate samples from a data batch using diffusion sampling. + + This function generates samples from either image or video data batches using diffusion sampling. + It handles both conditional and unconditional generation with classifier-free guidance. + + Args: + data_batch (dict): Raw data batch from the training data loader + guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. + seed (int, optional): Random seed for reproducibility. Defaults to 1. + state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. + n_sample (int | None, optional): Number of samples to generate. Defaults to 1. + is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. + num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. + + Returns: + Tensor: Generated samples after diffusion sampling + """ + condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition diff --git a/cosmos_predict1/diffusion/model/model_t2w_multiview.py b/cosmos_predict1/diffusion/model/model_t2w_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..8008cac54bd74d4ef835fca63218dffd5fc3fe31 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_t2w_multiview.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from einops import rearrange +from torch import Tensor + +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionMultiviewT2WModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.net.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.tokenizer.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.tokenizer.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + ) -> Tensor: + """Generate samples from a data batch using diffusion sampling. + + This function generates samples from either image or video data batches using diffusion sampling. + It handles both conditional and unconditional generation with classifier-free guidance. + + Args: + data_batch (dict): Raw data batch from the training data loader + guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. + seed (int, optional): Random seed for reproducibility. Defaults to 1. + state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. + n_sample (int | None, optional): Number of samples to generate. Defaults to 1. + is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. + num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. + + Returns: + Tensor: Generated samples after diffusion sampling + """ + condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = rearrange(xt, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + xt = rearrange(xt, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples diff --git a/cosmos_predict1/diffusion/model/model_v2w.py b/cosmos_predict1/diffusion/model/model_v2w.py new file mode 100644 index 0000000000000000000000000000000000000000..246485e94c8f4400149469fbce8b0d2965de1271 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_v2w.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionV2WModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Optional[int] = None + ) -> VideoExtendCondition: + """Adds conditioning masks to VideoExtendCondition object. + + Creates binary indicators and input masks for conditional video generation. + + Args: + latent_state: Input latent tensor (B,C,T,H,W) + condition: VideoExtendCondition object to update + num_condition_t: Number of frames to condition on + + Returns: + Updated VideoExtendCondition with added masks: + - condition_video_indicator: Binary tensor marking condition regions + - condition_video_input_mask: Input mask for network + - gt_latent: Ground truth latent tensor + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.debug( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + return condition + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + condition_augment_sigma: float = None, + add_input_frames_guidance: bool = False, + ) -> Tensor: + """Generates video samples conditioned on input frames. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + seed: Random seed for reproducibility + state_shape: Shape of output tensor (defaults to model's state shape) + n_sample: Number of samples to generate (defaults to batch size) + is_negative_prompt: Whether to use negative prompting + num_steps: Number of denoising steps + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_augment_sigma: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + Generated video samples tensor + """ + assert condition_latent is not None, "condition_latent should be provided" + condition, uncondition = self._get_conditions( + data_batch, is_negative_prompt, condition_latent, num_condition_t, add_input_frames_guidance + ) + + self.scheduler.set_timesteps(num_steps) + if n_sample is None: + n_sample = condition_latent.shape[0] + xt = torch.randn(size=(n_sample,) + tuple(state_shape), **self.tensor_kwargs) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + self.scheduler._init_step_index(t) + sigma = self.scheduler.sigmas[self.scheduler.step_index].to(**self.tensor_kwargs) + # Form new noise from latent + xt = xt.to(**self.tensor_kwargs) + new_xt, latent, indicator = self._augment_noise_with_latent( + xt, sigma, condition, condition_augment_sigma=condition_augment_sigma, seed=seed + ) + new_xt = new_xt.to(**self.tensor_kwargs) + new_xt_scaled = self.scheduler.scale_model_input(new_xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=new_xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=new_xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Replace indicated output with latent + latent_unscaled = self._reverse_precondition_output(latent, xt=new_xt, sigma=sigma) + new_output = indicator * latent_unscaled + (1 - indicator) * net_output + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(new_output, t, new_xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + assert condition.gt_latent.allclose(uncondition.gt_latent) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + """Augments the conditional frames with noise during inference. + + Args: + xt (Tensor): noise + sigma (Tensor): noise level for the generation region + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + condition_augment_sigma (float): sigma for condition video augmentation in inference + seed (int): random seed for reproducibility + Returns: + new_xt (Tensor): new latent-augmented noise tensor in shape B,C,T,H,W + latent (Tensor): ground-truth latent tensor in shape B,C,T,H,W + indicator (Tensor): ground-truth latent binary indicator tensor in shape B,C,T,H,W + + """ + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + # Now apply the augment_sigma to the gt_latent + noise = misc.arch_invariant_rand( + latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = split_inputs_cp(condition.gt_latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator + + def _reverse_precondition_input(self, xt: Tensor, sigma: Tensor) -> Tensor: + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + xt_unscaled = xt / c_in + return xt_unscaled + + def _reverse_precondition_output(self, latent: Tensor, xt: Tensor, sigma: Tensor) -> Tensor: + sigma_data = self.scheduler.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + latent_unscaled = (latent - c_skip * xt) / c_out + return latent_unscaled diff --git a/cosmos_predict1/diffusion/model/model_v2w_multiview.py b/cosmos_predict1/diffusion/model/model_v2w_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..ece347fd51e50163b9d0d8c8eb7cbf3a57cce58f --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_v2w_multiview.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from einops import rearrange +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionMultiviewV2WModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.net.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.tokenizer.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.tokenizer.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + return condition + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + condition_augment_sigma: float = None, + add_input_frames_guidance: bool = False, + ) -> Tensor: + """Generates video samples conditioned on input frames. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + seed: Random seed for reproducibility + state_shape: Shape of output tensor (defaults to model's state shape) + n_sample: Number of samples to generate (defaults to batch size) + is_negative_prompt: Whether to use negative prompting + num_steps: Number of denoising steps + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_augment_sigma: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + Generated video samples tensor + """ + assert condition_latent is not None, "condition_latent should be provided" + condition, uncondition = self._get_conditions( + data_batch, is_negative_prompt, condition_latent, num_condition_t, add_input_frames_guidance + ) + + self.scheduler.set_timesteps(num_steps) + if n_sample is None: + n_sample = condition_latent.shape[0] + xt = torch.randn(size=(n_sample,) + tuple(state_shape), **self.tensor_kwargs) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = rearrange(xt, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + xt = rearrange(xt, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + for t in self.scheduler.timesteps: + self.scheduler._init_step_index(t) + sigma = self.scheduler.sigmas[self.scheduler.step_index].to(**self.tensor_kwargs) + # Form new noise from latent + new_xt, latent, indicator = self._augment_noise_with_latent( + xt, sigma, condition, condition_augment_sigma=condition_augment_sigma, seed=seed + ) + new_xt = new_xt.to(**self.tensor_kwargs) + new_xt_scaled = self.scheduler.scale_model_input(new_xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=new_xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=new_xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Replace indicated output with latent + latent_unscaled = self._reverse_precondition_output(latent, xt=new_xt, sigma=sigma) + new_output = indicator * latent_unscaled + (1 - indicator) * net_output + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(new_output, t, new_xt).prev_sample + samples = xt + + if to_cp: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + """Augments the conditional frames with noise during inference. + + Args: + xt (Tensor): noise + sigma (Tensor): noise level for the generation region + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + condition_augment_sigma (float): sigma for condition video augmentation in inference + seed (int): random seed for reproducibility + Returns: + new_xt (Tensor): new latent-augmented noise tensor in shape B,C,T,H,W + latent (Tensor): ground-truth latent tensor in shape B,C,T,H,W + indicator (Tensor): ground-truth latent binary indicator tensor in shape B,C,T,H,W + + """ + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + # Now apply the augment_sigma to the gt_latent + noise = misc.arch_invariant_rand( + latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + indicator = rearrange(indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + augment_latent_unscaled = rearrange( + augment_latent_unscaled, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + + latent = split_inputs_cp(latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + + latent = rearrange(latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + indicator = rearrange(indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + augment_latent_unscaled = rearrange( + augment_latent_unscaled, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator diff --git a/cosmos_predict1/diffusion/model/model_world_interpolator.py b/cosmos_predict1/diffusion/model/model_world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..c82c7ab8c5e975e644fab120f749d30de93ff370 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_world_interpolator.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from statistics import NormalDist +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.config.base.conditioner import VideoCondBoolConfig +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.modules.res_sampler import Sampler +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.models.model import _broadcast +from cosmos_predict1.utils import log, misc + +IS_PREPROCESSED_KEY = "is_preprocessed" +from dataclasses import dataclass, fields + +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.diffusion.types import DenoisePrediction + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction + net_in: Optional[torch.Tensor] = None # input to the network + net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network + xt: Optional[torch.Tensor] = None # input to the network, before multiply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class DiffusionWorldInterpolatorWModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.setup_data_key() # Initialize input_data_key and input_image_key + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + self.sde = EDMSDE( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ) + + def setup_data_key(self) -> None: + """Initialize data keys for image and video inputs.""" + self.input_data_key = self.config.input_data_key + self.input_image_key = self.config.input_image_key + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """Determine if the data batch is an image batch or a video batch. + + Args: + data_batch (dict[str, Tensor]): Input data batch. + + Returns: + bool: True if the batch is an image batch, False if it is a video batch. + + Raises: + AssertionError: If both or neither of input_image_key and input_data_key are present. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """Normalizes video data in-place on a CUDA device to reduce data loading overhead. + + Args: + data_batch (dict[str, Tensor]): Dictionary containing the video data. + input_key (str, optional): Key for the video data in the batch. Defaults to self.input_data_key. + + Side Effects: + Modifies the video data tensor in-place to scale from [0, 255] to [-1, 1]. + """ + input_key = self.input_data_key if input_key is None else input_key + if input_key in data_batch: + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." + assert torch.all( + (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) + ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" + else: + assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." + data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """Augments image data in-place by adding a temporal dimension. + + Args: + data_batch (dict[str, Tensor]): Dictionary containing the image data. + input_key (str, optional): Key for the image data in the batch. Defaults to self.input_image_key. + + Side Effects: + Modifies the image data tensor in-place to add a temporal dimension (B,C,H,W -> B,C,1,H,W). + """ + input_key = self.input_image_key if input_key is None else input_key + if input_key in data_batch: + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert ( + data_batch[input_key].shape[2] == 1 + ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" + return + else: + data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def normalize_condition_latent(self, condition_latent: torch.Tensor) -> torch.Tensor: + """Normalize the condition latent tensor to have zero mean and unit variance.""" + condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") + mean = condition_latent_2D.mean(dim=-1) + std = condition_latent_2D.std(dim=-1) + mean = mean.unsqueeze(-1).unsqueeze(-1) + std = std.unsqueeze(-1).unsqueeze(-1) + condition_latent = (condition_latent - mean) / std + return condition_latent + + def draw_augment_sigma_and_epsilon( + self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float + ) -> Tuple[Tensor, Tensor]: + """Draw sigma and epsilon for augmenting conditional latent frames.""" + is_video_batch = condition.data_type == DataType.VIDEO + del condition + batch_size = size[0] + epsilon = torch.randn(size, **self.tensor_kwargs) + + gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) + + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed_inference: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """Augment the condition input with noise.""" + if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": + augment_sigma, _ = self.draw_augment_sigma_and_epsilon( + gt_latent.shape, + condition, + cfg_video_cond_bool.augment_sigma_sample_p_mean, + cfg_video_cond_bool.augment_sigma_sample_p_std, + cfg_video_cond_bool.augment_sigma_sample_multiplier, + ) + noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) + elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": + log.debug( + f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" + ) + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed_inference, + ) + else: + raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") + + augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) + _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) + + if cfg_video_cond_bool.condition_on_augment_sigma: + if condition.condition_video_indicator.sum() > 0: + condition.condition_video_augment_sigma = c_noise_augment + else: + condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) + + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def super_denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + if getattr(self.config, "use_dummy_temporal_dim", False): + # When using video DiT model for image, we need to use a dummy temporal dimension. + xt = xt.unsqueeze(2) + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + if getattr(self.config, "use_dummy_temporal_dim", False): + x0_pred = x0_pred.squeeze(2) + eps_pred = eps_pred.squeeze(2) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + def drop_out_condition_region( + self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig + ) -> Tensor: + """Drop out the conditional region for CFG on input frames.""" + if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": + augment_latent_drop = torch.zeros_like(augment_latent) + elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": + augment_latent_drop = noise_x + else: + raise NotImplementedError( + f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" + ) + return augment_latent_drop + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed_inference: int = 1, + ) -> VideoDenoisePrediction: + """Denoise the noisy input tensor for video data.""" + assert ( + condition.gt_latent is not None + ), "find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = self.normalize_condition_latent(condition_latent) + + condition, augment_latent = self.augment_conditional_latent_frames( + condition, + cfg_video_cond_bool, + condition_latent, + condition_video_augment_sigma_in_inference, + sigma, + seed_inference=seed_inference, + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + if not condition.video_cond_bool: + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + denoise_pred = self.super_denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Supports condition latent for video generation. + + Args: + data_batch (Dict): Input data batch. + guidance (float): Guidance scale for classifier-free guidance. + seed (int): Random seed for reproducibility. + state_shape (Tuple | None): Shape of the latent state, defaults to self.state_shape if None. + n_sample (int | None): Number of samples to generate, inferred from batch if None. + is_negative_prompt (bool): Use negative prompt for unconditioned generation. + num_steps (int): Number of sampling steps. + condition_latent (torch.Tensor | None): Latent tensor (B,C,T,H,W) as condition for video generation. + num_condition_t (int | None): Number of condition frames in T dimension. + condition_video_augment_sigma_in_inference (float): Sigma for augmenting condition video in inference. + add_input_frames_guidance (bool): Apply guidance to input frames for CFG. + return_noise (bool): Return initial noise along with samples. + + Returns: + Tensor | Tuple[Tensor, Tensor]: Generated samples, or (samples, noise) if return_noise is True. + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, + ) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` for denoising based on the data batch and condition latent. + + Args: + data_batch (Dict): Input data batch. + guidance (float): Guidance scale. + is_negative_prompt (bool): Use negative prompt for unconditioned generation. + condition_latent (torch.Tensor): Latent tensor (B,C,T,H,W) as condition. + num_condition_t (int | None): Number of condition frames. + condition_video_augment_sigma_in_inference (float): Sigma for condition augmentation. + add_input_frames_guidance (bool): Apply guidance to input frames. + seed_inference (int): Seed for inference noise. + + Returns: + Callable: Function `x0_fn(noise_x, sigma)` returning denoised prediction. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + # Should be used for both training and inference. The first and last frame will be condition frames. + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator[:, :, -num_condition_t:] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """ + Adds pose condition to the condition object for camera control. + + Args: + data_batch (Dict): Data batch with 'plucker_embeddings' or 'plucker_embeddings_downsample'. + condition (VideoExtendCondition): Condition object to update. + + Returns: + VideoExtendCondition: Updated condition object. + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition diff --git a/cosmos_predict1/diffusion/module/__init__.py b/cosmos_predict1/diffusion/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/module/attention.py b/cosmos_predict1/diffusion/module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a4b02dea9adbb9a320ed2631014d4102ed288 --- /dev/null +++ b/cosmos_predict1/diffusion/module/attention.py @@ -0,0 +1,313 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import numpy as np +import torch +import transformer_engine as te +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import checkpoint +from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb + +# ---------------------- Feed Forward Network ----------------------- + + +class FeedForward(nn.Module): + """ + Transformer FFN with optional gating + + Parameters: + d_model (int): Dimensionality of input features. + d_ff (int): Dimensionality of the hidden layer. + dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. + activation (callable, optional): The activation function applied after the first linear layer. + Defaults to nn.ReLU(). + is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. + Defaults to False. + bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. + + Example: + >>> ff = FeedForward(d_model=512, d_ff=2048) + >>> x = torch.randn(64, 10, 512) # Example input tensor + >>> output = ff(x) + >>> print(output.shape) # Expected shape: (64, 10, 512) + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + activation=nn.ReLU(), + is_gated: bool = False, + bias: bool = False, + ) -> None: + super().__init__() + + self.layer1 = nn.Linear(d_model, d_ff, bias=bias) + self.layer2 = nn.Linear(d_ff, d_model, bias=bias) + + self.dropout = nn.Dropout(dropout) + self.activation = activation + self.is_gated = is_gated + if is_gated: + self.linear_gate = nn.Linear(d_model, d_ff, bias=False) + + def forward(self, x: torch.Tensor): + g = self.activation(self.layer1(x)) + if self.is_gated: + x = g * self.linear_gate(x) + else: + x = g + assert self.dropout.p == 0.0, "we skip dropout" + return self.layer2(x) + + +class GPT2FeedForward(FeedForward): + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): + super().__init__( + d_model=d_model, + d_ff=d_ff, + dropout=dropout, + activation=nn.GELU(), + is_gated=False, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + assert self.dropout.p == 0.0, "we skip dropout" + + x = self.layer1(x) + + def activation_layer2_forward(x): + x = self.activation(x) + x = self.layer2(x) + return x + + x = checkpoint(activation_layer2_forward, x, use_reentrant=False) + return x + + +# ---------------------- Normalization Layer ----------------------- + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def get_normalization(name: str, channels: int): + if name == "I": + return nn.Identity() + elif name == "R": + return te.pytorch.RMSNorm(channels, eps=1e-6) + else: + raise ValueError(f"Normalization {name} not found") + + +class BaseAttentionOp(nn.Module): + def __init__(self): + super().__init__() + + +class Attention(nn.Module): + """ + Generalized attention impl. + + Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. + If `context_dim` is None, self-attention is assumed. + + Parameters: + query_dim (int): Dimension of each query vector. + context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. + heads (int, optional): Number of attention heads. Defaults to 8. + dim_head (int, optional): Dimension of each head. Defaults to 64. + dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. + attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. + qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. + out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. + qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. + Defaults to "SSI". + qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. + Defaults to 'per_head'. Only support 'per_head'. + + Examples: + >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) + >>> query = torch.randn(10, 128) # Batch size of 10 + >>> context = torch.randn(10, 256) # Batch size of 10 + >>> output = attn(query, context) # Perform the attention operation + + Note: + https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/cosmos_predict1/attention.py#L223 + """ + + def __init__( + self, + query_dim: int, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_op: Optional[BaseAttentionOp] = None, + qkv_bias: bool = False, + out_bias: bool = False, + qkv_norm: str = "SSI", + qkv_norm_mode: str = "per_head", + backend: str = "transformer_engine", + qkv_format: str = "bshd", + ) -> None: + super().__init__() + + self.is_selfattn = context_dim is None # self attention + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + self.qkv_norm_mode = qkv_norm_mode + self.qkv_format = qkv_format + + if self.qkv_norm_mode == "per_head": + norm_dim = dim_head + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + self.backend = backend + self.tp_size = 1 # TP is not included in this Attention implementation. + + self.to_q = nn.Sequential( + nn.Linear(query_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[0], norm_dim), + ) + self.to_k = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[1], norm_dim), + ) + self.to_v = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[2], norm_dim), + ) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim, bias=out_bias), + nn.Dropout(dropout), + ) + + if attn_op: # use what is given + self.attn_op = attn_op + elif self.backend == "transformer_engine": + self.attn_op: BaseAttentionOp = DotProductAttention( + self.heads, + self.dim_head, + num_gqa_groups=self.heads, + attention_dropout=0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + tp_size=self.tp_size, + tp_group=None, + sequence_parallel=False, + ) + elif self.backend == "torch": + self.attn_op = torch.nn.functional.scaled_dot_product_attention + else: + raise ValueError(f"Backend {backend} not found") + self.query_dim = query_dim + self.context_dim = context_dim + self.inner_dim = inner_dim + + def cal_qkv( + self, x, context=None, mask=None, rope_emb=None, **kwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + + """ + self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. + Before 07/24/2024, these modules normalize across all heads. + After 07/24/2024, to support tensor parallelism and follow the common practice in the community, + we support to normalize per head. + To keep the checkpoint copatibility with the previous code, + we keep the nn.Sequential but call the projection and the normalization layers separately. + We use a flag `self.qkv_norm_mode` to control the normalization behavior. + The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. + """ + if self.qkv_norm_mode == "per_head": + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + q, k, v = map( + lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head), + (q, k, v), + ) + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) + return q, k, v + + def cal_attn(self, q, k, v, mask=None): + if self.backend == "transformer_engine": + seq_dim = self.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] + return self.to_out(out) + elif self.backend == "torch": + q = rearrange(q, "s b h d -> b h s d") + k = rearrange(k, "s b h d -> b h s d") + v = rearrange(v, "s b h d -> b h s d") + out = self.attn_op(q, k, v) # [B, Mq, H, V] + return self.to_out(rearrange(out, " b h s d -> s b (h d)")) + else: + raise ValueError(f"Backend {self.backend} not found") + + def forward( + self, + x, + context=None, + mask=None, + rope_emb=None, + **kwargs, + ): + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) + return self.cal_attn(q, k, v, mask) diff --git a/cosmos_predict1/diffusion/module/blocks.py b/cosmos_predict1/diffusion/module/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..46b55225b078ec04983af2b3ab61712c6de9258f --- /dev/null +++ b/cosmos_predict1/diffusion/module/blocks.py @@ -0,0 +1,558 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import nn + +from cosmos_predict1.diffusion.module.attention import Attention, GPT2FeedForward +from cosmos_predict1.utils import log + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class Timesteps(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +class FourierFeatures(nn.Module): + """ + Implements a layer that generates Fourier features from input tensors, based on randomly sampled + frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. + + [B] -> [B, D] + + Parameters: + num_channels (int): The number of Fourier features to generate. + bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. + normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize + the variance of the features. Defaults to False. + + Example: + >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) + >>> x = torch.randn(10, 256) # Example input tensor + >>> output = layer(x) + >>> print(output.shape) # Expected shape: (10, 256) + """ + + def __init__(self, num_channels, bandwidth=1, normalize=False): + super().__init__() + self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) + self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) + self.gain = np.sqrt(2) if normalize else 1 + + def forward(self, x, gain: float = 1.0): + """ + Apply the Fourier feature transformation to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. + + Returns: + torch.Tensor: The transformed tensor, with Fourier features applied. + """ + in_dtype = x.dtype + x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) + x = x.cos().mul(self.gain * gain).to(in_dtype) + return x + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module processes video features while maintaining their spatio-temporal structure. It can perform + self-attention within the video features or cross-attention with external context features. + + Parameters: + x_dim (int): Dimension of input feature vectors + context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention + num_heads (int): Number of attention heads + bias (bool): Whether to include bias in attention projections. Default: False + qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head" + x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD" + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + + Input shape: + - x: (T, H, W, B, D) video features + - context (optional): (M, B, D) context features for cross-attention + where: + T: temporal dimension + H: height + W: width + B: batch size + D: feature dimension + M: context sequence length + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + n_views: int = 1, + ) -> None: + super().__init__() + self.x_format = x_format + self.n_views = n_views + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_norm_mode=qkv_norm_mode, + qkv_format="sbhd", + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), + where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) + context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) + else: + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) + return x_T_H_W_B_D + + +def adaln_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + A building block for the DiT (Diffusion Transformer) architecture that supports different types of + attention and MLP operations with adaptive layer normalization. + + Parameters: + block_type (str): Type of block - one of: + - "cross_attn"/"ca": Cross-attention + - "full_attn"/"fa": Full self-attention + - "mlp"/"ff": MLP/feedforward block + x_dim (int): Dimension of input features + context_dim (Optional[int]): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + bias (bool): Whether to use bias in layers. Default: False + mlp_dropout (float): Dropout rate for MLP. Default: 0.0 + qkv_norm_mode (str): QKV normalization mode. Default: "per_head" + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + bias: bool = False, + mlp_dropout: float = 0.0, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + qkv_norm_mode=qkv_norm_mode, + x_format=self.x_format, + n_views=n_views, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn( + x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format + ) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer. + Each block in the sequence is specified by a block configuration string. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention blocks + num_heads (int): Number of attention heads + block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention, + full-attention, then MLP) + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + n_views (int): Extra parameter used in multi-view diffusion model. Default: 1 + + The block_config string uses "-" to separate block types: + - "ca"/"cross_attn": Cross-attention block + - "fa"/"full_attn": Full self-attention block + - "mlp"/"ff": MLP/feedforward block + + Example: + block_config = "ca-fa-mlp" creates a sequence of: + 1. Cross-attention block + 2. Full self-attention block + 3. MLP block + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=n_views, + ) + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x diff --git a/cosmos_predict1/diffusion/module/parallel.py b/cosmos_predict1/diffusion/module/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0837a2c4fec1506f70c8ba2b7c9f602b84c862 --- /dev/null +++ b/cosmos_predict1/diffusion/module/parallel.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size +from torch.distributed.utils import _verify_param_shape_across_processes + +from cosmos_predict1.utils import distributed + + +def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Split input tensor along the sequence dimension for checkpoint parallelism. + + This function divides the input tensor into equal parts along the specified + sequence dimension, based on the number of ranks in the checkpoint parallelism group. + It then selects the part corresponding to the current rank. + + Args: + x: Input tensor to be split. + seq_dim: The dimension along which to split the input (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A slice of the input tensor corresponding to the current rank. + + Raises: + AssertionError: If the sequence dimension is not divisible by the number of ranks. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_group.rank()], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the checkpoint parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + world_size = get_world_size(cp_group) + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) + + +def broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: + """ + Broadcast the item from the minimum rank in the specified group(s). + Since global rank = tp_rank + cp_rank * tp_size + ... + First broadcast in the tp_group and then in the cp_group will + ensure that the item is broadcasted across ranks in cp_group and tp_group. + + Parameters: + - item: The item to broadcast (can be a torch.Tensor, str, or None). + - to_tp: Whether to broadcast to the tensor model parallel group. + - to_cp: Whether to broadcast to the context parallel group. + """ + if not parallel_state.is_initialized(): + return item + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 + to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 + + if to_tp: + min_tp_rank = min(get_process_group_ranks(tp_group)) + + if to_cp: + min_cp_rank = min(get_process_group_ranks(cp_group)) + + if isinstance(item, torch.Tensor): # assume the device is cuda + # log.info(f"{item.shape}", rank0_only=False) + if to_tp: + # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) + item = _robust_broadcast(item, min_tp_rank, tp_group) + if to_cp: + # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) + item = _robust_broadcast(item, min_cp_rank, cp_group) + elif item is not None: + broadcastable_list = [item] + if to_tp: + # log.info(f"{broadcastable_list}", rank0_only=False) + broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) + if to_cp: + broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) + + item = broadcastable_list[0] + return item + + +def _robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: + """ + Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). + src (int): The source rank for the broadcast. Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor on all ranks. + """ + # First, broadcast the shape of the tensor + if distributed.get_rank() == src: + shape = torch.tensor(tensor.shape).cuda() + else: + shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() + if is_check_shape: + _verify_param_shape_across_processes(pg, [shape]) + torch.distributed.broadcast(shape, src, group=pg) + + # Resize the tensor on non-src ranks if necessary + if distributed.get_rank() != src: + tensor = tensor.new_empty(shape.tolist()).type_as(tensor) + + # Now broadcast the tensor data + torch.distributed.broadcast(tensor, src, group=pg) + + return tensor diff --git a/cosmos_predict1/diffusion/module/position_embedding.py b/cosmos_predict1/diffusion/module/position_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4f93ededcb9619eba86cc61139778e18c54a8f86 --- /dev/null +++ b/cosmos_predict1/diffusion/module/position_embedding.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.diffusion.module.attention import normalize +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.module.timm import trunc_normal_ + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class VideoPositionEmb(nn.Module): + def __init__(self): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, VideoRopePosition3DEmb): + seq_dim = 0 + else: + seq_dim = 1 + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) + self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) + self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class MultiviewVideoPositionEmb(nn.Module): + def __init__( + self, + ): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + seq_dim = 1 + embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() + # rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() + else: + seq_dim = 1 + embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) + else: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class MultiviewVideoRopePosition3DEmb(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.n_views = n_views + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embedding_for_batch( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert uniform_fps # only support uniform fps now + + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return em_T_H_W_D + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. The camera view dimension is merged in the T dimension + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + + B, T, H, W, C = B_T_H_W_C + + single_view_B_T_H_W_C = (B, T // self.n_views, H, W, C) + em_T_H_W_D = torch.cat( + [ + self.generate_embedding_for_batch( + single_view_B_T_H_W_C, + fps=fps, + h_ntk_factor=h_ntk_factor, + w_ntk_factor=w_ntk_factor, + t_ntk_factor=t_ntk_factor, + ) + for item in range(self.n_views) + ], + dim=0, + ) + return em_T_H_W_D + + +class MultiviewSinCosPosEmbAxis(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + self.n_views = n_views + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + single_view_T = T // self.n_views + + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:single_view_T] + emb = torch.cat( + [ + torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_view_T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_view_T, h=H), + ], + dim=-1, + ) + for _ in range(self.n_views) + ], + 1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") diff --git a/cosmos_predict1/diffusion/module/pretrained_vae.py b/cosmos_predict1/diffusion/module/pretrained_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c799bfb1341b060afade60958a8162804001f50f --- /dev/null +++ b/cosmos_predict1/diffusion/module/pretrained_vae.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +from einops import rearrange +from torch.nn.modules import Module + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.name = name + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "image_mean_std.pt"), weights_only=True) + + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, latent_ch, is_image, is_bf16) + + def load_encoder(self, vae_dir: str) -> None: + """ + Load the encoder from the remote store. + """ + self.encoder = torch.jit.load(os.path.join(vae_dir, "encoder.jit")) + + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, vae_dir: str) -> None: + """ + Load the decoder from the remote store. + """ + self.decoder = torch.jit.load(os.path.join(vae_dir, "decoder.jit")) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=True) + + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__( + pixel_chunk_duration, + temporal_compression_factor, + max_enc_batch_size, + max_dec_batch_size, + ) + super(BasePretrainedVideoTokenizer, self).__init__( + name, + latent_ch, + False, + is_bf16, + ) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + + def load_weights(self, vae_dir: str): + # Load for video_vae + self.video_vae.register_mean_std(vae_dir) + self.video_vae.load_decoder(vae_dir) + self.video_vae.load_encoder(vae_dir) + + # Load for image_vae + self.image_vae.register_mean_std(vae_dir) + self.image_vae.load_decoder(vae_dir) + self.image_vae.load_encoder(vae_dir) diff --git a/cosmos_predict1/diffusion/module/timm.py b/cosmos_predict1/diffusion/module/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e5b1fdd15cc11f0aad45aaecbd5c78c5f27ce1 --- /dev/null +++ b/cosmos_predict1/diffusion/module/timm.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch +import torch.nn as nn + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.activation = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/cosmos_predict1/diffusion/modules/denoiser_scaling.py b/cosmos_predict1/diffusion/modules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fb3df0f38d52de317177c248e22707a899beb4 --- /dev/null +++ b/cosmos_predict1/diffusion/modules/denoiser_scaling.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise diff --git a/cosmos_predict1/diffusion/modules/res_sampler.py b/cosmos_predict1/diffusion/modules/res_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..59a4952f69eb1bb1a708ee561941577cca7e5d75 --- /dev/null +++ b/cosmos_predict1/diffusion/modules/res_sampler.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general framework for various sampling algorithm from a diffusion model. +Impl based on +* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 +* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. +Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ + adding new step function in get_runge_kutta_fn or get_multi_step_fn. +""" + +import math +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import attrs +import torch + +from cosmos_predict1.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported +from cosmos_predict1.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported +from cosmos_predict1.utils.config import make_freezable + +COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] + + +@make_freezable +@attrs.define(slots=False) +class SolverConfig: + is_multi: bool = False + rk: str = "2mid" + multistep: str = "2ab" + # following parameters control stochasticity, see EDM paper + # BY default, we use deterministic with no stochasticity + s_churn: float = 0.0 + s_t_max: float = float("inf") + s_t_min: float = 0.05 + s_noise: float = 1.0 + + +@make_freezable +@attrs.define(slots=False) +class SolverTimestampConfig: + nfe: int = 50 + t_min: float = 0.002 + t_max: float = 80.0 + order: float = 7.0 + is_forward: bool = False # whether generate forward or backward timestamps + + +@make_freezable +@attrs.define(slots=False) +class SamplerConfig: + solver: SolverConfig = attrs.field(factory=SolverConfig) + timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) + sample_clean: bool = True # whether run one last step to generate clean image + + +def get_rev_ts( + t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False +) -> torch.Tensor: + """ + Generate a sequence of reverse time steps. + + Args: + t_min (float): The minimum time value. + t_max (float): The maximum time value. + num_steps (int): The number of time steps to generate. + ts_order (Union[int, float]): The order of the time step progression. + is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. + + Returns: + torch.Tensor: A tensor containing the generated time steps in reverse or forward order. + + Raises: + ValueError: If `t_min` is not less than `t_max`. + TypeError: If `ts_order` is not an integer or float. + """ + if t_min >= t_max: + raise ValueError("t_min must be less than t_max") + + if not isinstance(ts_order, (int, float)): + raise TypeError("ts_order must be an integer or float") + + step_indices = torch.arange(num_steps + 1, dtype=torch.float64) + time_steps = ( + t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) + ) ** ts_order + + if is_forward: + return time_steps.flip(dims=(0,)) + + return time_steps + + +class Sampler(torch.nn.Module): + def __init__(self, cfg: Optional[SamplerConfig] = None): + super().__init__() + if cfg is None: + cfg = SamplerConfig() + self.cfg = cfg + + @torch.no_grad() + def forward( + self, + x0_fn: Callable, + x_sigma_max: torch.Tensor, + num_steps: int = 35, + sigma_min: float = 0.002, + sigma_max: float = 80, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, + solver_option: str = "2ab", + ) -> torch.Tensor: + in_dtype = x_sigma_max.dtype + + def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: + return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) + + is_multistep = is_multi_step_fn_supported(solver_option) + is_rk = is_runge_kutta_fn_supported(solver_option) + assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" + + solver_cfg = SolverConfig( + s_churn=S_churn, + s_t_max=S_max, + s_t_min=S_min, + s_noise=S_noise, + is_multi=is_multistep, + rk=solver_option, + multistep=solver_option, + ) + timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) + sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) + + return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) + + @torch.no_grad() + def _forward_impl( + self, + denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noisy_input_B_StateShape: torch.Tensor, + sampler_cfg: Optional[SamplerConfig] = None, + callback_fns: Optional[List[Callable]] = None, + ) -> torch.Tensor: + """ + Internal implementation of the forward pass. + + Args: + denoiser_fn: Function to denoise the input. + noisy_input_B_StateShape: Input tensor with noise. + sampler_cfg: Configuration for the sampler. + callback_fns: List of callback functions to be called during sampling. + + Returns: + torch.Tensor: Denoised output tensor. + """ + sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg + solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) + num_timestamps = sampler_cfg.timestamps.nfe // solver_order + + sigmas_L = get_rev_ts( + sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order + ).to(noisy_input_B_StateShape.device) + + denoised_output = differential_equation_solver( + denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns + )(noisy_input_B_StateShape) + + if sampler_cfg.sample_clean: + # Override denoised_output with fully denoised version + ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) + denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) + + return denoised_output + + +def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: + """ + Implements a for loop with a function. + + Args: + lower: Lower bound of the loop (inclusive). + upper: Upper bound of the loop (exclusive). + body_fun: Function to be applied in each iteration. + init_val: Initial value for the loop. + + Returns: + The final result after all iterations. + """ + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val + + +def differential_equation_solver( + x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + sigmas_L: torch.Tensor, + solver_cfg: SolverConfig, + callback_fns: Optional[List[Callable]] = None, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Creates a differential equation solver function. + + Args: + x0_fn: Function to compute x0 prediction. + sigmas_L: Tensor of sigma values with shape [L,]. + solver_cfg: Configuration for the solver. + callback_fns: Optional list of callback functions. + + Returns: + A function that solves the differential equation. + """ + num_step = len(sigmas_L) - 1 + + if solver_cfg.is_multi: + update_step_fn = get_multi_step_fn(solver_cfg.multistep) + else: + update_step_fn = get_runge_kutta_fn(solver_cfg.rk) + + eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) + + def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: + """ + Samples from the differential equation. + + Args: + input_xT_B_StateShape: Input tensor with shape [B, StateShape]. + + Returns: + Output tensor with shape [B, StateShape]. + """ + ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) + + def step_fn( + i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + input_x_B_StateShape, x0_preds = state + sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] + + # algorithm 2: line 4-6 + if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: + hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 + input_x_B_StateShape = input_x_B_StateShape + ( + hat_sigma_cur_0**2 - sigma_cur_0**2 + ).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) + sigma_cur_0 = hat_sigma_cur_0 + + if solver_cfg.is_multi: + x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds + ) + else: + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn + ) + + if callback_fns: + for callback_fn in callback_fns: + callback_fn(**locals()) + + return output_x_B_StateShape, x0_preds + + x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) + return x_at_eps + + return sample_fn diff --git a/cosmos_predict1/diffusion/networks/__init__.py b/cosmos_predict1/diffusion/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/networks/general_dit.py b/cosmos_predict1/diffusion/networks/general_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c5031e1c189557b7c4cdff7dcdebedf196771d29 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit.py @@ -0,0 +1,569 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +""" + +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.attention import get_normalization +from cosmos_predict1.diffusion.module.blocks import ( + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + TimestepEmbedding, + Timesteps, +) +from cosmos_predict1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from cosmos_predict1.utils import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block. See Notes for supported block types. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + use_cross_attn_mask (bool): Whether to use mask in cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + affline_emb_norm (bool): Whether to normalize affine embeddings. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + + Notes: + Supported block types in block_config: + * cross_attn, ca: Cross attention + * full_attn: Full attention on all flattened tokens + * mlp, ff: Feed forward block + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "learnable", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.affline_emb_norm = affline_emb_norm + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.cp_group = None + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + + for idx in range(num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + + self.build_decode_head() + if self.affline_emb_norm: + log.debug("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to + augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + for _, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + return x_B_D_T_H_W + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + self.pos_embedder.enable_context_parallel(cp_group) + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.enable_context_parallel(cp_group) + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff", "cross_attn", "ca"]: + continue + elif layer.block.attn.backend == "transformer_engine": + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + layer.block.attn.attn_op.cp_stream = None + + log.debug("[CP] Disable context parallelism.") + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None diff --git a/cosmos_predict1/diffusion/networks/general_dit_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..68ddbdb473365e6f4658a36d463f1e57479556e1 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_multiview.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.blocks import GeneralDITTransformerBlock, PatchEmbed +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.module.position_embedding import ( + MultiviewSinCosPosEmbAxis, + MultiviewVideoRopePosition3DEmb, +) +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewGeneralDIT(GeneralDIT): + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + n_views: int = 3, + view_condition_dim: int = 3, + traj_condition_dim: int = 0, + concat_view_embedding: bool = True, + concat_traj_embedding: bool = False, + add_repeat_frame_embedding: bool = False, + ): + self.n_views = n_views + self.view_condition_dim = view_condition_dim + self.concat_view_embedding = concat_view_embedding + self.traj_condition_dim = traj_condition_dim + self.concat_traj_embedding = concat_traj_embedding + self.add_repeat_frame_embedding = add_repeat_frame_embedding + super().__init__( + max_img_h, + max_img_w, + max_frames, + in_channels, + out_channels, + patch_spatial, + patch_temporal, + concat_padding_mask, + block_config, + model_channels, + num_blocks, + num_heads, + mlp_ratio, + block_x_format, + crossattn_emb_channels, + use_cross_attn_mask, + pos_emb_cls, + pos_emb_learnable, + pos_emb_interpolation, + affline_emb_norm, # whether or not to normalize the affine embedding + use_adaln_lora, + adaln_lora_dim, + rope_h_extrapolation_ratio, + rope_w_extrapolation_ratio, + rope_t_extrapolation_ratio, + extra_per_block_abs_pos_emb, + extra_per_block_abs_pos_emb_type, + extra_h_extrapolation_ratio, + extra_w_extrapolation_ratio, + extra_t_extrapolation_ratio, + ) + # reinit self.blocks + del self.blocks + self.blocks = nn.ModuleDict() + for idx in range(self.num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=self.n_views, + ) + self.view_embeddings = nn.Embedding(n_views, view_condition_dim) # Learnable embedding layer + if self.concat_traj_embedding: + self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer + if self.add_repeat_frame_embedding: + self.repeat_frame_embedding = nn.Linear(1, view_condition_dim) # Learnable embedding layer + + self.initialize_weights() + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + view_condition_dim, + traj_condition_dim, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.view_condition_dim, + self.traj_condition_dim, + ) + if self.concat_view_embedding: + in_channels = in_channels + view_condition_dim if view_condition_dim > 0 else in_channels + + if self.concat_traj_embedding: + in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = MultiviewVideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=30, + min_fps=1, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + n_views=self.n_views, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = MultiviewSinCosPosEmbAxis(**kwargs) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + trajectory = kwargs.get("trajectory", None) + frame_repeat = kwargs.get("frame_repeat", None) + + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + trajectory=trajectory, + frame_repeat=frame_repeat, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + trajectory: Optional[torch.Tensor] = None, + frame_repeat: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + + view_indices = torch.arange(self.n_views).to(x_B_C_T_H_W.device) # View indices [0, 1, ..., V-1] + view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] + view_embedding = rearrange(view_embedding, "V D -> D V") + view_embedding = view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) # Shape: [1, D, V, 1, 1, 1] + + if self.add_repeat_frame_embedding: + if frame_repeat is None: + frame_repeat = ( + torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) + .to(view_embedding.device) + .to(view_embedding.dtype) + ) + frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) + frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") + view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) + + x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) + view_embedding = view_embedding.expand( + x_B_C_V_T_H_W.shape[0], + view_embedding.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + if self.concat_traj_embedding: + traj_emb = self.traj_embeddings(trajectory) + traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + traj_emb = traj_emb.expand( + x_B_C_V_T_H_W.shape[0], + traj_emb.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) + else: + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) + + x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb diff --git a/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..e19f20cc9cf95a1ba0e474d9d6fe1bffd413c18e --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from torch import nn + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.blocks import TimestepEmbedding, Timesteps +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + Timesteps(self.model_channels), + TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def initialize_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().initialize_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass of the video-conditioned DIT model. + + Args: + x: Input tensor of shape (B, C, T, H, W) + timesteps: Timestep tensor of shape (B,) + crossattn_emb: Cross attention embeddings of shape (B, N, D) + crossattn_mask: Optional cross attention mask of shape (B, N) + fps: Optional frames per second tensor + image_size: Optional image size tensor + padding_mask: Optional padding mask tensor + scalar_feature: Optional scalar features tensor + data_type: Type of data being processed (default: DataType.VIDEO) + video_cond_bool: Optional video conditioning boolean tensor + condition_video_indicator: Optional video condition indicator tensor + condition_video_input_mask: Required mask tensor for video data type + condition_video_augment_sigma: Optional sigma values for conditional input augmentation + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Output tensor + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" + + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..702d72812b80d2327700203d021b80601a51b944 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewVideoExtendGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/prompt_upsampler/inference.py b/cosmos_predict1/diffusion/prompt_upsampler/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..494c9b882cafdbd7a7969966e8a51a97a759d1a1 --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/inference.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, TypedDict + +import torch + +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer + + +class ChatPrediction(TypedDict, total=False): + tokens: List[str] # not required + logprobs: List[float] # not required + + +def chat_completion( + model: AutoRegressiveModel, + dialogs: List, + seed: int = None, + temperature: float = 0.01, + top_k: int = None, + top_p: float = None, + max_gen_len: Optional[int] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + generation_prefix: str = "", + compile_sampling: bool = False, + compile_prefill: bool = False, + stop_tokens=None, + verbose: bool = False, +) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + model (AutoRegressiveModel): The language generation model. + dialogs (List): List of conversational dialogs, where each dialog is a list of messages. + NOTE if you are using a VLM, all dialogs must either all have images ("image" field) or all be pure text. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.01. + top_k (int, optional): Top-k probability threshold for nucleus sampling. Defaults to None. If not None, top-p sampling is ignored. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. If not None, top-k sampling is ignored. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + num_gen_seq (int, optional): Number of sequences to generate per prompt. Defaults to 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + generation_prefix (str, optional): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + compile_sampling (bool, optional): Flag indicating whether to compile the generation function. Defaults to False. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + stop_tokens (Set[int], optional): Set of tokens to stop generation. Defaults to None. If not None, it will override the model's stop tokens. + verbose (bool, optional): Flag indicating whether to print the generation throughput. Defaults to False. + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + """ + if max_gen_len is None: + max_gen_len = model.model.params.max_seq_len - 1 + images = None + if isinstance(model.tokenizer.text_tokenizer, ImageTextTokenizer): + # Vision-language model + prompt_dicts = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + prompt_tokens = [prompt_dict["input_ids"] for prompt_dict in prompt_dicts] + num_images = sum(["pixel_values" in prompt_dict for prompt_dict in prompt_dicts]) + assert num_images in [0, len(dialogs)], "For VLM, all dialogs must either all have images or all be pure text." + if num_images > 0: + images = torch.cat([prompt_dict["pixel_values"] for prompt_dict in prompt_dicts], dim=0) + else: + images = None + elif isinstance(model.tokenizer.text_tokenizer, TextTokenizer): + # Text-only model + prompt_tokens = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + else: + prompt_tokens = [model.formatter.encode_dialog_prompt(dialog) for dialog in dialogs] + + generation_tokens, generation_logprobs = model.generate( + prompt_tokens=prompt_tokens, + seed=seed, + max_gen_len=max_gen_len, + num_gen_seq=num_gen_seq, + temperature=temperature, + top_k=top_k, + top_p=top_p, + compile_sampling=compile_sampling, + compile_prefill=compile_prefill, + stop_tokens=stop_tokens, + verbose=verbose, + images=images, + ) + + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + "tokens": [model.tokenizer.text_tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + } + for t in generation_tokens + ] diff --git a/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py b/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..16076d1afc06c5f217bfe9f65663cb829b534ca2 --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Cosmos-UpsamplePrompt1-12B-Text2World. +Command: + CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py + +""" +import argparse +import os +import re + +from cosmos_predict1.autoregressive.configs.base.model_config import create_text_model_config +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.diffusion.prompt_upsampler.inference import chat_completion +from cosmos_predict1.utils import log + + +def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel: + model_config, tokenizer_config = create_text_model_config( + model_ckpt_path=os.path.join(checkpoint_dir, "model.pt"), + tokenizer_path=os.path.join(checkpoint_dir), + model_family="mistral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + rope_dim="1D", + add_special_tokens=True, + max_seq_len=1024, + pytorch_rope_version="v1", + ) + log.debug(f"Text prompt upsampler model config: {model_config}") + + # Create and return a LLM instance + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def run_chat_completion(model: AutoRegressiveModel, input: str, temperature: float = 0.01): + """ + text2world prompt upsampler model is finetuned for chat. + During training, the context window for the initial prompt upsampler models is 512 tokens. For inference, we set max_seq_len to 1024 to accommodate longer inputs. + Setting `max_gen_len` is optional as the finetuned models can naturally determine when to stop generating. + """ + + dialogs = [[{"role": "user", "content": f"Upsample the short caption to a long caption: {str(input)}"}]] + + results = chat_completion( + model, + dialogs, + max_gen_len=512, + temperature=temperature, + top_p=None, + top_k=None, + logprobs=False, + ) + upsampled_prompt = str(clean_text(results[0]["generation"]["content"])) + return upsampled_prompt + + +def clean_text(text: str) -> str: + """Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace.""" + # Replace all variations of newlines with a space + text = text.replace("\n", " ").replace("\r", " ") + + # Use a regex to find sections of the form '- **...**' + pattern = r"(- \*\*)(.*?)(\*\*)" + + def replacement(match: re.Match[str]) -> str: + content = match.group(2) # The text inside - ** and ** + words = re.findall(r"\w+", content) + if len(words) < 10: + # If fewer than 10 words, remove the entire '- **...**' portion + return "" + else: + # If 10 or more words, keep the entire section as it is + return match.group(0) + + text = re.sub(pattern, replacement, text) + + # Remove common prefixes + prefixes = ["Caption:", "#####", "####", "- ", "* ", ","] + for prefix in prefixes: + # lstrip(prefix) won't strip entire strings, but character sets. + # For more reliable prefix removal, do: + if text.startswith(prefix): + text = text[len(prefix) :].lstrip() + + # Remove extra spaces + text = " ".join(text.split()) + + # Strip any remaining leading/trailing punctuation, whitespace, and quotes + text = text.strip(' -,*:"\'"“”') + + return text + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument("--input", type=str, default="A dog is playing with a ball.") + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-UpsamplePrompt1-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner(args.checkpoint_dir) + is_safe = guardrail_presets.run_text_guardrail(args.input, guardrail_runner) + if not is_safe: + log.critical("Input text prompt is not safe.") + return + + prompt_upsampler = create_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + upsampled_prompt = run_chat_completion(prompt_upsampler, args.input, temperature=args.temperature) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py b/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f40c6cbfbfc078fb891bd75dc0f04236c12ced --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Pixtral-12B. +Command: + CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py + +""" + +import argparse +import os +from math import ceil + +from PIL import Image + +from cosmos_predict1.autoregressive.configs.base.model_config import create_vision_language_model_config +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.diffusion.prompt_upsampler.inference import chat_completion +from cosmos_predict1.utils import log +from cosmos_predict1.utils.io import load_from_fileobj + + +def create_vlm_prompt_upsampler( + checkpoint_dir: str, tokenizer_ckpt_path: str = "mistral-community/pixtral-12b" +) -> AutoRegressiveModel: + """ + Load the fine-tuned pixtral model for SimReady. + If pixtral_ckpt is not provided, use the pretrained checkpoint. + """ + model_ckpt_path = os.path.join(checkpoint_dir, "model.pt") + model_config, tokenizer_config = create_vision_language_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_family="pixtral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + max_seq_len=4300, + pytorch_rope_version="v1", + ) + # during instantiate, the weights will be downloaded (if not already cached) and loaded + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def resize_image(image: Image.Image, max_size: int = 1024) -> Image.Image: + """ + Ensure that the image is no larger than max_size in both dimensions. + """ + image_width, image_height = image.size + max_width, max_height = max_size, max_size + ratio = max(image_width / max_width, image_height / max_height) + if ratio > 1: + image = image.resize((ceil(image_width / ratio), ceil(image_height / ratio))) + return image + + +def prepare_dialog(image_or_video_path: str) -> list[dict]: + if image_or_video_path.endswith(".mp4"): + video_np, _ = load_from_fileobj(image_or_video_path, format="mp4") + image_frame = video_np[-1] + image = Image.fromarray(image_frame) + else: + image: Image.Image = Image.open(image_or_video_path) + + image = resize_image(image, max_size=1024) + prompt = """\ +Your task is to transform a given prompt into a refined and concise video description, no more than 150 words. +Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video. + """.strip() + + return [ + { + "role": "user", + "content": "[IMG]\n" + prompt, + "images": [image], + } + ] + + +def run_chat_completion(pixtral: AutoRegressiveModel, dialog: list[dict], **inference_args) -> str: + default_args = { + "max_gen_len": 400, + "temperature": 0, + "top_p": 0.9, + "logprobs": False, + "compile_sampling": False, + "compile_prefill": False, + } + default_args.update(inference_args) + results = chat_completion( + pixtral, + [dialog], + **default_args, + ) + assert len(results) == 1 + upsampled_prompt = str(results[0]["generation"]["content"]) + return upsampled_prompt + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument("--image_or_video_path", type=str, default="assets/diffusion/video2world_input0.jpg") + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p value for top-p sampling") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner(args.checkpoint_dir) + + pixtral = create_vlm_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + dialog = prepare_dialog(args.image_or_video_path) + upsampled_prompt = run_chat_completion( + pixtral, + dialog, + max_gen_len=400, + temperature=args.temperature, + top_p=args.top_p, + logprobs=False, + ) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/diffusion/training/callbacks/every_n.py b/cosmos_predict1/diffusion/training/callbacks/every_n.py new file mode 100644 index 0000000000000000000000000000000000000000..25cab309a58336867ed5fc58849e71db7611d0f3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/every_n.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.callback import Callback +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class EveryN(Callback): + def __init__( + self, + every_n: Optional[int] = None, + step_size: int = 1, + barrier_after_run: bool = True, + run_at_start: bool = False, + ) -> None: + """Constructor for `EveryN`. + + Args: + every_n (int): Frequency with which callback is run during training. + step_size (int): Size of iteration step count. Default 1. + barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. + run_at_start (bool): Whether to run at the beginning of training. Default False. + """ + self.every_n = every_n + if self.every_n == 0: + log.warning( + f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." + ) + + self.step_size = step_size + self.barrier_after_run = barrier_after_run + self.run_at_start = run_at_start + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training + if self.every_n != 0: + trainer = self.trainer + global_step = iteration // self.step_size + should_run = (iteration == 1 and self.run_at_start) or ( + global_step % self.every_n == 0 + ) # (self.every_n - 1) + if should_run: + log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") + self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) + log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") + # add necessary barrier to avoid timeout + if self.barrier_after_run: + distributed.barrier() + + @abstractmethod + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + ... diff --git a/cosmos_predict1/diffusion/training/callbacks/grad_clip.py b/cosmos_predict1/diffusion/training/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..8c060fc1a7514734cd3e429f5e3aba7c4c57f7e0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/grad_clip.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import torch +from megatron.core import parallel_state +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callbacks.grad_clip import GradClip as GradClipImage +from cosmos_predict1.utils.callbacks.grad_clip import _fused_nan_to_num +from cosmos_predict1.utils.model import Model + + +@dataclass +class _MagnitudeRecord: + state: float = 0 + iter_count: int = 0 + + def reset(self) -> None: + self.state = 0 + self.iter_count = 0 + + def update(self, cur_state: torch.Tensor) -> None: + self.state += cur_state + self.iter_count += 1 + + def get_stat(self) -> Tuple[float, float]: + if self.iter_count > 0: + avg_state = self.state / self.iter_count + avg_state = avg_state.item() + else: + avg_state = 0 + self.reset() + return avg_state + + +class GradClip(GradClipImage): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.img_mag_log = _MagnitudeRecord() + self.video_mag_log = _MagnitudeRecord() + self._cur_state = None + + def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: + if model.is_image_batch(data_batch): + self._cur_state = self.img_mag_log + else: + self._cur_state = self.video_mag_log + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + params = [] + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + if self.force_finite: + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + if isinstance(model, FSDP) and self.fsdp_enabled: + total_norm = model.clip_grad_norm_(self.clip_norm) + else: + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm) + else: + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) + + self._cur_state.update(total_norm) diff --git a/cosmos_predict1/diffusion/training/callbacks/iter_speed.py b/cosmos_predict1/diffusion/training/callbacks/iter_speed.py new file mode 100644 index 0000000000000000000000000000000000000000..371a227c6db390c0e3764ad6bb3e278bdd1ae866 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/iter_speed.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +from torch import Tensor + +from cosmos_predict1.diffusion.training.callbacks.every_n import EveryN +from cosmos_predict1.utils import log +from cosmos_predict1.utils.distributed import rank0_only +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class IterSpeed(EveryN): + """ + Args: + hit_thres (int): Number of iterations to wait before logging. + """ + + def __init__(self, *args, hit_thres: int = 5, **kwargs): + super().__init__(*args, **kwargs) + self.time = None + self.hit_counter = 0 + self.hit_thres = hit_thres + self.name = self.__class__.__name__ + self.last_hit_time = time.time() + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + if self.hit_counter < self.hit_thres: + log.info( + f"Iteration {iteration}: " + f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | " + f"Loss: {loss.item():.4f} | " + f"Time: {time.time() - self.last_hit_time:.2f}s" + ) + self.hit_counter += 1 + self.last_hit_time = time.time() + #! useful for large scale training and avoid oom crash in the first two iterations!!! + torch.cuda.synchronize() + return + super().on_training_step_end(model, data_batch, output_batch, loss, iteration) + + @rank0_only + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, Tensor], + output_batch: dict[str, Tensor], + loss: Tensor, + iteration: int, + ) -> None: + if self.time is None: + self.time = time.time() + return + cur_time = time.time() + iter_speed = (cur_time - self.time) / self.every_n / self.step_size + + log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}") + + self.time = cur_time diff --git a/cosmos_predict1/diffusion/training/callbacks/low_precision.py b/cosmos_predict1/diffusion/training/callbacks/low_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..faf9562c413f33f40d01b69e4bb01883283f0895 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/low_precision.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cosmos_predict1.diffusion.training.trainer import Trainer +from cosmos_predict1.utils.callback import LowPrecisionCallback as BaseCallback +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.model import Model + + +class LowPrecisionCallback(BaseCallback): + """ + Config with non-primitive type makes it difficult to override the option. + The callback gets precision from model.precision instead. + """ + + def __init__(self, config: Config, trainer: Trainer, update_iter: int): + self.config = config + self.trainer = trainer + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision diff --git a/cosmos_predict1/diffusion/training/conditioner.py b/cosmos_predict1/diffusion/training/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..fee7bfb14ad9d0f9346aed41304401d6383c8d2f --- /dev/null +++ b/cosmos_predict1/diffusion/training/conditioner.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from cosmos_predict1.diffusion.conditioner import GeneralConditioner +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.utils.misc import count_params + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + MIX = "mix" + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + + self._is_trainable = None + self._dropout_rate = None + self._input_key = None + self._return_dict = False + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def dropout_rate(self) -> Union[float, torch.Tensor]: + return self._dropout_rate + + @property + def input_key(self) -> str: + return self._input_key + + @property + def is_return_dict(self) -> bool: + return self._return_dict + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @dropout_rate.setter + def dropout_rate(self, value: Union[float, torch.Tensor]): + self._dropout_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_return_dict.setter + def is_return_dict(self, value: bool): + self._return_dict = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @dropout_rate.deleter + def dropout_rate(self): + del self._dropout_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + @is_return_dict.deleter + def is_return_dict(self): + del self._return_dict + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + return batch_mul( + torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), + in_tensor, + ) + + def details(self) -> str: + return "" + + def summary(self) -> str: + input_key = self.input_key if self.input_key is not None else getattr(self, "input_keys", None) + return ( + f"{self.__class__.__name__} \n\tinput key: {input_key}" + f"\n\tParam count: {count_params(self, False)} \n\tTrainable: {self.is_trainable}" + f"\n\tDropout rate: {self.dropout_rate}" + f"\n\t{self.details()}" + ) + + +class TrajectoryAttr(AbstractEmbModel): + def __init__(self, traj_dim: int): + super().__init__() + self.traj_dim = traj_dim + + def forward(self, traj: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "trajectory": traj, + } + + def details(self) -> str: + return f"Traj dim : {self.traj_dim} \n\tOutput key: [trajectory]" + + +class FrameRepeatAttr(AbstractEmbModel): + def __init__(self): + super().__init__() + + def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "frame_repeat": frame_repeat / 10.0, + } + + def details(self) -> str: + return "Frame repeat, Output key: [frame_repeat]" + + +@dataclass +class BaseVideoCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + data_type: DataType = DataType.VIDEO + padding_mask: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + image_size: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + trajectory: Optional[torch.Tensor] = None + frame_repeat: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +@dataclass +class VideoExtendCondition(BaseVideoCondition): + video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video + gt_latent: Optional[torch.Tensor] = None + condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region + + # condition_video_input_mask will concat to the input of network, along channel dim; + # Will be concat with the input tensor + condition_video_input_mask: Optional[torch.Tensor] = None + # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + condition_video_augment_sigma: Optional[torch.Tensor] = None + # pose conditional input, will be concat with the input tensor + condition_video_pose: Optional[torch.Tensor] = None + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + def get_condition_for_cp(self, cp_group): + self.latent_condition = split_inputs_cp(x=self.latent_condition, seq_dim=2, cp_group=cp_group) + self.latent_condition_sigma = split_inputs_cp(x=self.latent_condition_sigma, seq_dim=2, cp_group=cp_group) + + +class VideoConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoDiffusionDecoderConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoLatentDiffusionDecoderCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoLatentDiffusionDecoderCondition(**output) + + +class VideoExtendConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) + + +class VideoConditionerWithTraingOnlyEmb(GeneralConditioner): + def get_condition_uncondition( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Processes the provided data batch to generate two sets of outputs: conditioned and unconditioned. This method + manipulates the dropout rates of embedders to simulate two scenarios — one where all conditions are applied + (conditioned), and one where they are removed or reduced to the minimum (unconditioned). + + This method first sets the dropout rates to zero for the conditioned scenario to fully apply the embedders' effects. + For the unconditioned scenario, it sets the dropout rates to 1 (or to 0 if the initial unconditional dropout rate + is insignificant) to minimize the embedders' influences, simulating an unconditioned generation. + + Parameters: + data_batch (Dict): The input data batch that contains all necessary information for embedding processing. The + data is expected to match the required format and keys expected by the embedders. + + Returns: + Tuple[Any, Any]: A tuple containing two condition: + - The first one contains the outputs with all embedders fully applied (conditioned outputs). + - The second one contains the outputs with embedders minimized or not applied (unconditioned outputs). + """ + cond_dropout_rates, dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + if isinstance(embedder, FrameRepeatAttr): + cond_dropout_rates[emb_name] = 1.0 + else: + cond_dropout_rates[emb_name] = 0.0 + dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) + return condition, un_condition + + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoExtendConditionerWithTraingOnlyEmb(VideoConditionerWithTraingOnlyEmb): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) + + +@dataclass +class BaseWithCtrlCondition(VideoExtendCondition): + control_input_canny: Optional[torch.Tensor] = None + control_input_blur: Optional[torch.Tensor] = None + control_input_canny_blur: Optional[torch.Tensor] = None + control_input_depth: Optional[torch.Tensor] = None + control_input_segmentation: Optional[torch.Tensor] = None + control_input_depth_segmentation: Optional[torch.Tensor] = None + control_input_mask: Optional[torch.Tensor] = None + control_input_human_kpts: Optional[torch.Tensor] = None + control_input_upscale: Optional[torch.Tensor] = None + control_input_identity: Optional[torch.Tensor] = None + control_input_multi: Optional[torch.Tensor] = None + base_model: Optional[torch.nn.Module] = None + hint_key: Optional[str] = None + control_weight: Optional[float] = 1.0 + num_layers_to_use: Optional[int] = -1 + + +class VideoConditionerWithCtrl(VideoExtendConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseWithCtrlCondition: + output = super()._forward(batch, override_dropout_rate) + output["hint_key"] = batch["hint_key"] + if "control_weight" in batch: + output["control_weight"] = batch["control_weight"] + if "num_layers_to_use" in batch: + output["num_layers_to_use"] = batch["num_layers_to_use"] + return BaseWithCtrlCondition(**output) + + +class BooleanFlag(AbstractEmbModel): + def __init__(self, output_key: Optional[str] = None): + super().__init__() + self.output_key = output_key + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + del args, kwargs + key = self.output_key if self.output_key else self.input_key + return {key: self.flag} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) + return in_tensor + + def details(self) -> str: + key = self.output_key if self.output_key else self.input_key + return f"Output key: {key} \n\t This is a boolean flag" diff --git a/cosmos_predict1/diffusion/training/config/base/ema.py b/cosmos_predict1/diffusion/training/config/base/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..30eea4fe92403590d6812403e73b46ff8d4bded4 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/ema.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.ema import EMAModelTracker, PowerEMATracker +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.10, num=3 +) + +RegEMAConfig: LazyDict = L(EMAModelTracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.999, num=1 +) diff --git a/cosmos_predict1/diffusion/training/config/base/model.py b/cosmos_predict1/diffusion/training/config/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..47e0b76154805d61876acc486ccc5e13eb7c7fdb --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/model.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.ema import PowerEMAConfig +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class FSDPConfig: + policy: str = "block" + checkpoint: bool = False + min_num_params: int = 1024 + sharding_group_size: int = 8 + sharding_strategy: str = "full" + + +@attrs.define(slots=False) +class DefaultModelConfig: + vae: LazyDict = None + conditioner: LazyDict = None + net: LazyDict = None + ema: LazyDict = PowerEMAConfig + sde: LazyDict = L(EDMSDE)( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ) + sigma_data: float = 0.5 + camera_sample_weight: LazyDict = LazyDict( + dict( + enabled=False, + weight=5.0, + ) + ) + aesthetic_finetuning: LazyDict = LazyDict( + dict( + enabled=False, + ) + ) + loss_mask_enabled: bool = False + loss_masking: LazyDict = None + loss_add_logvar: bool = True + precision: str = "bfloat16" + input_data_key: str = "video" # key to fetch input data from data_batch + input_image_key: str = "images_1024" # key to fetch input image from data_batch + loss_reduce: str = "sum" + loss_scale: float = 1.0 + latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + fsdp_enabled: bool = False + use_torch_compile: bool = False + fsdp: FSDPConfig = attrs.field(factory=FSDPConfig) + use_dummy_temporal_dim: bool = False # Whether to use dummy temporal dimension in data + adjust_video_noise: bool = False # whether or not adjust video noise accroding to the video length + peft_control: LazyDict | None = None + + +@attrs.define(slots=False) +class MultiviewModelConfig(DefaultModelConfig): + n_views: int = 6 diff --git a/cosmos_predict1/diffusion/training/config/base/optim.py b/cosmos_predict1/diffusion/training/config/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6ddb236d2eb97576e521ca4308166e628047ed --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/optim.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.functional.lr_scheduler import LambdaLinearScheduler +from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_optimizer +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FusedAdamWConfig: LazyDict = L(get_base_optimizer)( + model=PLACEHOLDER, + lr=1e-4, + weight_decay=0.3, + betas=[0.9, 0.999], + optim_type="fusedadam", + eps=1e-8, + sharding=False, + master_weights=True, + capturable=True, +) + +LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( + warm_up_steps=[1000], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], +) diff --git a/cosmos_predict1/diffusion/training/config/base/vae.py b/cosmos_predict1/diffusion/training/config/base/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..8d87a1f78e0522a54cb4dc7772235b36856d6ad9 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/vae.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from cosmos_predict1.diffusion.training.module.pretrained_vae import VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_tokenizer_comp8x8x8( + resolution: str, + chunk_duration: int, +) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["512", "720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(VideoJITTokenizer)( + name="cosmos_diffusion_tokenizer_comp8x8x8", + enc_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + dec_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + mean_std_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ) diff --git a/cosmos_predict1/diffusion/training/config/config.py b/cosmos_predict1/diffusion/training/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..38e82f6f0f6e3b52dc66f960dc0ed019e7b381c1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/config.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.model import DefaultModelConfig +from cosmos_predict1.diffusion.training.config.text2world.registry import ( + register_configs as register_configs_text2world, +) +from cosmos_predict1.diffusion.training.config.video2world.registry import ( + register_configs as register_configs_video2world, +) +from cosmos_predict1.diffusion.training.config.video2world_action.registry import ( + register_configs as register_configs_video2world_action, +) +from cosmos_predict1.diffusion.training.config.video2world_instruction.registry import ( + register_configs as register_configs_video2world_instruction, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.trainer import Trainer + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": None}, + {"data_val": None}, + {"optimizer": "fusedadamw"}, + {"scheduler": "lambdalinear"}, + {"callbacks": None}, + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"fsdp": None}, + {"ema": "power"}, + {"vae": "vae1"}, + {"checkpoint": "pbss"}, + {"ckpt_klass": "fsdp"}, + # the list is with order, we need global experiment to be the last one + {"experiment": None}, + ] + ) + model_obj: LazyDict = L(DiffusionModel)( + config=PLACEHOLDER, + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig(), + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_predict1" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + c.trainer.max_iter = 400_000 + c.trainer.logging_iter = 10 + c.trainer.validation_iter = 100 + c.trainer.run_validation = False + c.trainer.callbacks = None + + c.checkpoint = None + + # Call this function to register config groups. + register_configs_text2world() + register_configs_video2world() + register_configs_video2world_instruction() + register_configs_video2world_action() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_instruction", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_action", reload=True) + + return c diff --git a/cosmos_predict1/diffusion/training/config/config_multiview.py b/cosmos_predict1/diffusion/training/config/config_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..fabf37c5643d0d4c1f1d1a8d90df7eacae511963 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/config_multiview.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.model import MultiviewModelConfig +from cosmos_predict1.diffusion.training.config.text2world.registry import ( + register_configs as register_configs_text2world, +) +from cosmos_predict1.diffusion.training.config.text2world_multiview.registry import ( + register_configs as register_configs_text2world_multiview, +) +from cosmos_predict1.diffusion.training.config.video2world.registry import ( + register_configs as register_configs_video2world, +) +from cosmos_predict1.diffusion.training.config.video2world_multiview.registry import ( + register_configs as register_configs_video2world_multiview, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.trainer import Trainer + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": None}, + {"data_val": None}, + {"optimizer": "fusedadamw"}, + {"scheduler": "lambdalinear"}, + {"callbacks": None}, + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"fsdp": None}, + {"ema": "power"}, + {"vae": "vae1"}, + {"checkpoint": "pbss"}, + {"ckpt_klass": "fsdp"}, + # the list is with order, we need global experiment to be the last one + {"experiment": None}, + ] + ) + model_obj: LazyDict = L(DiffusionModel)( + config=PLACEHOLDER, + ) + + +def make_config(): + c = Config( + model=MultiviewModelConfig(), + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_predict1" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + # c.trainer.straggler_detection.enabled = False + c.trainer.max_iter = 400_000 + c.trainer.logging_iter = 10 + c.trainer.validation_iter = 100 + c.trainer.run_validation = False + c.trainer.callbacks = None + + c.checkpoint = None + + # Call this function to register config groups. + register_configs_text2world() + register_configs_video2world() + register_configs_text2world_multiview() + register_configs_video2world_multiview() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world_multiview", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_multiview", reload=True) + + return c diff --git a/cosmos_predict1/diffusion/training/config/text2world/experiment.py b/cosmos_predict1/diffusion/training/config/text2world/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..5e67f0585f515472185ef545b272376da27374df --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world/experiment.py @@ -0,0 +1,1020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.model import FSDPDiffusionModel +from cosmos_predict1.diffusion.training.models.model_peft import PEFTVideoDiffusionModel +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +n_length = 15 +num_frames = 8 * n_length + 1 # 121 + +# HDVILA example +example_video_dataset_hdvila = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets example +example_video_dataset_cosmos_nemo_assets = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets 480x848 example for lora +example_video_dataset_cosmos_nemo_assets_480_848 = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(480, 848), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets examples with more affordable GPUs setup (4 GPUs or 40GB VRAM) +n_length_4gpu_80gb = 15 +num_frames_4gpu_80gb = 8 * n_length_4gpu_80gb + 1 # 121 +example_video_dataset_cosmos_nemo_assets_4gpu_80gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_80gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering the content aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_8gpu_40gb = 4 +num_frames_8gpu_40gb = 8 * n_length_8gpu_40gb + 1 # 33 +example_video_dataset_cosmos_nemo_assets_8gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_8gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_4gpu_40gb = 2 +num_frames_4gpu_40gb = 8 * n_length_4gpu_40gb + 1 # 17 +example_video_dataset_cosmos_nemo_assets_4gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + + +text2world_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + grad_accum_iter=2, + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_hdvila, + dataloader_val=dataloader_val_hdvila, + ) +) + + +text2world_14b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_14b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-16), + weight_decay=0.2, + betas=[0.9, 0.99], + eps=1e-11, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-14B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=8, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + num=1, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=64, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + use_memory_save=True, + ), + adjust_video_noise=True, + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[90_000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1e-1], + ), + dataloader_train=dataloader_train_hdvila, + dataloader_val=dataloader_val_hdvila, + ) +) + +text2world_7b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + dataloader_val=dataloader_val_cosmos_nemo_assets, + ) +) + +text2world_7b_example_cosmos_nemo_assets_4gpu_80gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_4gpu_80gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_80gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_80gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_4gpu_80gb, + ) +) + +text2world_7b_example_cosmos_nemo_assets_8gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_8gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_8gpu_40gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_8gpu_40gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_8gpu_40gb, + ) +) + +text2world_7b_example_cosmos_nemo_assets_4gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_4gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_40gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_40gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_4gpu_40gb, + ) +) + + +text2world_14b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_14b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-16), + weight_decay=0.2, + betas=[0.9, 0.99], + eps=1e-11, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-14B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=16, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + num=1, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=64, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + use_memory_save=True, + ), + adjust_video_noise=True, + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[90_000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1e-1], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + dataloader_val=dataloader_val_cosmos_nemo_assets, + ) +) + +text2world_7b_lora_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "peft"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_lora_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=1e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=1000, + broadcast_via_filesystem=True, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, + ), + trainer=dict( + max_iter=5000, + distributed_parallelism="ddp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=False, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=4, + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1), + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=False, + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(PEFTVideoDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + scheduler=dict( + warm_up_steps=[0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_480_848, + dataloader_val=dataloader_val_cosmos_nemo_assets_480_848, + ) +) + + +def register_experiments(cs: ConfigStore) -> None: + # Register the experiments + for _item in [ + text2world_7b_example_hdvila, + text2world_14b_example_hdvila, + text2world_7b_example_cosmos_nemo_assets, + text2world_14b_example_cosmos_nemo_assets, + text2world_7b_example_cosmos_nemo_assets_4gpu_80gb, + text2world_7b_example_cosmos_nemo_assets_8gpu_40gb, + text2world_7b_example_cosmos_nemo_assets_4gpu_40gb, + text2world_7b_lora_example_cosmos_nemo_assets, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/text2world/registry.py b/cosmos_predict1/diffusion/training/config/text2world/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..4a292447ec51a1c9a710d9c8591803f41a59464c --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world/registry.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Dict + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.checkpointer.peft_checkpointer import Checkpointer as PEFTCheckpointer +from cosmos_predict1.diffusion.checkpointers.ema_fsdp_checkpointer import CheckpointConfig, FSDPCheckpointer +from cosmos_predict1.diffusion.conditioner import VideoExtendConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, + VideoCondBoolConfig, +) +from cosmos_predict1.diffusion.training.conditioner import VideoConditioner +from cosmos_predict1.diffusion.training.config.base.optim import FusedAdamWConfig, LambdaLinearSchedulerConfig +from cosmos_predict1.diffusion.training.config.base.vae import get_cosmos_tokenizer_comp8x8x8 +from cosmos_predict1.diffusion.training.config.text2world.experiment import register_experiments +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.utils.ema import PowerEMATracker +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FSDP_CHECKPOINTER: Dict[str, str] = L(FSDPCheckpointer)() +PEFT_CHECKPOINTER: Dict[str, str] = L(PEFTCheckpointer)() +VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), +) + + +VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), +) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond", + node=VideoExtendConditionerConfig, + ) + + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingConfig, + ) + + +def register_checkpoint_credential(cs): + CHECKPOINT_LOCAL = CheckpointConfig( + save_iter=1000, + load_path="", + load_training_state=False, + strict_resume=True, + ) + + cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) + + +def register_checkpointer(cs): + cs.store(group="ckpt_klass", package="checkpoint.type", name="fsdp", node=FSDP_CHECKPOINTER) + cs.store(group="ckpt_klass", package="checkpoint.type", name="peft", node=PEFT_CHECKPOINTER) + + +FADITV2Config: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + spatial_attn_win_size=1, + temporal_attn_win_size=1, + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + additional_timestamp_channels=None, + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, + legacy_patch_emb=False, +) + +FADITV2_14B_Config = copy.deepcopy(FADITV2Config) +FADITV2_14B_Config.model_channels = 5120 +FADITV2_14B_Config.num_heads = 40 +FADITV2_14B_Config.num_blocks = 36 + + +def register_net(cs): + cs.store(group="net", package="model.net", name="faditv2_7b", node=FADITV2Config) + cs.store(group="net", package="model.net", name="faditv2_14b", node=FADITV2_14B_Config) + + +def register_vae(cs): + cs.store( + group="vae", + package="model.vae", + name="cosmos_diffusion_tokenizer_comp8x8x8", + node=get_cosmos_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + + +PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.10, num=3 +) + + +def register_ema(cs): + cs.store(group="ema", package="model.ema", name="power", node=PowerEMAConfig) + + +def register_optimizer(cs): + cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig) + + +def register_scheduler(cs): + cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) + + +def register_configs(): + cs = ConfigStore.instance() + + register_optimizer(cs) + register_scheduler(cs) + + register_net(cs) + register_conditioner(cs) + register_vae(cs) + + register_ema(cs) + + register_checkpoint_credential(cs) + register_checkpointer(cs) + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py b/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9b9a213f89879cc8cfe8e5854743e86bbeec33 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_multiview import Dataset +from cosmos_predict1.diffusion.training.models.model_multiview import FSDPDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 57 +num_views = 5 +view_keys = ["pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right"] +example_multiview_dataset_waymo = L(Dataset)( + dataset_dir="datasets/waymo", + sequence_interval=1, + num_frames=num_frames, + view_keys=view_keys, + video_size=(480, 848), +) + + +text2world_multiview_7b_example_waymo = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_multiview_7b_example_waymo", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # broadcast_via_filesystem=True, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=200, + hit_thres=5, + ), + # manual_gc=L(ManualGarbageCollection)(every_n=5), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + n_views=num_views, + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(MultiviewGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + n_views=num_views, + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + dataloader_val=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + text2world_multiview_7b_example_waymo, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py b/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c12886435fede2a3ab1a9d6f5fafda1c8ef4f019 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.text2world_multiview.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world/experiment.py b/cosmos_predict1/diffusion/training/config/video2world/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..0817cf30119c6229cef2e8bfb0e3960117664986 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world/experiment.py @@ -0,0 +1,846 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model_peft import PEFTExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +n_length = 15 +num_frames = 8 * n_length + 1 # 121 + +# HDVILA example +example_video_dataset_hdvila = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +# Cosmos-NeMo-Assets example +example_video_dataset_cosmos_nemo_assets = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +# Cosmos-NeMo-Assets examples with more affordable GPUs setup (4 GPUs or 40GB VRAM) +n_length_4gpu_80gb = 15 +num_frames_4gpu_80gb = 8 * n_length_4gpu_80gb + 1 # 121 +example_video_dataset_cosmos_nemo_assets_4gpu_80gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_80gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering the content aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_8gpu_40gb = 3 +num_frames_8gpu_40gb = 8 * n_length_8gpu_40gb + 1 # 25 +example_video_dataset_cosmos_nemo_assets_8gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_8gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_4gpu_40gb = 3 +num_frames_4gpu_40gb = 8 * n_length_4gpu_40gb + 1 # 25 +example_video_dataset_cosmos_nemo_assets_4gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_40gb, + video_size=(192, 192), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=0, + pin_memory=True, +) + +# Cosmos-NeMo-Assets 480x848 example for lora +example_video_dataset_cosmos_nemo_assets_480_848 = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(480, 848), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +dataloader_val_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +video2world_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_hdvila, + ) +) + + +video2world_7b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + ) +) + +video2world_7b_example_cosmos_nemo_assets_4gpu_80gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_4gpu_80gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_80gb, + spatial_resolution="384", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_80gb, + ) +) + +video2world_7b_example_cosmos_nemo_assets_8gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_8gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_8gpu_40gb, + spatial_resolution="384", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_8gpu_40gb, + ) +) + +video2world_7b_example_cosmos_nemo_assets_4gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_4gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 24, # Latent height dim + 24, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_40gb, + spatial_resolution="192", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_40gb, + ) +) + +video2world_7b_lora_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "peft"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_lora_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=1e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=1000, + broadcast_via_filesystem=True, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=5000, + distributed_parallelism="ddp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=False, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=4, + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1), + latent_shape=[ + 16, + 16, + 88, + 160, + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=False, + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(PEFTExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + scheduler=dict( + warm_up_steps=[0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_480_848, + dataloader_val=dataloader_val_cosmos_nemo_assets_480_848, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_7b_example_hdvila, + video2world_7b_example_cosmos_nemo_assets, + video2world_7b_example_cosmos_nemo_assets_4gpu_80gb, + video2world_7b_example_cosmos_nemo_assets_8gpu_40gb, + video2world_7b_example_cosmos_nemo_assets_4gpu_40gb, + video2world_7b_lora_example_cosmos_nemo_assets, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world/registry.py b/cosmos_predict1/diffusion/training/config/video2world/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5b49f93e8b6bb932a7755e5bf8585104916e27 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..522fb093ca4a8dcde618aaddb0e8a821634c1a36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_action import ActionConditionalVideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() +base_path = "datasets/bridge/" +train_annotation_path = os.path.join(base_path, "annotation/train") +val_annotation_path = os.path.join(base_path, "annotation/val") +test_annotation_path = os.path.join(base_path, "annotation/test") + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +bridge_train_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_action=True, + load_t5_embeddings=False, +) + +bridge_val_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="val", + load_action=True, + load_t5_embeddings=False, +) + + +dataloader_train = L(DataLoader)( + dataset=bridge_train_dataset, + sampler=L(get_sampler)(dataset=bridge_train_dataset), + batch_size=8, + drop_last=True, + pin_memory=True, + num_workers=8, +) +dataloader_val = L(DataLoader)( + dataset=bridge_val_dataset, + sampler=L(get_sampler)(dataset=bridge_val_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + + +video2world_action_bridge_2frames = LazyDict( # This experiment is used to verify the expanded config is the same as BASE002_101_512N_FSDP_LR-143_VideoImage_1-1 + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "action_conditional_video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world_action", + name="video2world_action_bridge_2frames", + ), + optimizer=dict( + lr=4e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=500, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2_000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + # Use 16x2x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 2, # Latent temporal dim + 32, # Latent height dim + 40, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(ActionConditionalVideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + first_random_n_num_condition_t_max=1, + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + ) + ), + # Use Image VAE for training + vae=dict(pixel_chunk_duration=1), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + dataloader_val=dataloader_val, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_action_bridge_2frames, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_action/registry.py b/cosmos_predict1/diffusion/training/config/video2world_action/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c0ba551a3297c923ca69e6a0dcdf2c43f391b5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_action/registry.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +import attrs +import torch +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition, VideoExtendConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + NumFramesConfig, + PaddingMaskConfig, + ReMapkey, + TextConfig, + VideoCondBoolConfig, +) +from cosmos_predict1.diffusion.training.config.video2world_action.experiment import register_experiments +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@dataclass +class ActionConditionalVideoExtendCondition(VideoExtendCondition): + action: Optional[torch.Tensor] = None + + +class ActionConditionalVideoExtendConditioner(VideoExtendConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> ActionConditionalVideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + assert "action" in batch, "ActionConditionalVideoExtendConditioner requires 'action' in batch" + output["action"] = batch["action"] + return ActionConditionalVideoExtendCondition(**output) + + +@attrs.define(slots=False) +class ActionConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `action`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="action", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "action" + + +ActionConditionalVideoExtendConditionerConfig: LazyDict = L(ActionConditionalVideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), + action=ActionConfig(), +) + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) + + cs.store( + group="conditioner", + package="model.conditioner", + name="action_conditional_video_cond", + node=ActionConditionalVideoExtendConditionerConfig, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..f33bc5cd87bec43be6ea668814f6c60b5ab4b645 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() +base_path = "datasets/bridge/" +train_annotation_path = os.path.join(base_path, "annotation/train") +val_annotation_path = os.path.join(base_path, "annotation/val") +test_annotation_path = os.path.join(base_path, "annotation/test") + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +bridge_train_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=57, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_action=False, + load_t5_embeddings=True, +) + +bridge_val_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=57, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="val", + load_action=False, + load_t5_embeddings=True, +) + + +dataloader_train = L(DataLoader)( + dataset=bridge_train_dataset, + sampler=L(get_sampler)(dataset=bridge_train_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) +dataloader_val = L(DataLoader)( + dataset=bridge_val_dataset, + sampler=L(get_sampler)(dataset=bridge_val_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + + +video2world_instruction_bridge_57frames = LazyDict( # This experiment is used to verify the expanded config is the same as BASE002_101_512N_FSDP_LR-143_VideoImage_1-1 + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world_instruction", + name="video2world_instruction_bridge_57frames", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=500, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2_000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + # Use 16x8x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 8, # Latent temporal dim + 32, # Latent height dim + 40, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + # Use Image VAE for training + vae=dict(pixel_chunk_duration=57), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + first_random_n_num_condition_t_max=1, + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + ) + ), + ), + # using the video extend model for training + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + dataloader_val=dataloader_val, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_instruction_bridge_57frames, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py b/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bea45494ca7826258e9f9629dc480c52cfbbdec5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world_instruction.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d6e2245873867f455e38654b2edab069b034bf --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_multiview import Dataset +from cosmos_predict1.diffusion.training.models.extend_model_multiview import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg_multiview import VideoExtendMultiviewGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 57 +num_views = 5 +view_keys = ["pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right"] +example_multiview_dataset_waymo = L(Dataset)( + dataset_dir="datasets/waymo", + sequence_interval=1, + num_frames=num_frames, + view_keys=view_keys, + video_size=(480, 848), +) + + +video2world_multiview_7b_example_waymo = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_multiview_7b_example_waymo", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # broadcast_via_filesystem=True, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=200, + hit_thres=5, + ), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + n_views=num_views, + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendMultiviewGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + n_views=num_views, + ), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + dataloader_val=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_multiview_7b_example_waymo, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py b/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..b30c4cdc00f18675329364a4f64679578dbf8c97 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world_multiview.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py b/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..92ad62cd053b2dba6daa0b02fe89b0b5c063e87a --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.interpolator import FSDPInterpolatorDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 18 +example_video_dataset = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train = L(DataLoader)( + dataset=example_video_dataset, + sampler=L(get_sampler)(dataset=example_video_dataset), + batch_size=1, + drop_last=True, +) +dataloader_val = L(DataLoader)( + dataset=example_video_dataset, + sampler=L(get_sampler)(dataset=example_video_dataset), + batch_size=1, + drop_last=True, +) + + +world_interpolator_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_world_interpolator", + name="world_interpolator_7b_example_hdvila", + ), + optimizer=dict( + # lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + lr=0.0, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # save_iter=1, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-WorldInterpolator/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + # max_iter=2, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, + 4, + 88, + 160, + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + # checkpoint=False, + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + context_parallel_size=1, + num_latents_to_drop=1, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_and_last_1", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ), + text=dict( + dropout_rate=0.5, + ), + ), + vae=dict(pixel_chunk_duration=9), # 9 frames per chunk for video vae (18 frames / 2 chunks = 9) + ), + model_obj=L(FSDPInterpolatorDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + world_interpolator_7b_example_hdvila, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py b/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..601f2e4da4bf89586d26f28e6cb1bd1103ca662b --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.world_interpolator.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/context_parallel.py b/cosmos_predict1/diffusion/training/context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4758db33f1d63546d01be917fa19a821ed4bc1f7 --- /dev/null +++ b/cosmos_predict1/diffusion/training/context_parallel.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, get_process_group_ranks, get_world_size + + +def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Split input tensor along the sequence dimension for checkpoint parallelism. + + This function divides the input tensor into equal parts along the specified + sequence dimension, based on the number of ranks in the checkpoint parallelism group. + It then selects the part corresponding to the current rank. + + Args: + x: Input tensor to be split. + seq_dim: The dimension along which to split the input (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A slice of the input tensor corresponding to the current rank. + + Raises: + AssertionError: If the sequence dimension is not divisible by the number of ranks. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_group.rank()], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the checkpoint parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + world_size = get_world_size(cp_group) + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) + + +def cat_outputs_cp_with_grad(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the context parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + It retains computational graph locally for each rank by replacing the portion of the tensor with original output. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + cp_size = cp_group.size() + assert cp_size > 0, "cp_size should be greater than 0" + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + rank = cp_group.rank() + gathered_tensors[rank] = x + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) diff --git a/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py b/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4dcd91920a9ab1364af8a2a482495cb4bec0f9 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses + + +@dataclasses.dataclass +class ItemDatasetConfig: + path: str + length: int diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_3D.py b/cosmos_predict1/diffusion/training/datasets/dataset_3D.py new file mode 100644 index 0000000000000000000000000000000000000000..e561400a34ba4e504ee717a1edabdf387108bf9a --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_3D.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import json +import os +import pickle +import random +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import imageio +import numpy as np +import torch +from decord import VideoReader, cpu +from einops import rearrange +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import ( + Resize_Preprocess, + ToTensorVideo, + euler2rotm, + rotm2euler, +) + + +class Dataset_3D(Dataset): + def __init__( + self, + train_annotation_path, + val_annotation_path, + test_annotation_path, + video_path, + sequence_interval, + num_frames, + cam_ids, + accumulate_action, + video_size, + val_start_frame_interval, + debug=False, + normalize=False, + pre_encode=False, + do_evaluate=False, + load_t5_embeddings=False, + load_action=True, + mode="train", + ): + """Dataset class for loading 3D robot action-conditional data. + + This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states), + and computes relative actions between consecutive frames. + + Args: + train_annotation_path (str): Path to training annotation files + val_annotation_path (str): Path to validation annotation files + test_annotation_path (str): Path to test annotation files + video_path (str): Base path to video files + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + cam_ids (list): List of camera IDs to sample from + accumulate_action (bool): Whether to accumulate actions relative to first frame + video_size (list): Target size [H,W] for video frames + val_start_frame_interval (int): Frame sampling interval for validation/test + debug (bool, optional): If True, only loads subset of data. Defaults to False. + normalize (bool, optional): Whether to normalize video frames. Defaults to False. + pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False. + do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False. + load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False. + load_action (bool, optional): Whether to load actions. Defaults to True. + mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'. + + The dataset loads robot trajectories and computes: + - RGB video frames from specified camera views + - Robot arm states (xyz position + euler angles) + - Gripper states (binary open/closed) + - Relative actions between consecutive frames + + Actions are computed as relative transforms between frames: + - Translation: xyz offset in previous frame's coordinate frame + - Rotation: euler angles of relative rotation + - Gripper: binary gripper state + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - action: Action tensor [T-1,7] + - video_name: Dict with episode/frame metadata + - latent: Pre-encoded video features if pre_encode=True + """ + + super().__init__() + if mode == "train": + self.data_path = train_annotation_path + self.start_frame_interval = 1 + elif mode == "val": + self.data_path = val_annotation_path + self.start_frame_interval = val_start_frame_interval + elif mode == "test": + self.data_path = test_annotation_path + self.start_frame_interval = val_start_frame_interval + self.video_path = video_path + self.sequence_interval = sequence_interval + self.mode = mode + self.sequence_length = num_frames + self.normalize = normalize + self.pre_encode = pre_encode + self.load_t5_embeddings = load_t5_embeddings + self.load_action = load_action + + self.cam_ids = cam_ids + self.accumulate_action = accumulate_action + + self.action_dim = 7 # ee xyz (3) + ee euler (3) + gripper(1) + self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0] + self.c_act_scaler = np.array(self.c_act_scaler, dtype=float) + self.ann_files = self._init_anns(self.data_path) + + self.samples = self._init_sequences(self.ann_files) + + self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0])) + if debug and not do_evaluate: + self.samples = self.samples[0:10] + self.wrong_number = 0 + self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) + self.training = False + self.preprocess = T.Compose( + [ + ToTensorVideo(), + Resize_Preprocess(tuple(video_size)), # 288 512 + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + def __str__(self): + return f"{len(self.ann_files)} samples from {self.data_path}" + + def _init_anns(self, data_dir): + ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")] + return ann_files + + def _init_sequences(self, ann_files): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_ann_file = { + executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files + } + for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)): + samples.extend(future.result()) + return samples + + def _load_and_process_ann_file(self, ann_file): + samples = [] + with open(ann_file, "r") as f: + ann = json.load(f) + + n_frames = len(ann["state"]) + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["ann_file"] = ann_file + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + return frame_data + + def _get_frames(self, label, frame_ids, cam_id, pre_encode): + if pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video_path = label["videos"][cam_id]["video_path"] + video_path = os.path.join(self.video_path, video_path) + frames = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + + def printvideo(videos, filename): + t_videos = rearrange(videos, "f c h w -> f h w c") + t_videos = ( + ((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy() + ) + print(t_videos.shape) + writer = imageio.get_writer(filename, fps=4) # fps 是帧率 + for frame in t_videos: + writer.append_data(frame) # 1 4 13 23 # fp16 24 76 456 688 + + if self.normalize: + frames = self.preprocess(frames) + else: + frames = self.not_norm_preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames + + def _get_obs(self, label, frame_ids, cam_id, pre_encode): + if cam_id is None: + temp_cam_id = random.choice(self.cam_ids) + else: + temp_cam_id = cam_id + frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode) + return frames, temp_cam_id + + def _get_robot_states(self, label, frame_ids): + all_states = np.array(label["state"]) + all_cont_gripper_states = np.array(label["continuous_gripper_state"]) + states = all_states[frame_ids] + cont_gripper_states = all_cont_gripper_states[frame_ids] + arm_states = states[:, :6] + assert arm_states.shape[0] == self.sequence_length + assert cont_gripper_states.shape[0] == self.sequence_length + return arm_states, cont_gripper_states + + def _get_all_robot_states(self, label, frame_ids): + all_states = np.array(label["state"]) + all_cont_gripper_states = np.array(label["continuous_gripper_state"]) + states = all_states[frame_ids] + cont_gripper_states = all_cont_gripper_states[frame_ids] + arm_states = states[:, :6] + return arm_states, cont_gripper_states + + def _get_all_actions(self, arm_states, gripper_states, accumulate_action): + action_num = arm_states.shape[0] - 1 + action = np.zeros((action_num, self.action_dim)) + if accumulate_action: + first_xyz = arm_states[0, 0:3] + first_rpy = arm_states[0, 3:6] + first_rotm = euler2rotm(first_rpy) + for k in range(1, action_num + 1): + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) + rel_rotm = first_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + else: + for k in range(1, action_num + 1): + prev_xyz = arm_states[k - 1, 0:3] + prev_rpy = arm_states[k - 1, 3:6] + prev_rotm = euler2rotm(prev_rpy) + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) + rel_rotm = prev_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + return torch.from_numpy(action) # (l - 1, act_dim) + + def _get_actions(self, arm_states, gripper_states, accumulate_action): + action = np.zeros((self.sequence_length - 1, self.action_dim)) + if accumulate_action: + first_xyz = arm_states[0, 0:3] + first_rpy = arm_states[0, 3:6] + first_rotm = euler2rotm(first_rpy) + for k in range(1, self.sequence_length): + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) + rel_rotm = first_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + else: + for k in range(1, self.sequence_length): + prev_xyz = arm_states[k - 1, 0:3] + prev_rpy = arm_states[k - 1, 3:6] + prev_rotm = euler2rotm(prev_rpy) + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) + rel_rotm = prev_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + return torch.from_numpy(action) # (l - 1, act_dim) + + def __getitem__(self, index, cam_id=None, return_video=False): + if self.mode != "train": + np.random.seed(index) + random.seed(index) + + try: + sample = self.samples[index] + ann_file = sample["ann_file"] + frame_ids = sample["frame_ids"] + with open(ann_file, "r") as f: + label = json.load(f) + arm_states, gripper_states = self._get_robot_states(label, frame_ids) + actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) + actions *= self.c_act_scaler + + data = dict() + if self.load_action: + data["action"] = actions.float() + + if self.pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video.to(dtype=torch.uint8) + + data["annotation_file"] = ann_file + + # NOTE: __key__ is used to uniquely identify the sample, required for callback functions + if "episode_id" in label: + data["__key__"] = label["episode_id"] + else: + data["__key__"] = label["original_path"] + + # Just add these to fit the interface + if self.load_t5_embeddings: + t5_embedding_path = ann_file.replace(".json", ".pickle") + with open(t5_embedding_path, "rb") as f: + data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) + else: + data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) + data["fps"] = 4 + data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 256, 256) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset_3D( + train_annotation_path="datasets/bridge/annotation/train", + val_annotation_path="datasets/bridge/annotation/val", + test_annotation_path="datasets/bridge/annotation/test", + video_path="datasets/bridge/", + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_t5_embeddings=True, + ) + + indices = [0, 13, 200, -1] + for idx in indices: + print( + ( + f"{idx=} " + f"{dataset[idx]['video'].sum()=}\n" + f"{dataset[idx]['video'].shape=}\n" + f"{dataset[idx]['video_name']=}\n" + f"{dataset[idx]['action'].sum()=}\n" + "---" + ) + ) + + from IPython import embed + + embed() diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py b/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py new file mode 100644 index 0000000000000000000000000000000000000000..c760c92c7f6531335c6194cace332c5664ab94c0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import json +import pickle +import random +import traceback +import warnings + +import numpy as np +import torch + +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.utils import log + + +class Dataset_3DBinary(Dataset_3D): + def __init__( + self, + train_annotation_path, + val_annotation_path, + test_annotation_path, + video_path, + sequence_interval, + num_frames, + cam_ids, + accumulate_action, + video_size, + val_start_frame_interval, + debug=False, + normalize=False, + pre_encode=False, + do_evaluate=False, + load_t5_embeddings=False, + load_action=True, + mode="train", + ): + """Dataset class for loading 3D robot action-conditional data. + + This dataset loads robot trajectories consisting of RGB video frames, robot states + (arm positions and binary gripper states), and computes relative actions between + consecutive frames. + """ + + super().__init__( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=video_path, + sequence_interval=sequence_interval, + num_frames=num_frames, + cam_ids=cam_ids, + accumulate_action=accumulate_action, + video_size=video_size, + val_start_frame_interval=val_start_frame_interval, + debug=debug, + normalize=normalize, + pre_encode=pre_encode, + do_evaluate=do_evaluate, + load_t5_embeddings=load_t5_embeddings, + load_action=load_action, + mode=mode, + ) + + log.info("Dataset_3DBinary: in this dataset, we binarize the gripper state to 0 or 1.") + + def _get_json_action(self, label, frame_ids): + all_action = np.array(label["action"]) + actions = all_action[frame_ids[:-1]] + return torch.from_numpy(actions) + + def __getitem__(self, index, cam_id=None, return_video=False): + if self.mode != "train": + np.random.seed(index) + random.seed(index) + + try: + sample = self.samples[index] + ann_file = sample["ann_file"] + frame_ids = sample["frame_ids"] + with open(ann_file, "r") as f: + label = json.load(f) + arm_states, gripper_states = self._get_robot_states(label, frame_ids) + actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) + actions *= self.c_act_scaler + + data = dict() + if self.load_action: + data["action"] = actions.float() + json_action = self._get_json_action(label, frame_ids).float() + json_action[:, :6] = data["action"][:, :6] + data["action"] = json_action + + if self.pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video.to(dtype=torch.uint8) + + data["annotation_file"] = ann_file + + if "episode_id" in label: + data["__key__"] = label["episode_id"] + else: + data["__key__"] = label["original_path"] + + # Just add these to fit the interface + if self.load_t5_embeddings: + t5_embedding_path = ann_file.replace(".json", ".pickle") + with open(t5_embedding_path, "rb") as f: + data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) + else: + data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) + data["fps"] = 4 + data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 256, 256) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset_3DBinary( + train_annotation_path="datasets/bridge/annotation/train", + val_annotation_path="datasets/bridge/annotation/val", + test_annotation_path="datasets/bridge/annotation/test", + video_path="datasets/bridge/", + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_t5_embeddings=True, + ) + + indices = [0, 13, 200, -1] + for idx in indices: + print( + ( + f"{idx=} " + f"{dataset[idx]['video'].sum()=}\n" + f"{dataset[idx]['video'].shape=}\n" + f"{dataset[idx]['video_name']=}\n" + f"{dataset[idx]['action'].sum()=}\n" + f"{dataset[idx]['json_action'].sum()=}\n" + "---" + ) + ) + + from IPython import embed + + embed() diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py b/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..6f86e8a8e5c0518afa193b13ff8a9c01b37ed5f3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_multiview.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import os +import pickle +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + view_keys, + video_size, + start_frame_interval=1, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + self.view_keys = view_keys + + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [ + os.path.join(video_dir, view_keys[0], f) for f in os.listdir(os.path.join(video_dir, view_keys[0])) + ] + print(f"{len(self.video_paths)} videos in total") + + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + cache_dir = os.path.join(self.dataset_dir, "cache") + self.prefix_t5_embeddings = {} + for view_key in view_keys: + with open(os.path.join(cache_dir, f"prefix_t5_embeddings_{view_key}.pickle"), "rb") as f: + self.prefix_t5_embeddings[view_key] = pickle.load(f)[0] + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + self.t5_dir, + os.path.basename(os.path.dirname(video_path)), + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + t5_embedding_path = sample["t5_embedding_path"] + + data = dict() + + videos = [] + t5_embeddings = [] + for view_key in self.view_keys: + video, fps = self._get_frames( + os.path.join(os.path.dirname(os.path.dirname(video_path)), view_key, os.path.basename(video_path)), + frame_ids, + ) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + videos.append(video) + + with open( + os.path.join( + os.path.dirname(os.path.dirname(t5_embedding_path)), + view_key, + os.path.basename(t5_embedding_path), + ), + "rb", + ) as f: + t5_embedding = pickle.load(f)[0] + t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) + t5_embedding = torch.from_numpy(t5_embedding) + if t5_embedding.shape[0] < 512: + t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) + t5_embeddings.append(t5_embedding) + video = torch.cat(videos, dim=1) + t5_embedding = torch.cat(t5_embeddings, dim=0) + + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": t5_embedding_path, + "start_frame_id": str(frame_ids[0]), + } + data["t5_text_embeddings"] = t5_embedding + data["t5_text_mask"] = torch.ones(512 * len(self.view_keys), dtype=torch.int64) + data["fps"] = fps + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + dataset_dir="datasets/waymo/", + sequence_interval=1, + num_frames=57, + view_keys=[ + "pinhole_front_left", + "pinhole_front", + "pinhole_front_right", + "pinhole_side_left", + "pinhole_side_right", + ], + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_utils.py b/cosmos_predict1/diffusion/training/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..963e4c9de3d2e7958dbaa0526b284650671965b3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_utils.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_util.py +""" + +import base64 +import math +import os +from io import BytesIO + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from PIL import Image + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def b64_2_img(data: str): + image_b64 = base64.b64decode(data) + img = Image.open(BytesIO(image_b64)).convert("RGB") + return img + + +def get_continuous_action(d_acts, c_act_max, c_act_min, n_bins): + c_act_max = c_act_max.to(d_acts.device) + c_act_min = c_act_min.to(d_acts.device) + c_acts = d_acts / (n_bins - 1) * (c_act_max - c_act_min) + c_act_min + return c_acts + + +def alpha2rotm(a): + """Alpha euler angle to rotation matrix.""" + rotm = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]]) + return rotm + + +def beta2rotm(b): + """Beta euler angle to rotation matrix.""" + rotm = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]]) + return rotm + + +def gamma2rotm(c): + """Gamma euler angle to rotation matrix.""" + rotm = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]]) + return rotm + + +def euler2rotm(euler_angles): + """Euler angle (ZYX) to rotation matrix.""" + alpha = euler_angles[0] + beta = euler_angles[1] + gamma = euler_angles[2] + + rotm_a = alpha2rotm(alpha) + rotm_b = beta2rotm(beta) + rotm_c = gamma2rotm(gamma) + + rotm = rotm_c @ rotm_b @ rotm_a + + return rotm + + +def isRotm(R): + # Checks if a matrix is a valid rotation matrix. + # Forked from Andy Zeng + Rt = np.transpose(R) + shouldBeIdentity = np.dot(Rt, R) + I = np.identity(3, dtype=R.dtype) + n = np.linalg.norm(I - shouldBeIdentity) + return n < 1e-6 + + +def rotm2euler(R): + # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/ + # R = Rz * Ry * Rx + assert isRotm(R) + sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) + singular = sy < 1e-6 + + if not singular: + x = math.atan2(R[2, 1], R[2, 2]) + y = math.atan2(-R[2, 0], sy) + z = math.atan2(R[1, 0], R[0, 0]) + else: + x = math.atan2(-R[1, 2], R[1, 1]) + y = math.atan2(-R[2, 0], sy) + z = 0 + + # (-pi , pi] + while x > np.pi: + x -= 2 * np.pi + while x <= -np.pi: + x += 2 * np.pi + while y > np.pi: + y -= 2 * np.pi + while y <= -np.pi: + y += 2 * np.pi + while z > np.pi: + z -= 2 * np.pi + while z <= -np.pi: + z += 2 * np.pi + return np.array([x, y, z]) + + +def get_converted_fp32_paths(deepspeed_ckpt_path): + deepspeed_ckpt_path = deepspeed_ckpt_path.rstrip("/") + ckpt_dir = os.path.dirname(deepspeed_ckpt_path) + ckpt_name = os.path.basename(deepspeed_ckpt_path) + fp32_ckpt_name = f"{ckpt_name}.fp32.pt" + converted_path = os.path.join(ckpt_dir, fp32_ckpt_name) + return converted_path + + +def quat2rotm(quat): + """Quaternion to rotation matrix. + + Args: + quat (4, numpy array): quaternion x, y, z, w + Returns: + rotm (3x3 numpy array): rotation matrix + """ + w = quat[3] + x = quat[0] + y = quat[1] + z = quat[2] + + s = w * w + x * x + y * y + z * z + + rotm = np.array( + [ + [1 - 2 * (y * y + z * z) / s, 2 * (x * y - z * w) / s, 2 * (x * z + y * w) / s], + [2 * (x * y + z * w) / s, 1 - 2 * (x * x + z * z) / s, 2 * (y * z - x * w) / s], + [2 * (x * z - y * w) / s, 2 * (y * z + x * w) / s, 1 - 2 * (x * x + y * y) / s], + ] + ) + + return rotm + + +class Resize_Preprocess: + def __init__(self, size): + """ + Initialize the preprocessing class with the target size. + Args: + size (tuple): The target height and width as a tuple (height, width). + """ + self.size = size + + def __call__(self, video_frames): + """ + Apply the transformation to each frame in the video. + Args: + video_frames (torch.Tensor): A tensor representing a batch of video frames. + Returns: + torch.Tensor: The transformed video frames. + """ + # Resize each frame in the video + resized_frames = torch.stack([F.resize(frame, self.size, antialias=True) for frame in video_frames]) + return resized_frames + + +class Preprocess: + def __init__(self, size): + self.size = size + + def __call__(self, clip): + clip = Preprocess.resize_scale(clip, self.size[0], self.size[1], interpolation_mode="bilinear") + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + @staticmethod + def resize_scale(clip, target_height, target_width, interpolation_mode): + target_ratio = target_height / target_width + H = clip.size(-2) + W = clip.size(-1) + clip_ratio = H / W + if clip_ratio > target_ratio: + scale_ = target_width / W + else: + scale_ = target_height / H + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_video.py b/cosmos_predict1/diffusion/training/datasets/dataset_video.py new file mode 100644 index 0000000000000000000000000000000000000000..d728129d6c9bb6d2e73dbba246b1ce5558cf6509 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_video.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_gear.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import os +import pickle +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + video_size, + start_frame_interval=1, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] + # print(f"{len(self.video_paths)} trajectories in total") + print(f"{len(self.video_paths)} videos in total") + + # self.t5_dir = os.path.join(self.dataset_dir, "labels") + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + # self.t5_dir, os.path.basename(video_path).replace(".mp4", ".npy") + self.t5_dir, + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video, fps = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": sample["t5_embedding_path"], + "start_frame_id": str(frame_ids[0]), + } + + # Just add these to fit the interface + # t5_embedding = np.load(sample["t5_embedding_path"])[0] + with open(sample["t5_embedding_path"], "rb") as f: + t5_embedding = pickle.load(f)[0] # [n_tokens, 1024] + n_tokens = t5_embedding.shape[0] + if n_tokens < 512: + t5_embedding = np.concatenate( + [t5_embedding, np.zeros((512 - n_tokens, 1024), dtype=np.float32)], axis=0 + ) + t5_text_mask = torch.zeros(512, dtype=torch.int64) + t5_text_mask[:n_tokens] = 1 + + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) + data["t5_text_mask"] = t5_text_mask + data["fps"] = fps + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=57, + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/diffusion/training/functional/loss.py b/cosmos_predict1/diffusion/training/functional/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..27d138371e86ae2ede2044e2f175306bbc63f59a --- /dev/null +++ b/cosmos_predict1/diffusion/training/functional/loss.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul + + +def create_per_sample_loss_mask( + loss_masking_cfg: dict, + data_batch: dict, + x_shape: Tuple[int], + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", +): + """ + Creates a per-sample loss mask based on the given configuration and input data batch. + + This function generates a dictionary of loss masks for each specified key in the loss masking configuration. + For keys present in both the configuration and the data batch, the corresponding data batch value is used. + For keys present only in the configuration, a tensor of zeros with the specified shape is created. + Additionally, it computes loss mask weights for each key based on the configuration values and adjusts them + based on the presence of certain keys in the data batch, such as "skip_face" and "object_loss_map". + + Note: + - The original `loss_masking_cfg` and `data_batch` are not modified by this function. + - For image data, it is assumed that the channel is always the first dimension. + - `skip_face` is for face regions that should be skipped during training, the key is provided so that we can generate + diverse human and avoid collapse to a single face given certain prompts. The issue happens for getty projects, + where face distribution in the dataset is high unbalanced that single man face can be shown in more than 100+ images. + + Parameters: + loss_masking_cfg (dict): Configuration for loss masking, specifying which keys to include and their weights. + data_batch (dict): The batch of data containing actual data points and potential mask indicators like "skip_face". + x_shape (tuple): The shape of the input data, used to initialize zero masks for keys not in the data batch. + dtype (torch.dtype): The data type for the tensors in the loss masks. + device (str, optional): The device on which to create the tensors. Defaults to 'cuda'. + + Returns: + dict: A dictionary containing combined loss masks adjusted according to the `loss_masking_cfg` and `data_batch`. + + Raises: + AssertionError: If "skip_face" is not present in `data_batch`. + + Note: `create_combined_loss_mask` is assumed to be a separate function that combines individual loss masks into a + single mask or set of masks based on the given parameters. Its behavior should be documented separately. + """ + loss_mask_data: dict = {} + for key in loss_masking_cfg: + if key not in data_batch: + loss_mask_data[key] = torch.zeros((x_shape[0], 1, x_shape[2], x_shape[3]), device=device) + else: + loss_mask_data[key] = data_batch[key] + + if "skip_face" not in data_batch: + # When skip_face is not there in data_dict, use 0 as default. This will not skip any sample. + data_batch["skip_face"] = torch.zeros((x_shape[0],), dtype=dtype, device=device) + + loss_mask_weight: dict = {} + for k, v in loss_masking_cfg.items(): + loss_mask_weight[k] = torch.tensor(v, device=device).expand(data_batch["skip_face"].size()) + + if "human_face_mask" in loss_mask_weight: + loss_mask_weight["human_face_mask"] = (1 - data_batch["skip_face"]) * loss_mask_weight["human_face_mask"] + + if "object_loss_map" in data_batch: + loss_mask_weight["object_loss_map"] = torch.ones(data_batch["object_loss_map"].shape[0], device=device) + + return create_combined_loss_mask(loss_mask_data, x_shape, dtype, device, loss_mask_weight) + + +def create_combined_loss_mask(data, x_shape, dtype, device="cuda", loss_masking=None): + """ + Creates a combined loss mask from multiple input masks. + + This function combines several loss masks into a single mask. In regions where masks overlap, + the highest value is assigned. Non-overlapping regions are assigned a default value of 1. + Regions with a mask value of zero are explicitly zeroed out, which is essential for padded loss calculations. + + Example: + Given the following masks and weights: + mask1: [0, 1, 1, 1, 0, 0], weight: 2 + mask2: [1, 0, 1, 0, 0, 0], weight: 4 + mask3: [0, 1, 0, 0, 0, 0], weight: 0 + The resulting combined loss mask would be: + [4, 0, 4, 2, 1, 1] + + Parameters: + data (dict): Contains the loss masks and their weights. + x_shape (tuple): The shape of the output mask. + dtype: The data type for the output mask. + device: The device on which the output mask will be allocated. + loss_masking: The loss masking weight configuration. + + Returns: + torch.Tensor: The combined loss mask. + """ + + loss_mask = torch.ones(x_shape, dtype=dtype, device=device) + zero_mask = torch.ones(x_shape, dtype=dtype, device=device) + + if loss_masking: + for key in loss_masking: + # Repeat mask along channel's dimension. ndim=4 for images. + repeat_dims = (1, x_shape[1]) + tuple([1] * (data[key].ndim - 2)) + mask_key = torch.tile(data[key], dims=repeat_dims) + weight_key = loss_masking[key] + + # handle zero weight case + is_zero_weight = (weight_key == 0).float()[:, None, None, None] + zero_mask = zero_mask * ( + (1 - is_zero_weight) * torch.ones(x_shape, dtype=dtype, device=device) + + is_zero_weight * (1 - mask_key.bool().float()) + ) + + # calculate weights + no_mask_region = (mask_key.bool() == 0).float() + loss_mask = batch_mul(mask_key, weight_key) + batch_mul(no_mask_region, loss_mask) + + loss_mask_final = loss_mask * zero_mask + return loss_mask_final diff --git a/cosmos_predict1/diffusion/training/functional/lr_scheduler.py b/cosmos_predict1/diffusion/training/functional/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..579d6debaceeefb13d2304f7389090ca9b496a2d --- /dev/null +++ b/cosmos_predict1/diffusion/training/functional/lr_scheduler.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np + +from cosmos_predict1.utils import distributed, log + + +class TeroPolyScheduler: + def __init__( + self, + total_Mimg: int, + batch_size: int, + ref_Mimg: Optional[int] = None, + ref_batches: float = 70e3 / 1024, + max_lr_ratio: Optional[float] = 1.0, + min_lr_ratio: Optional[float] = None, + rampup_Mimg: float = 0, + rampdown_Mimg: int = 0, + verbosity_interval: int = 0, + formula: str = "poly", + poly_exp: float = 0.5, + ): + self.total_Mimg = total_Mimg + self.batch_size = batch_size * distributed.get_world_size() + self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6 + self.ref_batches = ref_batches + self.max_lr_ratio = max_lr_ratio + self.min_lr_ratio = min_lr_ratio + self.rampup_Mimg = rampup_Mimg + self.rampdown_Mimg = rampdown_Mimg + self.verbosity_interval = verbosity_interval + self.formula = formula + self.poly_exp = poly_exp + + self._model = None + + @property + def model(self): + return self._model + + @model.setter + def model(self, model): + self._model = model + + def schedule(self, n, **kwargs): + cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6 + + if self.formula == "constant": + lr = 1.0 + elif self.formula == "poly": + lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp + else: + raise ValueError(f'Invalid learning rate formula "{self.formula}"') + + if self.max_lr_ratio is not None: + lr = min(lr, self.max_lr_ratio) + if self.min_lr_ratio is not None: + lr = max(lr, self.min_lr_ratio) + + if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg: + lr *= cur_Mimg / self.rampup_Mimg + if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg: + lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg + + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler: + """ + A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles. + It supports different configurations for each cycle, including the number of warm-up steps, minimum + and maximum scaling factors for the learning rate. + + The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning + rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler. + + Parameters: + warm_up_steps (list[int]): List of integers where each element represents the number of warm-up + steps for the corresponding cycle. + f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up. + f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle. + f_start (list[float]): List of starting scaling factors for each warm-up phase. + cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps. + verbosity_interval (int, optional): Interval of training steps at which to print current step and + scaling factor information. Set to 0 by default to disable verbosity. + + Examples: + >>> scheduler = LambdaWarmUpCosineScheduler2( + warm_up_steps=[10, 10], + f_min=[0.1, 0.1], + f_max=[1.0, 1.0], + f_start=[0.01, 0.01], + cycle_lengths=[50, 50], + verbosity_interval=10) + >>> for step in range(100): + >>> lr_multiplier = scheduler(step) + >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}") + """ + + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler): + """ + Linear instead of cosine decay for the main part of the cycle. + """ + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + self.last_f = f + return f diff --git a/cosmos_predict1/diffusion/training/models/extend_model.py b/cosmos_predict1/diffusion/training/models/extend_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb8ed2b37908fff47d1ec6dfddcce8a44756177 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/extend_model.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from statistics import NormalDist +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.config.base.conditioner import VideoCondBoolConfig +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as BaseModel +from cosmos_predict1.diffusion.training.models.model import _broadcast, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log, misc + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty + net_in: Optional[torch.Tensor] = None # input to the network + net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network + xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +def normalize_condition_latent(condition_latent): + """Normalize the condition latent tensor to have zero mean and unit variance + Args: + condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W + """ + condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") + mean = condition_latent_2D.mean(dim=-1) + std = condition_latent_2D.std(dim=-1) + # bct -> bct11 + mean = mean.unsqueeze(-1).unsqueeze(-1) + std = std.unsqueeze(-1).unsqueeze(-1) + condition_latent = (condition_latent - mean) / std + return condition_latent + + +class ExtendDiffusionModel(BaseModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None + ) -> Tuple[Tensor, VideoExtendCondition]: + raw_state, latent_state, condition = super().get_data_and_condition(data_batch) + if condition.data_type == DataType.VIDEO: + if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + latent_state = self.sample_tokens_start_from_p_or_i(latent_state) + condition = self.add_condition_video_indicator_and_video_input_mask( + latent_state, condition, num_condition_t=num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + log.debug(f"condition.data_type {condition.data_type}") + return raw_state, latent_state, condition + + def draw_augment_sigma_and_epsilon( + self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float + ) -> Tensor: + is_video_batch = condition.data_type == DataType.VIDEO + del condition + batch_size = size[0] + epsilon = torch.randn(size, **self.tensor_kwargs) + + gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) + + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed_inference: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """This function is used to augment the condition input with noise + Args: + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config + gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + sigma (Tensor): noise level for the generation region + Returns: + VideoExtendCondition: updated condition object + condition_video_augment_sigma: sigma for the condition region, feed to the network + augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W + + """ + + if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": + # Training only, sample sigma for the condition region + augment_sigma, _ = self.draw_augment_sigma_and_epsilon( + gt_latent.shape, + condition, + cfg_video_cond_bool.augment_sigma_sample_p_mean, + cfg_video_cond_bool.augment_sigma_sample_p_std, + cfg_video_cond_bool.augment_sigma_sample_multiplier, + ) + noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) + + elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": + # Inference only, use fixed sigma for the condition region + log.debug( + f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" + ) + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. + # This is achieved by setting all region as `generation`, i.e. value=0 + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition.condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + + # Inference, use fixed seed + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed_inference, + ) + else: + raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") + + # Now apply the augment_sigma to the gt_latent + + augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) + _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) + + if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input + if condition.condition_video_indicator.sum() > 0: # has condition frames + condition.condition_video_augment_sigma = c_noise_augment + else: # no condition frames + condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) + + # Multiply the whole latent with c_in_augment + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + + # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def drop_out_condition_region( + self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig + ) -> Tensor: + """Use for CFG on input frames, we drop out the conditional region + There are two option: + 1. when we dropout, we set the region to be zero + 2. when we dropout, we set the region to be noise_x + """ + # Unconditional case, use for cfg + if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": + # Set the condition location input to be zero + augment_latent_drop = torch.zeros_like(augment_latent) + elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": + # Set the condition location input to be noise_x, i.e., same as base model training + augment_latent_drop = noise_x + else: + raise NotImplementedError( + f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" + ) + return augment_latent_drop + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed_inference: int = 1, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super().denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, + cfg_video_cond_bool, + condition_latent, + condition_video_augment_sigma_in_inference, + sigma, + seed_inference=seed_inference, + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + denoise_pred = super().denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, # Use for noise of augment sigma + ) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: + """Sample the PPP... from the IPPP... sequence, only for video sequence + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + Returns: + torch.Tensor: sampled PPP tensor in shape B,C,T,H,W + """ + B, C, T, H, W = latent_state.shape + latent_dtype = latent_state.dtype + T_target = self.state_shape[1] + latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) + t_start = torch.randint(0, T - T_target + 1, (1,)) + # broadcast to other device + latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() + if parallel_state.is_initialized(): + latent_state_sample = _broadcast(latent_state_sample, to_tp=True, to_cp=True) + + return latent_state_sample + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(ExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/extend_model_multiview.py b/cosmos_predict1/diffusion/training/models/extend_model_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..3c66f32f04015229b2723b06030171b8f5ead3a3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/extend_model_multiview.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.extend_model import ( + ExtendDiffusionModel, + VideoDenoisePrediction, + normalize_condition_latent, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log + + +class MultiviewExtendDiffusionModel(ExtendDiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super(DiffusionModel, self).denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + + denoise_pred = super(DiffusionModel, self).denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + if guidance_other is not None: # and guidance_other != guidance: + import copy + + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + cond_other_x0 = self.denoise( + noise_x, + sigma, + condition_other, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + guidance_other=guidance_other, + ) + + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + x_sigma_max = ( + torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return samples + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(MultiviewExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/interpolator.py b/cosmos_predict1/diffusion/training/models/interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1f883d69dc5b5976549fff2f5410647fa25ae5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/interpolator.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as BaseModel +from cosmos_predict1.diffusion.training.models.model import broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log + + +class InterpolatorDiffusionModel(ExtendDiffusionModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.pixel_chunk_duration = config.vae.video_vae.pixel_chunk_duration + self.input_image_key = getattr(self.config, "input_image_key", None) + self.input_data_key = self.config.input_data_key + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None + ) -> Tuple[Tensor, VideoExtendCondition]: + raw_state, latent_state, condition = BaseModel.get_data_and_condition(self, data_batch) + num_valid_frames = raw_state.shape[2] - self.pixel_chunk_duration + 1 + raw_state, latent_state = ( + raw_state[:, :, :num_valid_frames, ...], + latent_state[:, :, : self.num_valid_latents, ...], + ) # [B, C, T, H, W] + raw_state, latent_state = raw_state.contiguous(), latent_state.contiguous() + if condition.data_type == DataType.VIDEO: + if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + latent_state = self.sample_tokens_start_from_p_or_i(latent_state) + condition = self.add_condition_video_indicator_and_video_input_mask( + latent_state, condition, num_condition_t=1 + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + log.debug(f"condition.data_type {condition.data_type}") + return raw_state, latent_state, condition + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + # Should be used for both training and inference. The first and last frame will be condition frames. + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator[:, :, -num_condition_t:] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + +@diffusion_fsdp_class_decorator +class FSDPInterpolatorDiffusionModel(InterpolatorDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model.py b/cosmos_predict1/diffusion/training/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e1f7892451d410e5496540454fe8e08c37ff49 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model.py @@ -0,0 +1,662 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + +import amp_C +import torch +from apex.multi_tensor_apply import multi_tensor_applier +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import broadcast_object_list, get_process_group_ranks +from torch.distributed.utils import _verify_param_shape_across_processes + +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_predict1.diffusion.training.conditioner import BaseVideoCondition, DataType +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition +from cosmos_predict1.diffusion.training.models.model_image import DiffusionModel as ImageModel +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import distributed, log, misc + +l2_norm_impl = amp_C.multi_tensor_l2norm +multi_tensor_scale_impl = amp_C.multi_tensor_scale + +# key to check if the video data is normalized or image data is converted to video data +# to avoid apply normalization or augment image dimension multiple times +# It is due to we do not have normalization and augment image dimension in the dataloader and move it to the model +IS_PREPROCESSED_KEY = "is_preprocessed" + + +def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: + """ + Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). + src (int): The source rank for the broadcast. Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor on all ranks. + """ + # First, broadcast the shape of the tensor + if distributed.get_rank() == src: + shape = torch.tensor(tensor.shape).cuda() + else: + shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() + if is_check_shape: + _verify_param_shape_across_processes(pg, [shape]) + torch.distributed.broadcast(shape, src, group=pg) + + # Resize the tensor on non-src ranks if necessary + if distributed.get_rank() != src: + tensor = tensor.new_empty(shape.tolist()).type_as(tensor) + + # Now broadcast the tensor data + torch.distributed.broadcast(tensor, src, group=pg) + + return tensor + + +def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: + """ + Broadcast the item from the minimum rank in the specified group(s). + Since global rank = tp_rank + cp_rank * tp_size + ... + First broadcast in the tp_group and then in the cp_group will + ensure that the item is broadcasted across ranks in cp_group and tp_group. + + Parameters: + - item: The item to broadcast (can be a torch.Tensor, str, or None). + - to_tp: Whether to broadcast to the tensor model parallel group. + - to_cp: Whether to broadcast to the context parallel group. + """ + if not parallel_state.is_initialized(): + return item + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 + to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 + + if to_tp: + min_tp_rank = min(get_process_group_ranks(tp_group)) + + if to_cp: + min_cp_rank = min(get_process_group_ranks(cp_group)) + + if isinstance(item, torch.Tensor): # assume the device is cuda + # log.info(f"{item.shape}", rank0_only=False) + if to_tp: + # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) + item = robust_broadcast(item, min_tp_rank, tp_group) + if to_cp: + # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) + item = robust_broadcast(item, min_cp_rank, cp_group) + elif item is not None: + broadcastable_list = [item] + if to_tp: + # log.info(f"{broadcastable_list}", rank0_only=False) + broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) + if to_cp: + broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) + + item = broadcastable_list[0] + return item + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = _broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition + + +class DiffusionModel(ImageModel): + def __init__(self, config): + super().__init__(config) + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.trained_data_record = { + "image": 0, + "video": 0, + "iteration": 0, + } + if parallel_state.is_initialized(): + self.data_parallel_size = parallel_state.get_data_parallel_world_size() + else: + self.data_parallel_size = 1 + + if self.config.adjust_video_noise: + self.video_noise_multiplier = math.sqrt(self.state_shape[1]) + else: + self.video_noise_multiplier = 1.0 + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + self.input_image_key = self.config.input_image_key + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. + Another comes from a dataloader which we by default assumes as video_data for video model training. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def draw_training_sigma_and_epsilon(self, size: int, condition: BaseVideoCondition) -> Tensor: + sigma_B, epsilon = super().draw_training_sigma_and_epsilon(size, condition) + is_video_batch = condition.data_type == DataType.VIDEO + multiplier = self.video_noise_multiplier if is_video_batch else 1 + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + save generated videos + """ + raw_data, x0, condition = self.get_data_and_condition(data) + guidance = data["guidance"] + data = misc.to(data, **self.tensor_kwargs) + sample = self.generate_samples_from_batch( + data, + guidance=guidance, + # make sure no mismatch and also works for cp + state_shape=x0.shape[1:], + n_sample=x0.shape[0], + ) + sample = self.decode(sample) + gt = raw_data + caption = data["ai_caption"] + return {"gt": gt, "result": sample, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) + + def training_step(self, data_batch: Dict[str, Tensor], iteration: int) -> Tuple[Dict[str, Tensor] | Tensor]: + input_key = self.input_data_key # by default it is video key + if self.is_image_batch(data_batch): + input_key = self.input_image_key + batch_size = data_batch[input_key].shape[0] + self.trained_data_record["image" if self.is_image_batch(data_batch) else "video"] += ( + batch_size * self.data_parallel_size + ) + self.trained_data_record["iteration"] += 1 + return super().training_step(data_batch, iteration) + + def state_dict(self) -> Dict[str, Any]: + state_dict = super().state_dict() + state_dict["trained_data_record"] = self.trained_data_record + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if "trained_data_record" in state_dict and hasattr(self, "trained_data_record"): + trained_data_record = state_dict.pop("trained_data_record") + if trained_data_record: + assert set(trained_data_record.keys()) == set(self.trained_data_record.keys()) + for k, v in trained_data_record.items(): + self.trained_data_record[k] = v + else: + log.warning("trained_data_record not found in the state_dict.") + return super().load_state_dict(state_dict, strict, assign) + + def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """ + Normalizes video data in-place on a CUDA device to reduce data loading overhead. + + This function modifies the video data tensor within the provided data_batch dictionary + in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. + + Warning: + A warning is issued if the data has not been previously normalized. + + Args: + data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. + This tensor is expected to be on a CUDA device and have dtype of torch.uint8. + + Side Effects: + Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. + + Note: + This operation is performed directly on the CUDA device to avoid the overhead associated + with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device + and has the correct dtype (torch.uint8) to avoid unexpected behaviors. + """ + input_key = self.input_data_key if input_key is None else input_key + # only handle video batch + if input_key in data_batch: + # Check if the data has already been normalized and avoid re-normalizing + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." + assert torch.all( + (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) + ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" + else: + assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." + data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + input_key = self.input_image_key if input_key is None else input_key + if input_key in data_batch: + # Check if the data has already been augmented and avoid re-augmenting + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert ( + data_batch[input_key].shape[2] == 1 + ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" + return + else: + data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, BaseVideoCondition]: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + input_key = self.input_data_key # by default it is video key + is_image_batch = self.is_image_batch(data_batch) + is_video_batch = not is_image_batch + + # Broadcast data and condition across TP and CP groups. + # sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang! + local_keys = sorted(list(data_batch.keys())) + # log.critical(f"all keys {local_keys}", rank0_only=False) + for key in local_keys: + data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) + + if is_image_batch: + input_key = self.input_image_key + + # Latent state + raw_state = data_batch[input_key] + latent_state = self.encode(raw_state).contiguous() + + # Condition + condition = self.conditioner(data_batch) + if is_image_batch: + condition.data_type = DataType.IMAGE + else: + condition.data_type = DataType.VIDEO + + # VAE has randomness. CP/TP group should have the same encoded output. + + latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) + condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) + + return raw_state, latent_state, condition + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + super().on_train_start(memory_format) + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + if sequence_parallel: + self.net.enable_sequence_parallel() + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + output_batch, kendall_loss, pred_mse, edm_loss = super().compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + def get_x_from_clean( + self, + in_clean_img: torch.Tensor, + sigma_max: float | None, + seed: int = 1, + ) -> Tensor: + """ + in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising + sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video + """ + if in_clean_img is None: + return None + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) + if sigma_max is None: + sigma_max = self.sde.sigma_max + x_sigma_max = in_clean_img + noise * sigma_max + return x_sigma_max + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def on_after_backward(self, iteration: int = 0): + finalize_model_grads([self]) + + def get_grad_norm( + self, + norm_type: Union[int, float] = 2, + filter_fn: Callable[[str, torch.nn.Parameter], bool] | None = None, + ) -> float: + """Calculate the norm of gradients, handling model parallel parameters. + + This function is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ + with added functionality to handle model parallel parameters. + + Args: + norm_type (float or int): Type of norm to use. Can be 2 for L2 norm. + 'inf' for infinity norm is not supported. + filter_fn (callable, optional): Function to filter parameters for norm calculation. + Takes parameter name and parameter as input, returns True if this parameter is sharded else False. + + Returns: + float: Total norm of the parameters (viewed as a single vector). + + Note: + - Uses NVIDIA's multi-tensor applier for efficient norm calculation. + - Handles both model parallel and non-model parallel parameters separately. + - Currently only supports L2 norm (norm_type = 2). + """ + # Get model parallel group if parallel state is initialized + if parallel_state.is_initialized(): + model_parallel_group = parallel_state.get_model_parallel_group() + else: + model_parallel_group = None + + # Default filter function to identify tensor parallel parameters + if filter_fn is None: + + def is_tp(name, param): + return ( + any(key in name for key in ["to_q.0", "to_k.0", "to_v.0", "to_out.0", "layer1", "layer2"]) + and "_extra_state" not in name + ) + + filter_fn = is_tp + + # Separate gradients into model parallel and non-model parallel + without_mp_grads_for_norm = [] + with_mp_grads_for_norm = [] + for name, param in self.named_parameters(): + if param.grad is not None: + if filter_fn(name, param): + with_mp_grads_for_norm.append(param.grad.detach()) + else: + without_mp_grads_for_norm.append(param.grad.detach()) + + # Only L2 norm is currently supported + if norm_type != 2.0: + raise NotImplementedError(f"Norm type {norm_type} is not supported. Only L2 norm (2.0) is implemented.") + + # Calculate L2 norm using NVIDIA's multi-tensor applier + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + + # Calculate norm for non-model parallel gradients + without_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if without_mp_grads_for_norm: + without_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [without_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Calculate norm for model parallel gradients + with_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if with_mp_grads_for_norm: + with_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [with_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Square the norms as we'll be summing across model parallel GPUs + total_without_mp_norm = without_mp_grad_norm**2 + total_with_mp_norm = with_mp_grad_norm**2 + + # Sum across all model-parallel GPUs + torch.distributed.all_reduce(total_with_mp_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group) + + # Combine norms from model parallel and non-model parallel gradients + total_norm = (total_with_mp_norm.item() + total_without_mp_norm.item()) ** 0.5 + + return total_norm + + def clip_grad_norm_(self, max_norm: float): + """ + This function performs gradient clipping to prevent exploding gradients. + It calculates the total norm of the gradients, and if it exceeds the + specified max_norm, scales the gradients down proportionally. + + Args: + max_norm (float): The maximum allowed norm for the gradients. + + Returns: + torch.Tensor: The total norm of the gradients before clipping. + + Note: + This implementation uses NVIDIA's multi-tensor applier for efficiency. + """ + # Collect gradients from all parameters that require gradients + grads = [] + for param in self.parameters(): + if param.grad is not None: + grads.append(param.grad.detach()) + + # Calculate the total norm of the gradients + total_norm = self.get_grad_norm() + + # Compute the clipping coefficient + clip_coeff = max_norm / (total_norm + 1.0e-6) + + # Apply gradient clipping if the total norm exceeds max_norm + if clip_coeff < 1.0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + # Apply the scaling to the gradients using multi_tensor_applier for efficiency + multi_tensor_applier(multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff) + + return torch.tensor([total_norm]) + + +def _allreduce_layernorm_grads(model: List[torch.nn.Module]): + """ + All-reduce the following layernorm grads: + - When tensor parallel is enabled, all-reduce grads of QK-layernorm + - When sequence parallel, all-reduce grads of AdaLN, t_embedder, additional_timestamp_embedder, + and affline_norm. + """ + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + if parallel_state.get_tensor_model_parallel_world_size() > 1: + grads = [] + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + + if "to_q.1" in name or "to_k.1" in name: # TP # Q-layernorm # K-layernorm + # grad = param.main_grad + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if sequence_parallel: # TP + SP + if ( + "t_embedder" in name + or "adaLN_modulation" in name + or "additional_timestamp_embedder" in name + or "affline_norm" in name + or "input_hint_block" in name + or "zero_blocks" in name + ): + # grad = param.main_grad + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def finalize_model_grads(model: List[torch.nn.Module]): + """ + All-reduce layernorm grads for tensor/sequence parallelism. + Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py#L99 + """ + + _allreduce_layernorm_grads(model) + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_image.py b/cosmos_predict1/diffusion/training/models/model_image.py new file mode 100644 index 0000000000000000000000000000000000000000..11ff1503c086f828cdc4afdd26567fad96f57fe4 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_image.py @@ -0,0 +1,933 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.module.blocks import FourierFeatures +from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from cosmos_predict1.diffusion.training.functional.loss import create_per_sample_loss_mask +from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh +from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_scheduler +from cosmos_predict1.diffusion.types import DenoisePrediction +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.ema import FastEmaModelUpdater +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate +from cosmos_predict1.utils.model import Model + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class DiffusionModel(Model): + def __init__(self, config): + super().__init__() + + self.config = config + + # how many sample have been processed + self.sample_counter = 0 + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.warning(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.sde = lazy_instantiate(config.sde) + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + + # 3. vae + with misc.timer("DiffusionModel: set_up_vae"): + self.vae: BaseVAE = lazy_instantiate(config.vae) + assert ( + self.vae.latent_ch == self.state_shape[0] + ), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" + + # 4. Set up loss options, including loss masking, loss reduce and loss scaling + self.loss_masking: Optional[Dict] = config.loss_masking + self.loss_reduce = getattr(config, "loss_reduce", "mean") + assert self.loss_reduce in ["mean", "sum"] + self.loss_scale = getattr(config, "loss_scale", 1.0) + log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}") + log.critical(f"Enable loss masking: {config.loss_mask_enabled}") + + # 5. diffusion neural networks part + self.set_up_model() + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key + + def build_model(self) -> torch.nn.ModuleDict: + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @misc.timer("DiffusionModel: set_up_model") + def set_up_model(self): + config = self.config + self.model = self.build_model() + if config.ema.enabled: + with misc.timer("DiffusionModel: instantiate ema"): + config.ema.model = self.model + self.model_ema = lazy_instantiate(config.ema) + config.ema.model = None + else: + self.model_ema = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """ + update the model_ema + """ + if self.config.ema.enabled: + self.model_ema.update_average(self.model, iteration) + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + if self.config.ema.enabled: + self.model_ema.to(dtype=torch.float32) + if hasattr(self.vae, "reset_dtype"): + self.vae.reset_dtype() + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config + if torch.__version__ < "2.3": + log.warning( + "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" + "It's very likely there will be no significant speedup from torch.compile.\n" + "Please use at least 24.04 Pytorch container." + ) + # Increasing cache size. It's required because of the model size and dynamic input shapes resulting in + # multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for + # up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe + # graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about + # exceeding cache limit, you may want to increase this size. + # Starting with 24.05 Pytorch container, the default value is 256 anyway. + # You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. + torch._dynamo.config.accumulated_cache_size_limit = 256 + # dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs + # at initial iterations, but can result in more specialized and efficient kernels. + # dynamic=True currently throws errors in pytorch 2.3. + self.model.net = torch.compile(self.model.net, dynamic=False, disable=not self.config.use_torch_compile) + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + """ + Compute loss givee epsilon and sigma + + This method is responsible for computing loss give epsilon and sigma. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + x0_from_data_batch: raw image/video + x0: image/video latent + condition: text condition + epsilon: noise + sigma: noise level + + Returns: + tuple: A tuple containing four elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor 1: kendall loss, + - Tensor 2: MSE loss, + - Tensor 3: EDM loss + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the mean and stand deviation of the marginal probability distribution. + mean, std = self.sde.marginal_prob(x0, sigma) + # Generate noisy observations + xt = mean + batch_mul(std, epsilon) # corrupted data + # make prediction + model_pred = self.denoise(xt, sigma, condition) + # loss weights for different noise levels + weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) + # extra weight for each sample, for example, aesthetic weight, camera weight + weights_per_sample = self.get_per_sample_weight(data_batch, x0_from_data_batch.shape[0]) + # extra loss mask for each sample, for example, human faces, hands + loss_mask_per_sample = self.get_per_sample_loss_mask(data_batch, x0_from_data_batch.shape, x0.shape) + pred_mse = (x0 - model_pred.x0) ** 2 * loss_mask_per_sample + edm_loss = batch_mul(pred_mse, weights_per_sigma * weights_per_sample) + if self.config.loss_add_logvar: + kendall_loss = batch_mul(edm_loss, torch.exp(-model_pred.logvar).view(-1)).flatten( + start_dim=1 + ) + model_pred.logvar.view(-1, 1) + else: + kendall_loss = edm_loss.flatten(start_dim=1) + output_batch = { + "x0": x0, + "xt": xt, + "sigma": sigma, + "weights_per_sigma": weights_per_sigma, + "weights_per_sample": weights_per_sample, + "loss_mask_per_sample": loss_mask_per_sample, + "condition": condition, + "model_pred": model_pred, + "mse_loss": pred_mse.mean(), + "edm_loss": edm_loss.mean(), + } + return output_batch, kendall_loss, pred_mse, edm_loss + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step for the diffusion model. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + + Returns: + tuple: A tuple containing two elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor: The computed loss for the training step as a PyTorch Tensor. + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. + x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) + + # Sample pertubation noise levels and N(0, 1) noises + sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) + + output_batch, kendall_loss, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + + if self.loss_reduce == "mean": + kendall_loss = kendall_loss.mean() * self.loss_scale + elif self.loss_reduce == "sum": + kendall_loss = kendall_loss.sum(dim=1).mean() * self.loss_scale + else: + raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}") + + return output_batch, kendall_loss + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + if getattr(self.config, "use_dummy_temporal_dim", False): + # When using video DiT model for image, we need to use a dummy temporal dimension. + xt = xt.unsqueeze(2) + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + if getattr(self.config, "use_dummy_temporal_dim", False): + x0_pred = x0_pred.squeeze(2) + eps_pred = eps_pred.squeeze(2) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + return self.vae.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + return self.vae.decode(latent / self.sigma_data) + + def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: + del condition + batch_size = x0_size[0] + epsilon = torch.randn(x0_size, **self.tensor_kwargs) + return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon + + def get_data_and_condition(self, data_batch: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, CosmosCondition]: + """ + processing data batch draw from data loader and return data and condition that used for denoising task + + Returns: + raw_state (tensor): the image / video data that feed to vae + latent_state (tensor): nosie-free state, the vae latent state + condition (CosmosCondition): condition information for conditional generation. Generated from conditioner + """ + raw_state = data_batch[self.input_data_key] + latent_state = self.encode(raw_state) + condition = self.conditioner(data_batch) + return raw_state, latent_state, condition + + def get_per_sample_weight(self, data_batch: dict[str, torch.Tensor], batch_size: int): + r""" + extra weight for each sample, for example, aesthetic weight + Args: + data_batch: raw data batch draw from the training data loader. + batch_size: int, the batch size of the input data + """ + aesthetic_cfg = getattr(self.config, "aesthetic_finetuning", None) + if (aesthetic_cfg is not None) and getattr(aesthetic_cfg, "enabled", False): + sample_weight = data_batch["aesthetic_weight"] + else: + sample_weight = torch.ones(batch_size, **self.tensor_kwargs) + + camera_cfg = getattr(self.config, "camera_sample_weight", None) + if (camera_cfg is not None) and getattr(camera_cfg, "enabled", False): + sample_weight *= 1 + (data_batch["camera_attributes"][:, 1:].sum(dim=1) != 0) * (camera_cfg.weight - 1) + return sample_weight + + def get_per_sample_loss_mask(self, data_batch, raw_x_shape, latent_x_shape): + """ + extra loss mask for each sample, for example, human faces, hands. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + raw_x_shape (tuple): shape of the input data. We need the raw_x_shape for necessary resize operation. + latent_x_shape (tuple): shape of the latent data + """ + if self.config.loss_mask_enabled: + raw_x_shape = [raw_x_shape[0], 1, *raw_x_shape[2:]] + weights = create_per_sample_loss_mask( + self.loss_masking, data_batch, raw_x_shape, torch.get_default_dtype(), "cuda" + ) + return F.interpolate(weights, size=latent_x_shape[2:], mode="bilinear") + + return 1.0 + + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): + """ + Args: + sigma (tensor): noise level + + Returns: + loss weights per sigma noise level + """ + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + def generate_samples(self, batch_size: int, condition: CosmosCondition) -> torch.Tensor: + """ + Generate samples with given condition. It is WITHOUT classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + return self.denoise(x, t, condition).x0 # ODE function + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def generate_cfg_samples( + self, batch_size: int, condition: CosmosCondition, uncondition: CosmosCondition, guidance=1.5 + ) -> torch.Tensor: + """ + Generate samples with with classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + uncondition (CosmosCondition): uncondition information, possibily generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + cond_x0 = self.denoise(x, t, condition).x0 + uncond_x0 = self.denoise(x, t, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Optional[Tuple] = None, + n_sample: Optional[int] = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + ) -> torch.Tensor: + """ + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + batch_size = n_sample or data_batch[self.input_data_key].shape[0] + state_shape = state_shape or self.state_shape + x_sigma_max = ( + misc.arch_invariant_rand( + (batch_size,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + return self.sampler( + x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max, num_steps=num_steps, solver_option=solver_option + ) + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Current code does nothing. + """ + return {}, torch.tensor(0).to(**self.tensor_kwargs) + + @torch.no_grad() + def forward(self, xt, t, condition: CosmosCondition): + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + return self.denoise(xt, t, condition) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer = lazy_instantiate(optimizer_config, model=self.model) + scheduler = get_base_scheduler(optimizer, self, scheduler_config) + return optimizer, scheduler + + def state_dict(self) -> Dict[str, Any]: + """ + Returns the current state of the model as a dictionary. + + Returns: + Dict: The current state of the model as a dictionary. + """ + return { + "model": self.model.state_dict(), + "ema": self.model_ema.state_dict() if self.config.ema.enabled else None, + } + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + """ + Loads a state dictionary into the model and optionally its EMA counterpart. + Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. + + Parameters: + state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and + potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. + strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly + those in the model and EMA model (if applicable). Defaults to True. + assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than + matching keys one-by-one. This is typically used when loading parts of state dicts + or using customized loading procedures. Defaults to False. + """ + if strict: + # the converted tpsp checkpoint has "ema" and it is None + if self.config.ema.enabled and state_dict["ema"] is not None: + ema_results: _IncompatibleKeys = self.model_ema.load_state_dict( + state_dict["ema"], strict=strict, assign=assign + ) + reg_results: _IncompatibleKeys = self.model.load_state_dict( + state_dict["model"], strict=strict, assign=assign + ) + if self.config.ema.enabled and state_dict["ema"] is not None: + return _IncompatibleKeys( + ema_results.missing_keys + reg_results.missing_keys, + ema_results.unexpected_keys + reg_results.unexpected_keys, + ) + return reg_results + else: + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.critical("load model in non-strict mode") + if "model" in state_dict: + log.critical(non_strict_load_model(self.model, state_dict["model"]), rank0_only=False) + else: + log.critical(non_strict_load_model(self.model, state_dict), rank0_only=False) + if self.config.ema.enabled and "ema" in state_dict and state_dict["ema"] is not None: + log.critical("load ema model in non-strict mode") + log.critical(non_strict_load_model(self.model_ema, state_dict["ema"]), rank0_only=False) + + def get_ckpt_postfix(self) -> Tuple[str, int, int]: + """Get the checkpoint file postfix. + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + rank_to_save ema (int), we will not save each ema model in each rank, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + total_ema_num = min(self.config.ema.num, distributed.get_world_size()) + rank = distributed.get_rank() + if rank == 0: + return "", 0, total_ema_num + if self.config.ema.enabled: + if rank < self.config.ema.num: + return f"_RANK{rank}", rank, total_ema_num + return "", 0, total_ema_num # use rank 0 to save the checkpoint + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema.copy_to(self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + +T = TypeVar("T", bound=DiffusionModel) + + +def diffusion_fsdp_class_decorator(base_class: Type[T]) -> Type[T]: + """ + Decorator for the FSDP class for the diffusion model, which handles the FSDP specific logic for the diffusion model. + """ + + class FSDPClass(base_class): + """ + Handle FSDP specific logic for the diffusion model. Including: + - FSDP model initialization + - FSDP model / optimizer save and loading + - Different from the original DiffusionModel, the impl of multi-rank EMA is a bit hacky. \ + We need to make sure sharded model weights for EMA and regular model are the same. + """ + + def __init__(self, config, fsdp_checkpointer: Any): + self.fsdp_checkpointer = fsdp_checkpointer + super().__init__(config) + + def set_up_model(self): + config = self.config + + # 1. build FSDP sharding strategy and device_mesh + strategy = { + "full": ShardingStrategy.FULL_SHARD, + "hybrid": ShardingStrategy.HYBRID_SHARD, + }[config.fsdp.sharding_strategy] + log.critical(f"Using {strategy} sharding strategy for FSDP") + + if config.fsdp.sharding_strategy == "hybrid": + sharding_group_size = getattr(config.fsdp, "sharding_group_size", 8) + device_mesh = hsdp_device_mesh( + sharding_group_size=sharding_group_size, + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + replicate_group = device_mesh.get_group(mesh_dim="replicate") + fsdp_process_group = (shard_group, replicate_group) + else: + device_mesh = hsdp_device_mesh( + sharding_group_size=distributed.get_world_size(), + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + fsdp_process_group = shard_group + + # We piggyback the `device_mesh` to megatron-core's `parallel_state` for global access. + # This is not megatron-core's original API. + parallel_state.fsdp_device_mesh = device_mesh + + def get_wrap_policy(_model): + if not hasattr(_model.net, "fsdp_wrap_block_cls"): + raise ValueError( + "Networks does not have fsdp_wrap_block_cls attribute, please check the net definition" + ) + fsdp_blocks_cls = _model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] + ) + log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") + + log.critical(f"Using wrap policy {config.fsdp.policy}") + if config.fsdp.policy == "size": + min_num_params = getattr(config.fsdp, "min_num_params", 100) + log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") + wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + else: + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + + wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=set(fsdp_blocks_cls), + ) + return wrap_policy + + # 2. build naive pytorch model and load weights if exists + replica_idx, shard_idx = device_mesh.get_coordinate() + # 2.1 handle ema case first, since float32 is more expensive + if config.ema.enabled: + with misc.timer("Creating PyTorch model and loading weights for ema"): + model_ema = self.build_model().float() + model_ema.cuda().eval().requires_grad_(False) + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic + self.fsdp_checkpointer.load_model_during_init(model_ema, is_ema=True) + # sync ema model weights from rank0 + with misc.timer("Sync model states for EMA model"): + #! this is IMPORTANT, see the following comment about regular model for details + #! we broadcast the ema model first, since it is fp32 and costs more memory + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + # for ema model with dfiferent rate, we download the model when necessary + if shard_idx == 0 and replica_idx > 0 and replica_idx < config.ema.num: + print("loading ema model in rank", replica_idx) + self.fsdp_checkpointer.load_model_during_init( + model_ema, + is_ema=True, + ema_id=replica_idx, + ) + print("finish loading ema model in rank", replica_idx) + # 2.1.2 create FSDP model for ema model + with misc.timer("Creating FSDP model for EMA model"): + self.model_ema = FSDP( + model_ema, + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + process_group=device_mesh.get_group(mesh_dim=1), + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=get_wrap_policy(model_ema), + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + ) + + # extra ema model upate logic to the model + self.model_ema_worker = FastEmaModelUpdater() + s = 0.1 + replica_idx, shard_idx = device_mesh.get_coordinate() + divider = 2**replica_idx if replica_idx < config.ema.num else 1 + if replica_idx < config.ema.num: + if shard_idx == 0: + print(f"EMA: rank {replica_idx}, rate {config.ema.rate / divider}") + s = config.ema.rate / divider + self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + torch.cuda.empty_cache() + + # 2.2 handle regular model + with misc.timer("Creating PyTorch model and loading weights for regular model"): + model = self.build_model().cuda().to(**self.tensor_kwargs) + + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic and sync later + self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) + + #! overwrite the forward method so that it will invoke the FSDP-specific pre- and post-forward sharding logic + model.forward = super().training_step + #! this is IMPORTANT, though following two lines are identical to sync_module_states=True in FSDP + #! we do it twice so that following line can warm up and avoid OOM in aws 128+ nodes settings + #! qsh hypothesize that it is due to overhead of initialization of nccl network communication; + #! without it, peak mem : reg_model + ema_model + FSDP overhead + nccl communication initialization overhead + #! with it, peak men: reg_model + ema_model + FSDP overhead + #! it is tricky, but it works! + with misc.timer("Sync model states for regular model"): + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + + with misc.timer("Creating FSDP model"): + self.model = FSDP( + model.to(**self.tensor_kwargs), + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + sharding_strategy=strategy, + auto_wrap_policy=get_wrap_policy(model), + process_group=fsdp_process_group, + limit_all_gathers=True, + ) + + if self.config.fsdp.checkpoint: + fsdp_blocks_cls = model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) + if isinstance(fsdp_blocks_cls, (list, tuple, set)) + else [fsdp_blocks_cls] + ) + log.critical(f"Applying FSDP checkpointing with FSDP blocks: {fsdp_blocks_cls}") + apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) + + torch.cuda.empty_cache() + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + del scheduler, optimizer + + if self.config.ema.enabled: + # calculate beta for EMA update + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.ema_exp_coefficient + 1) + self.model_ema_worker.update_average(self.model, self.model_ema, beta=beta) + + def training_step( + self, data_batch: Dict[str, torch.Tensor], iteration: int + ) -> Tuple[Dict[str, torch.Tensor] | torch.Tensor]: + # ! Important!!! + # ! make sure the training step is the same as the forward method~(training_step in the super class) + # ! this is necessary to trigger the FSDP-specific pre- and post-forward sharding logic + return self.model(data_batch, iteration) + + def state_dict(self) -> Dict: + raise NotImplementedError( + "FSDPDiffModle does not support state_dict, use state_dict_model and FSDPCheckpointer" + ) + + @misc.timer("FSDP state_dict_model") + def state_dict_model(self) -> Dict: + with FSDP.summon_full_params(self.model): + pass + with FSDP.state_dict_type( + self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + model_state = self.model.state_dict() + if self.config.ema.enabled: + with FSDP.summon_full_params(self.model_ema): + pass + with FSDP.state_dict_type( + self.model_ema, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + ema_model_state = self.model_ema.state_dict() + else: + ema_model_state = None + return { + "model": model_state, + "ema": ema_model_state, + } + + def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: + raise NotImplementedError("FSDPDiffModle does not support load_state_dict, using FSDPCheckpointer") + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) + self.fsdp_checkpointer.load_optim_scheduler_during_init( + self.model, + optimizer, + scheduler, + ) + return optimizer, scheduler + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema_worker.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema_worker.copy_to(src_model=self.model_ema, tgt_model=self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema_worker.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + def get_ckpt_postfix(self) -> Tuple[str, int]: + """Get the checkpoint file postfix. check FSDPCheckpointer for more details + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ + we will not save each ema model in each GPU, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + mesh_shape = parallel_state.fsdp_device_mesh.shape + total_ema_num = min(self.config.ema.num, mesh_shape[0]) + replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() + if replicate_idx == 0: + return "", 0, shard_idx, total_ema_num + if self.config.ema.enabled: + if replicate_idx < self.config.ema.num: + return f"_RANK{replicate_idx}", replicate_idx, shard_idx, total_ema_num + return "", replicate_idx, shard_idx, total_ema_num + + return FSDPClass + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_multiview.py b/cosmos_predict1/diffusion/training/models/model_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..8cae2d6b0561f69fdc706d579a357b6940674015 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_multiview.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log, misc + + +class MultiviewDiffusionModel(DiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + x0_fn = self.get_x0_fn_from_batch( + data_batch, guidance, is_negative_prompt=is_negative_prompt, guidance_other=guidance_other + ) + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + if guidance_other is not None: + # assume this is for inference time trajectory guidance for now + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + cond_other_x0 = self.denoise(noise_x, sigma, condition_other).x0 + + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + if "guided_image" in data_batch: + assert False, "not supported" + return raw_x0 + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(MultiviewDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_peft.py b/cosmos_predict1/diffusion/training/models/model_peft.py new file mode 100644 index 0000000000000000000000000000000000000000..a25645e1d8f348e357d46539af0f248cacb30db7 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_peft.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Type, TypeVar + +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as VideoDiffusionModel +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils import misc +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + +T = TypeVar("T") + + +def video_peft_decorator(base_class: Type[T]) -> Type[T]: + class PEFTVideoDiffusionModel(base_class): + def __init__(self, config: dict, fsdp_checkpointer=None): + super().__init__(config) + + @misc.timer("PEFTVideoDiffusionModel: set_up_model") + def set_up_model(self): + config = self.config + peft_control_config_parser = LayerControlConfigParser(config=config.peft_control) + peft_control_config = peft_control_config_parser.parse() + self.model = self.build_model() + if peft_control_config and peft_control_config["customization_type"] == CustomizationType.LORA: + add_lora_layers(self.model, peft_control_config) + num_lora_params = setup_lora_requires_grad(self.model) + if num_lora_params == 0: + raise ValueError("No LoRA parameters found. Please check the model configuration.") + if config.ema.enabled: + with misc.timer("PEFTDiffusionModel: instantiate ema"): + config.ema.model = self.model + self.model_ema = lazy_instantiate(config.ema) + config.ema.model = None + else: + self.model_ema = None + + def state_dict_model(self) -> Dict: + return { + "model": self.model.state_dict(), + "ema": self.model_ema.state_dict() if self.model_ema else None, + } + + return PEFTVideoDiffusionModel + + +@video_peft_decorator +class PEFTVideoDiffusionModel(VideoDiffusionModel): + pass + + +@video_peft_decorator +class PEFTExtendDiffusionModel(ExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/module/blocks.py b/cosmos_predict1/diffusion/training/module/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..255ab483fd61706c1cda4e8457418ffbff373dbd --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/blocks.py @@ -0,0 +1,1118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from megatron.core import parallel_state +from torch import nn +from transformer_engine.pytorch.attention import apply_rotary_pos_emb + +from cosmos_predict1.diffusion.module.attention import Attention, GPT2FeedForward +from cosmos_predict1.diffusion.training.tensor_parallel import gather_along_first_dim +from cosmos_predict1.utils import log + + +class SDXLTimesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class SDXLTimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.critical( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches and embedding each + patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + - keep_spatio (bool): If True, the spatial dimensions are kept separate in the output tensor, otherwise, they are flattened. Default: False. + - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! The legacy model is for backward compatibility. Default: True. + The output shape of the module depends on the `keep_spatio` flag. If `keep_spatio`=True, the output retains the spatial dimensions. + Otherwise, the spatial dimensions are flattened into a single dimension. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + keep_spatio=False, + legacy_patch_emb: bool = True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + assert keep_spatio, "Only support keep_spatio=True" + self.keep_spatio = keep_spatio + self.legacy_patch_emb = legacy_patch_emb + + if legacy_patch_emb: + self.proj = nn.Conv3d( + in_channels, + out_channels, + kernel_size=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + stride=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + bias=bias, + ) + self.out = Rearrange("b c t h w -> b t h w c") + else: + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class ExtraTokenPatchEmbed(PatchEmbed): + def __init__(self, *args, out_channels: int = 768, keep_spatio: bool = False, **kwargs): + assert keep_spatio, "ExtraTokenPatchEmbed only supports keep_spatio=True" + super().__init__(*args, out_channels=out_channels, keep_spatio=keep_spatio, **kwargs) + self.temporal_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + self.spatial_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + + def forward(self, x): + x_B_T_H_W_C = super().forward(x) + B, T, H, W, C = x_B_T_H_W_C.shape + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.temporal_token.repeat(B, 1, H, W, 1), + ], + dim=1, + ) + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.spatial_token.repeat(B, T, H, 1, 1), + ], + dim=3, + ) + return x_B_T_H_W_C + + +class ExpertChoiceMoEGate(nn.Module): + """ + ExpertChoiceMoEGate determines which tokens go + to which experts (and how much to weigh each expert). + + Args: + hidden_size (int): Dimensionality of input features. + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + capacity: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.router = nn.Parameter(torch.empty((self.num_experts, self.hidden_size))) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.router) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D) + Returns: + gating (Tensor): Gating weights of shape (B, E, C), + where E = num_experts, C = capacity (top-k). + dispatch (Tensor): Dispatch mask of shape (B, E, C, S). + index (Tensor): Indices of top-k tokens for each expert, + shape (B, E, C). + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # token-expert affinity scores + logits = torch.einsum("bsd,de->bse", x, self.router) + affinity = torch.nn.functional.softmax(logits, dim=-1) # (B, S, E) + + # gather topk tokens for each expert + affinity_t = affinity.transpose(1, 2) # (B, E, S) + + # select top-k tokens for each expert + gating, index = torch.topk(affinity_t, k=C, dim=-1) # (B, E, C) + + # one-hot dispatch mask + dispatch = torch.nn.functional.one_hot(index, num_classes=S).float() # (B, E, C, S) + + return gating, dispatch, index + + +class ExpertChoiceMoELayer(nn.Module): + """ + ExpertChoiceMoELayer uses the ExpertChoiceMoEGate to route tokens + to experts, process them, and then combine the outputs. + + Args: + gate_hidden_size (int): Dimensionality of input features. + ffn_hidden_size (int): Dimension of hidden layer in each expert feedforward (e.g., GPT2FeedForward). + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + expert_cls (nn.Module): The class to instantiate each expert. Defaults to GPT2FeedForward. + expert_kwargs (dict): Extra kwargs to pass to each expert class. + """ + + def __init__( + self, + gate_hidden_size: int, + ffn_hidden_size: int, + num_experts: int, + capacity: int, + expert_class: nn.Module = GPT2FeedForward, + expert_kwargs=None, + ): + super().__init__() + if not expert_kwargs: + expert_kwargs = {} + + self.gate_hidden_size = gate_hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.gate = ExpertChoiceMoEGate(gate_hidden_size, num_experts, capacity) + + self.experts = nn.ModuleList( + [expert_class(gate_hidden_size, ffn_hidden_size, **expert_kwargs) for _ in range(num_experts)] + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D). + + Returns: + x_out (Tensor): Output of shape (B, S, D), after dispatching tokens + to experts and combining their outputs. + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # gating: (B, E, C) + # dispatch: (B, E, C, S) + gating, dispatch, index = self.gate(x) + + # collect input tokens for each expert + x_in = torch.einsum("becs,bsd->becd", dispatch, x) + + # process through each expert + expert_outputs = [self.experts[e](x_in[:, e]) for e in range(E)] + + x_e = torch.stack(expert_outputs, dim=1) # (B, E, C, D) + + # gating: (B, E, C), dispatch: (B, E, C, S), x_e: (B, E, C, d) + # x_out: (B, S, D) + # each token is placed back to its location with weighting + x_out = torch.einsum("becs,bec,becd->bsd", dispatch, gating, x_e) + + return x_out + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + if self.sequence_parallel: + x_T_B_HW_D = rearrange(x_BT_HW_D, "(b t) hw d -> t b hw d", b=B, t=T) + x_T_B_HW_D = gather_along_first_dim(x_T_B_HW_D, parallel_state.get_tensor_model_parallel_group()) + x_BT_HW_D = rearrange(x_T_B_HW_D, "t b hw d -> (b t) hw d", b=B) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + def forward_with_memory_save( + self, + x_BT_HW_D_before_gate: torch.Tensor, + x_BT_HW_D_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D_before_gate.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + gate_BT_1_D = repeat(gate_L_B_D, "1 b d -> (b t) 1 d", t=T) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_BT_1_D * _x_before_gate + _x = modulate(self.norm_final(previous_block_out), shift_BT_D, scale_BT_D) + return self.linear(_x) + + return torch.utils.checkpoint.checkpoint(_fn, x_BT_HW_D_before_gate, x_BT_HW_D_skip, use_reentrant=False) + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module supports both self-attention within the video frames and cross-attention + with an external context. It's designed to work with flattened spatial dimensions + to accommodate for video input. + + Attributes: + x_dim (int): Dimensionality of the input feature vectors. + context_dim (Optional[int]): Dimensionality of the external context features. + If None, the attention does not utilize external context. + num_heads (int): Number of attention heads. + bias (bool): If true, bias is added to the query, key, value projections. + x_format (str): The shape format of x tenosor. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + x_format: str = "BTHWD", + n_views: int = 1, + ) -> None: + super().__init__() + self.n_views = n_views + self.x_format = x_format + if self.x_format == "BTHWD": + qkv_format = "bshd" + elif self.x_format == "THWBD": + qkv_format = "sbhd" + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_format=qkv_format, + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + + if self.x_format == "BTHWD": + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x, "b (v t) h w d -> (v b) t h w d", v=self.n_views) + context_B_M_D = rearrange(context, "b (v m) d -> (v b) m d", v=self.n_views) + else: + x_B_T_H_W_D = x + context_B_M_D = context + B, T, H, W, D = x_B_T_H_W_D.shape + x_B_THW_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") + x_B_THW_D = self.attn(x_B_THW_D, context_B_M_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D) + + # reshape it back to video format + x_B_T_H_W_D = rearrange(x_B_THW_D, "b (t h w) d -> b t h w d", h=H, w=W) + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x_B_T_H_W_D, "(v b) t h w d -> b (v t) h w d", v=self.n_views) + return x_B_T_H_W_D + elif self.x_format == "THWBD": + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) + context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) + else: + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) + return x_T_H_W_B_D + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + +def checkpoint_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + DIT Building Block for constructing various types of attention or MLP blocks dynamically based on a specified block type. + + This class instantiates different types of buildig block / attn and MLP based on config, and applies crossponding forward pass during training. + + Attributes: + block_type (str): Type of block to be used ('spatial_sa', 'temporal_sa', 'cross_attn', 'full_attn', 'mlp'). + x_dim (int): Dimensionality of the input features. + context_dim (Optional[int]): Dimensionality of the external context, required for cross attention blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Multiplier for the dimensionality of the MLP hidden layer compared to input. + spatial_win_size (int): Window size for spatial self-attention. + temporal_win_size (int): Window size for temporal self-attention. + bias (bool): Whether to include bias in attention and MLP computations. + mlp_dropout (float): Dropout rate for MLP blocks. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_win_size: int = 1, + temporal_win_size: int = 1, + bias: bool = False, + mlp_dropout: float = 0.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + x_format=self.x_format, + n_views=n_views, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn(x_dim, None, num_heads, bias=bias, x_format=self.x_format) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward_with_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn), "only support VideoAttn impl" + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + if self.block.attn.is_selfattn: + return q, k, v + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + return self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + if self.block.attn.is_selfattn: + q, k, v = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + else: + softmax_attn_output = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + attn_out = self.block.attn.to_out(softmax_attn_output) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_x_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def x_attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + return self.block.attn.to_out(softmax_attn_output) + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + attn_out = torch.utils.checkpoint.checkpoint(x_attn_fn, q, k, v, use_reentrant=False) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_ffn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return self.block.layer1(normalized_x), previous_block_out + + intermediate_output, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, use_reentrant=False + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + return ( + _gate_L_B_D, + torch.utils.checkpoint.checkpoint(_fn2, intermediate_output, use_reentrant=False), + previous_block_out, + ) + + def forward_with_ffn_memory_save_upgrade( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return _fn2(self.block.layer1(normalized_x)), previous_block_out + + output, previous_block_out = torch.utils.checkpoint.checkpoint(_fn, x_before_gate, x_skip, use_reentrant=False) + + return ( + _gate_L_B_D, + output, + previous_block_out, + ) + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + if isinstance(self.block, VideoAttn): + if self.block.attn.is_selfattn: + fn = self.forward_with_attn_memory_save + else: + fn = self.forward_with_x_attn_memory_save + else: + # fn = self.forward_with_ffn_memory_save + fn = self.forward_with_ffn_memory_save_upgrade + return fn( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + if self.x_format == "BTHWD": + shift_B_1_1_1_D, scale_B_1_1_1_D, gate_B_1_1_1_D = ( + shift_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + scale_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + gate_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + ) + if self.block_type in ["spatial_sa", "temporal_sa", "window_attn", "ssa", "tsa", "wa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + crossattn_emb, + crossattn_mask, + ) + elif self.block_type in ["mlp", "ff"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + elif self.x_format == "THWBD": + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + This class is a wrapper for a list of DITBuildingBlock. + It's not essential, refactor it if needed. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + use_checkpoint: bool = False, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + window_sizes, + spatial_attn_win_size, + temporal_attn_win_size, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=n_views, + ) + ) + self.use_checkpoint = use_checkpoint + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + else: + return self._forward( + x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, extra_per_block_pos_emb + ) + + def _forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x + + def set_memory_save(self, mode: bool = True): + # (qsh) to make fsdp happy! + #! IMPORTANT! + if mode: + self.forward = self.forward_with_memory_save + for block in self.blocks: + block.forward = block.forward_with_memory_save + else: + raise NotImplementedError("Not implemented yet.") + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + for block in self.blocks: + gate_L_B_D, x_before_gate, x_skip = block.forward( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + extra_per_block_pos_emb = None + return gate_L_B_D, x_before_gate, x_skip diff --git a/cosmos_predict1/diffusion/training/module/position_embedding.py b/cosmos_predict1/diffusion/training/module/position_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..83625e8a2c6e59352c2786e9fbd699c7c13e2a36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/position_embedding.py @@ -0,0 +1,932 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.diffusion.module.attention import normalize +from cosmos_predict1.diffusion.module.timm import trunc_normal_ +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size_h, + grid_size_w, + grid_size_t, + spatial_interpolation_scale, + temporal_interpolation_scale, + concat=True, +): + grid_h = np.arange(grid_size_h, dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(grid_size_w, dtype=np.float32) / spatial_interpolation_scale + grid_t = np.arange(grid_size_t, dtype=np.float32) / temporal_interpolation_scale + + grid = np.meshgrid(grid_w, grid_h, grid_t, indexing="ij") + grid = np.stack(grid, axis=0) + grid = grid.reshape(3, 1, grid_size_h, grid_size_w, grid_size_t) + + if concat: + per_axis = embed_dim // 3 + per_axis = (per_axis // 2) * 2 # make it even (for sin/cos split) + dim_h, dim_w = per_axis, per_axis + dim_t = embed_dim - dim_h - dim_w + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, grid[0]) # (H*W, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, grid[1]) # (H*W, D/3) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, grid[2]) # (H*W, D/3) + + return np.concatenate([emb_h, emb_w, emb_t], axis=1) # (H*W*T, D) + else: + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[0]) # (H*W) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[1]) # (H*W) + emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[2]) # (H*W) + + return emb_h + emb_w + emb_t # (H*W*T, D) + + +class VideoPositionEmb(nn.Module): + def __init__(self): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, VideoRopePosition3DEmb): + seq_dim = 0 + else: + seq_dim = 1 + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class SinCosPosEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + is_learnable: bool = False, + interpolation: Literal["crop", "resize", "crop_resize"] = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + init_length_for_resize: int = 16, + **kwargs, + ): + """ + Args: + interpolation (str): "crop", "resize", "crop_resize". "crop" means we crop the positional embedding to the length of the input sequence. "resize" means we resize the positional embedding to the length of the input sequence. "crop_resize" (inference only) means we first crop the positional embedding to init_length_for_resize, then resize it to the length of the input sequence. + init_length_for_resize (int): used when interpolation is "crop_resize", where we "resize" embedding during inference for model trained with "crop". We first "crop" the pos_embed to this length (used during training), then run the "resize", default 16 + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.init_length_for_resize = init_length_for_resize + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + if self.interpolation == "crop_resize": + pos_embed_crop = self.pos_embed[:, : self.init_length_for_resize, :H, :W] # B,T,H,W,C + _, t, h, w, c = pos_embed_crop.shape + + pos_embed_crop_resize_t = rearrange( + F.interpolate( + rearrange(pos_embed_crop, "1 t h w c -> 1 (c h w) t"), + size=(T), + mode="linear", + ), + "1 (c h w) t -> 1 t h w c", + c=c, + h=h, + w=w, + ) + pos_embed_crop_resize = rearrange( + F.interpolate( + rearrange(pos_embed_crop_resize_t, "1 t h w c -> 1 (c t) h w"), + size=(H, W), + mode="bilinear", + ), + "1 (c t) h w -> 1 t h w c", + c=c, + ) + return pos_embed_crop_resize + + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class SinCosPosEmb_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + is_learnable: bool = False, + interpolation: str = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + if self.interpolation == "crop": + param = get_3d_sincos_pos_embed( + model_channels, + len_h, + len_w, + len_t * int(max_fps / min_fps), + spatial_interpolation_scale, + temporal_interpolation_scale, + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnableEmb3D(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.pos_embed = nn.Parameter(torch.zeros(1, len_t, len_h, len_w, model_channels)) + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnableEmb3D_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + + if self.interpolation == "crop": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t * int(max_fps / min_fps), len_h, len_w, model_channels) + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t, len_h, len_w, model_channels) + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class VideoRopePositionEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(len_h * len_w * len_t, dtype=torch.float)) + + self.register_buffer( + "dim_range", torch.arange(0, head_dim, 2)[: (head_dim // 2)].float().cuda() / head_dim, persistent=False + ) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], ntk_factor: float = 1.0): + theta = 10000.0 * ntk_factor + + # original_dtype = self.dim_range.dtype + freq = 1.0 / (theta ** self.dim_range.float()) + _, T, H, W, _ = B_T_H_W_C + length = T * H * W + emb_L_D = torch.outer(self.seq[:length], freq) + return rearrange(torch.cat([emb_L_D, emb_L_D], dim=-1), "l d -> l 1 1 d").float() + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.max_t = len_t + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + self._dim_h = dim_h + self._dim_t = dim_t + + def reset_parameters(self) -> None: + if self.dim_spatial_range.device == torch.device("meta"): + return + + dim_h = self._dim_h + dim_t = self._dim_t + + self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) + + self.dim_spatial_range = ( + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h + ) + self.dim_temporal_range = ( + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t + ) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class SinCosPosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H), + ], + dim=-1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) + self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) + self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def reset_parameters(self): + if self.pos_emb_h.device == torch.device("meta"): + return + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class MultiviewVideoPositionEmb(nn.Module): + def __init__( + self, + ): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + seq_dim = 1 + embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() + # rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() + else: + seq_dim = 1 + embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) + else: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class MultiviewVideoRopePosition3DEmb(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.n_views = n_views + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embedding_for_batch( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert uniform_fps # only support uniform fps now + + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return em_T_H_W_D + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. The camera view dimension is merged in the T dimension + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + + B, T, H, W, C = B_T_H_W_C + + single_view_B_T_H_W_C = (B, T // self.n_views, H, W, C) + em_T_H_W_D = torch.cat( + [ + self.generate_embedding_for_batch( + single_view_B_T_H_W_C, + fps=fps, + h_ntk_factor=h_ntk_factor, + w_ntk_factor=w_ntk_factor, + t_ntk_factor=t_ntk_factor, + ) + for item in range(self.n_views) + ], + dim=0, + ) + + return em_T_H_W_D + # return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class MultiviewSinCosPosEmbAxis(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + self.n_views = n_views + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + single_view_T = T // self.n_views + + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:single_view_T] + emb = torch.cat( + [ + torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_view_T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_view_T, h=H), + ], + dim=-1, + ) + for _ in range(self.n_views) + ], + 1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") diff --git a/cosmos_predict1/diffusion/training/module/pretrained_vae.py b/cosmos_predict1/diffusion/training/module/pretrained_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a925b9ecaf7f6b952c1a3edd24236962ebbc74 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/pretrained_vae.py @@ -0,0 +1,805 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch.nn.modules import Module + +from cosmos_predict1.diffusion.training.module.pretrained_vae_base import JITVAE, BaseVAE, StateDictVAE +from cosmos_predict1.utils import log + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + @property + def is_chunk_overlap(self): + return False + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> None: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) + super(BasePretrainedVideoTokenizer, self).__init__(enc_fp, dec_fp, name, mean_std_fp, latent_ch, False, is_bf16) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class VideoStateDictTokenizer(BasePretrainedVideoTokenizer, StateDictVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from state_dict checkpoint + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + vae: torch.nn.Module, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) + super(BasePretrainedVideoTokenizer, self).__init__( + enc_fp, dec_fp, vae, name, mean_std_fp, latent_ch, is_image=False, is_bf16=is_bf16 + ) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class VideoJITVAEChunkWiseTokenizer(VideoJITTokenizer): + """ + Do temporal chunk wise encoding and decoding. + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + spatial_compression_factor: int, + latent_ch: int = 16, + is_bf16: bool = True, + full_duration: int = 121, + chunk_duration: int = 49, + temporal_compression_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution="720", + overlap_size: int = 9, + ): + self._latent_chunk_duration = ( + chunk_duration - 1 + ) // temporal_compression_factor + 1 # need to set before super init + self._latent_full_duration = (full_duration - 1) // temporal_compression_factor + 1 + super().__init__( + enc_fp=enc_fp, + dec_fp=dec_fp, + name=name, + mean_std_fp=mean_std_fp, + latent_ch=latent_ch, + is_bf16=is_bf16, + pixel_chunk_duration=chunk_duration, + temporal_compression_factor=temporal_compression_factor, + max_enc_batch_size=max_enc_batch_size, + max_dec_batch_size=max_dec_batch_size, + spatial_resolution=spatial_resolution, + spatial_compression_factor=spatial_compression_factor, + ) + self.overlap_size = overlap_size + self.full_duration = full_duration + # make sure full_duration is divisible by chunk_duration with pre-set overlap size + assert (full_duration - overlap_size) % (chunk_duration - overlap_size) == 0 + + @property + def latent_chunk_duration(self) -> int: + return self._latent_chunk_duration + + @property + def latent_full_duration(self) -> int: + return self._latent_full_duration + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.full_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.full_duration}" + return num_pixel_frames // self.full_duration * self.latent_full_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_full_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_full_duration}" + return num_latent_frames // self.latent_full_duration * self.full_duration + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + # This is a hack impl, should be improved later + return state + + def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: + # This is a hack impl, should be improved later + return latent + + def _impl_encode(self, state: torch.Tensor) -> torch.Tensor: + in_dtype = state.dtype + + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + + assert state.shape[2] == self.full_duration + + # Calculate the number of overlapping windows/chunks + # Each window has a duration of self.pixel_chunk_duration frames + # The overlap between consecutive windows is self.overlap_size frames + num_windows = (T - self.pixel_chunk_duration) // (self.pixel_chunk_duration - self.overlap_size) + # Calculate the total number of frames covered by the windows + num_windowed_frames = self.pixel_chunk_duration + num_windows * (self.pixel_chunk_duration - self.overlap_size) + + assert num_windowed_frames == T # only handle case where number frames can be separated equally + # Prepare a list to hold overlapping chunks of the input state + pack_list = [state[:, :, : self.pixel_chunk_duration]] + [ + state[ + :, + :, + (ii + 1) + * (self.pixel_chunk_duration - self.overlap_size) : (ii + 1) + * (self.pixel_chunk_duration - self.overlap_size) + + self.pixel_chunk_duration, + ] + for ii in range(num_windows) + ] + + latent = self._impl_encode(torch.cat(pack_list, dim=0)) + latent = rearrange(latent, "(n b) c t h w -> n b c t h w", b=B) + # Calculate the overlap size in the latent space, accounting for any temporal compression + # For example, if the network downsamples temporally by a factor of 4, adjust the overlap accordingly + overlap_latent = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Concatenate the latent representations from each chunk/window + # For the first chunk, include all latent frames + # For subsequent chunks, exclude the overlapping latent frames at the beginning + out = torch.cat([latent[0]] + [latent[i, :, :, overlap_latent:] for i in range(1, len(latent))], dim=2) + return out + + @torch.no_grad() + def maybe_pad_latent(self, latent: torch.Tensor) -> tuple[torch.Tensor, int]: + """Since the decoder expect the latent to be window_size + (window_size - decode_overlap_size) * N, we need to pad the latent to match the expected size + Args: + latent (torch.Tensor): [B, C, T, H, W] + Returns: + latent: torch.Tensor, the padded latent + padding_t: int, the number of padding latent t + """ + + # Calculate the overlap size and window size in the latent space, considering any temporal compression + decode_overlap_size = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Calculate the number of windows/chunks for decoding + window_size = (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + B, C, current_latent_t, H, W = latent.shape + + if current_latent_t < window_size: + # If the current latent tensor is smaller than the window size, pad it to the window size + target_latent_t = window_size + else: + # Calculate the target latent frame number for decoding + target_latent_t = window_size + math.ceil( + (current_latent_t - window_size) / (window_size - decode_overlap_size) + ) * (window_size - decode_overlap_size) + + padding_t = target_latent_t - current_latent_t + if padding_t != 0: + log.info( + f"Padding latent from {current_latent_t} to {target_latent_t} for decoding purpose. current window_size: {window_size}, decode_overlap_size: {decode_overlap_size}" + ) + padding = latent.new_zeros(B, C, padding_t, H, W) + latent = torch.cat([latent, padding], dim=2).contiguous() + return latent, padding_t + + @torch.no_grad() + def decode(self, state: torch.Tensor) -> torch.Tensor: + state, padding_t = self.maybe_pad_latent(state) + B, C, num_latents, H, W = state.shape + + # Calculate the overlap size and window size in the latent space, considering any temporal compression + decode_overlap_size = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Calculate the number of windows/chunks for decoding + window_size = (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + num_windows = (num_latents - window_size) // (window_size - decode_overlap_size) + 1 + decoded_frames = [] + # Start decoding with the initial window of latent frames + current_state = state[:, :, :window_size] + for i in range(num_windows): + # Decode the current window to get the reconstructed frames + window_frames = super().decode(current_state) + decoded_frames.append(window_frames) + # Re-encode the overlapping frames at the end of the decoded window to obtain the last latent frame + # This is necessary due to the casual first frame design + last_latent = self._impl_encode(window_frames[:, :, -self.overlap_size : -self.overlap_size + 1])[:, :, 0:1] + # Calculate the start and end indices for the next chunk of latent frames + start_idx = window_size + i * (window_size - decode_overlap_size) - decode_overlap_size + 1 + end_idx = start_idx + window_size - 1 + # Prepare the next state by concatenating the last latent frame with the next chunk of latent frames + current_state = torch.cat([last_latent, state[:, :, start_idx:end_idx]], dim=2) + # Remove overlapping frames (e.g., 17 frames) from all windows except the first one. + for i in range(1, num_windows): + decoded_frames[i] = decoded_frames[i][:, :, self.overlap_size :] + video_tensor = torch.cat(decoded_frames, dim=2) + return video_tensor + + @property + def is_chunk_overlap(self): + return True + + +class DebugMeanStdVideoJITVAE(VideoJITTokenizer): + """ + A class for one + """ + + def register_mean_std(self, mean_std_fp: str) -> None: + target_shape = [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + # latent_mean.to(self.dtype).reshape(*target_shape), + torch.zeros(*target_shape, dtype=self.dtype), + persistent=False, + ) + self.register_buffer( + "latent_std", + torch.ones(*target_shape, dtype=self.dtype), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return JITVAE.encode(self, state) + return super().encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + if T == 1: + return JITVAE.decode(self, latent) + return super().decode(latent) + + +class DebugMeanStdVideoJITVAEChunkWiseTokenizer(VideoJITVAEChunkWiseTokenizer): + def register_mean_std(self, mean_std_fp: str) -> None: + target_shape = [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + # latent_mean.to(self.dtype).reshape(*target_shape), + torch.zeros(*target_shape, dtype=self.dtype), + persistent=False, + ) + self.register_buffer( + "latent_std", + torch.ones(*target_shape, dtype=self.dtype), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return JITVAE.encode(self, state) + return super().encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + if T == 1: + return JITVAE.decode(self, latent) + return super().decode(latent) + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.image_vae.reset_dtype() + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + self.image_vae.encoder = self.video_vae.encoder + self.image_vae.decoder = self.video_vae.decoder + + +class JointImageVideoStateDictTokenizer(JointImageVideoTokenizer): + """ + Copy of ImageVideoVAE1 that uses plain torch.nn.Module instead of JITed one so + that it can be used witch torch.compile() + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + + assert isinstance(image_vae, StateDictVAE) + assert isinstance(video_vae, VideoStateDictTokenizer) + # a hack to make the image_vae and video_vae share the same encoder and decoder + + # nn.Module + del self.image_vae.vae + # Just method + del self.image_vae.encoder + # Just method + del self.image_vae.decoder + + self.image_vae.vae = self.video_vae.vae + self.image_vae.encoder = self.video_vae.encoder + self.image_vae.decoder = self.video_vae.decoder + + +class DummyJointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + name: str = "dummy_joint_image_video", + pixel_ch: int = 3, + latent_ch: int = 16, + pixel_chunk_duration: int = 17, + latent_chunk_duration: int = 3, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + spatial_resolution: str = "720", + ): + self.pixel_ch = pixel_ch + self._pixel_chunk_duration = pixel_chunk_duration + self._latent_chunk_duration = latent_chunk_duration + self._spatial_compression_factor = spatial_compression_factor + self._temporal_compression_factor = temporal_compression_factor + self._spatial_resolution = spatial_resolution + super().__init__(latent_ch, name) + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self._temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self._latent_chunk_duration + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + state_B_T_C_H_W = F.interpolate( + rearrange(state, "b c t h w -> b t c h w"), + size=(self.latent_ch, H // self.spatial_compression_factor, W // self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + return rearrange(state_B_T_C_H_W, "b t c h w -> b c t h w").contiguous() + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + num_frames = T // self.pixel_chunk_duration * self.latent_chunk_duration + + state_B_C_T_H_W = F.interpolate( + state, + size=(self.latent_ch, H // self.spatial_compression_factor, W // self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + state_B_H_W_T_C = rearrange(state_B_C_T_H_W, "b c t h w -> b h w t c") + state_B_H_W_T_C = F.interpolate( + state_B_H_W_T_C, + size=(W // self.spatial_compression_factor, num_frames, self.latent_ch), + mode="trilinear", + align_corners=False, + ) + return rearrange(state_B_H_W_T_C, "b h w t c -> b c t h w").contiguous() + + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + latent_B_T_C_H_W = F.interpolate( + rearrange(latent, "b c t h w -> b t c h w"), + size=(self.pixel_ch, H * self.spatial_compression_factor, W * self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + return rearrange(latent_B_T_C_H_W, "b t c h w -> b c t h w").contiguous() + + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + num_frames = T * self.pixel_chunk_duration // self.latent_chunk_duration + + latent_B_H_W_T_C = rearrange(latent, "b c t h w -> b h w t c") + latent_B_H_W_T_C = F.interpolate( + latent_B_H_W_T_C, + size=(W * self.spatial_compression_factor, num_frames, self.pixel_ch), + mode="trilinear", + align_corners=False, + ) + latent_B_C_T_H_W = rearrange(latent_B_H_W_T_C, "b h w t c -> b c t h w") + + state = F.interpolate( + latent_B_C_T_H_W, + size=(num_frames, H * self.spatial_compression_factor, W * self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + + return state.contiguous() diff --git a/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py b/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py new file mode 100644 index 0000000000000000000000000000000000000000..08570cbdcc9b8d9875ee907a28001aa5193d8fc5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + +from cosmos_predict1.utils.distributed import rank0_first +from cosmos_predict1.utils.misc import load_from_s3_with_cache + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.mean_std_fp = mean_std_fp + self.name = name + + self.backend_args = None + + self.register_mean_std(mean_std_fp) + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = torch.jit.load(enc_fp, map_location="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = torch.jit.load(dec_fp, map_location="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class StateDictVAE(BasePretrainedImageVAE): + """ + A Variational Autoencoder (VAE) that loads pre-trained weights into + provided encoder and decoder components from a remote store, handles data type conversions, + and normalization using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The encoder with weights loaded from storage. + decoder (Module): The decoder with weights loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + vae (Module): Instance of VAE with not loaded weights + name (str): Name of the model, used for differentiating cache file paths. + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + vae: torch.nn.Module, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + + self.load_encoder_and_decoder(enc_fp, dec_fp, vae) + + def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: + """ + Load the encoder from the remote store. + + Args: + - vae_fp (str): File path to the vae's state dict file on the remote store. + - vae (str): VAE module into which weights will be loaded. + """ + state_dict_enc = load_from_s3_with_cache( + enc_fp, + f"vae/{self.name}_enc.jit", + easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, + backend_args=self.backend_args, + ) + + state_dict_dec = load_from_s3_with_cache( + dec_fp, + f"vae/{self.name}_dec.jit", + easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, + backend_args=self.backend_args, + ) + + jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() + jit_weights_state_dict = { + k: v + for k, v in jit_weights_state_dict.items() + # Global variables captured by JIT + if k + not in ( + "encoder.patcher.wavelets", + "encoder.patcher._arange", + "decoder.unpatcher.wavelets", + "decoder.unpatcher._arange", + ) + } + + vae.load_state_dict(jit_weights_state_dict) + vae.eval() + for param in vae.parameters(): + param.requires_grad = False + vae.to(self.dtype) + + self.vae = vae + self.encoder = self.vae.encode + self.decoder = self.vae.decode + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.vae.to(self.dtype) + + +class SDVAE(BaseVAE): + def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: + super().__init__(channel=4, name="sd_vae") + self.dtype = torch.bfloat16 + self.register_buffer( + "scale", + torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), + persistent=False, + ) + self.register_buffer( + "bias", + -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, + persistent=False, + ) + self.batch_size = batch_size + self.count_std = count_std + self.is_downsample = is_downsample + self.load_vae() + self.reset_dtype() + + def reset_dtype(self, *args, **kwargs): + del args, kwargs + self.vae.to(self.dtype) + + @rank0_first + def load_vae(self) -> None: + os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" + os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + import diffusers + + vae_name = "stabilityai/sd-vae-ft-mse" + try: + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) + except: # noqa: E722 + # Could not load the model from cache; try without local_files_only. + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) + self.vae = vae.eval().requires_grad_(False) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + state : pixel range [-1, 1] + """ + if self.is_downsample: + _h, _w = state.shape[-2:] + state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) + in_dtype = state.dtype + state = state.to(self.dtype) + state = (state + 1.0) / 2.0 + latent_dist = self.vae.encode(state)["latent_dist"] + mean, std = latent_dist.mean, latent_dist.std + if self.count_std: + latent = mean + torch.randn_like(mean) * std + else: + latent = mean + latent = latent * self.scale + latent = latent + self.bias + return latent.to(in_dtype) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + in_dtype = latent.dtype + latent = latent.to(self.dtype) + latent = latent - self.bias + latent = latent / self.scale + latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) + if self.is_downsample: + _h, _w = latent.shape[-2:] + latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) + return latent.to(in_dtype) * 2 - 1.0 + + @property + def spatial_compression_factor(self) -> int: + return 8 diff --git a/cosmos_predict1/diffusion/training/modules/edm_sde.py b/cosmos_predict1/diffusion/training/modules/edm_sde.py new file mode 100644 index 0000000000000000000000000000000000000000..3d08a8229f03c9fdd6a8d905ad4543fe5fe5238a --- /dev/null +++ b/cosmos_predict1/diffusion/training/modules/edm_sde.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from statistics import NormalDist + +import numpy as np +import torch + + +class EDMSDE: + def __init__( + self, + p_mean: float = -1.2, + p_std: float = 1.2, + sigma_max: float = 80.0, + sigma_min: float = 0.002, + ): + self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + self.sigma_max = sigma_max + self.sigma_min = sigma_min + + def sample_t(self, batch_size: int) -> torch.Tensor: + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + return torch.exp(log_sigma) + + def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """This is trivial in the base class, but may be used by derived classes in a more interesting way""" + return x0, sigma diff --git a/cosmos_predict1/diffusion/training/networks/general_dit.py b/cosmos_predict1/diffusion/training/networks/general_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6fb610208ea6c108bd951818ea1a74b760715e --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit.py @@ -0,0 +1,1022 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies +""" + +from collections.abc import Container +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.module.attention import get_normalization +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.module.blocks import ( + DITBuildingBlock, + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + SDXLTimestepEmbedding, + SDXLTimesteps, +) +from cosmos_predict1.diffusion.training.module.position_embedding import ( + LearnableEmb3D, + LearnableEmb3D_FPS_Aware, + LearnablePosEmbAxis, + SinCosPosEmb, + SinCosPosEmb_FPS_Aware, + SinCosPosEmbAxis, + VideoRopePosition3DEmb, + VideoRopePositionEmb, +) +from cosmos_predict1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim +from cosmos_predict1.utils import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + Attributes: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple of int): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means + full attention, cross attention, and MLP in sequence in one transformer block. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of residual blocks per resolution in the transformer. + num_heads (int): Number of heads in the multi-head self-attention layers. + spatial_attn_win_size (int): Window size for the spatial attention mechanism. + temporal_attn_win_size (int): Window size for the temporal attention mechanism. + mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. + use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) + use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. + crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. + use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. + pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). + pos_emb_learnable (bool): Specifies if positional embeddings are learnable. + pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. + block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. + legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. + rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. + rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. + rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. + Note: + block_config support block type: + * spatial_sa, ssa: spatial self attention + * temporal_sa, tsa: temporal self attention + * cross_attn, ca: cross attention + * full_attn: full attention on all flatten tokens + * mlp, ff: feed forward block + * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. + + Example: + >>> # full attention, cross attention, and MLP + >>> option1_block_config = 'FA-CA-MLP' + >>> model_1 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option1_block_config + ) + >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' + >>> model_2 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option2_block_config + ) + >>> # option3 model + >>> model_3 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=2, + block_config=option2_block_config + ) + >>> # Process input tensor through the model + >>> output = model(input_tensor) + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + window_block_indexes: list = [], # index for window attention block + window_sizes: list = [], # window size for window attention block in the order of T, H, W + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + mlp_ratio: float = 4.0, + use_memory_save: bool = False, + use_checkpoint: bool = False, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, # 1 for getty video + max_fps: int = 30, # 120 for getty video but let's use 30 + additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + layer_mask: list = None, # whether or not a layer is used. For controlnet encoder + legacy_patch_emb: bool = True, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "learnable", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.additional_timestamp_channels = additional_timestamp_channels + self.affline_emb_norm = affline_emb_norm + self.legacy_patch_emb = legacy_patch_emb + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.cp_group = None + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + SDXLTimesteps(model_channels), + SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + self.block_config = block_config + self.use_memory_save = use_memory_save + self.use_checkpoint = use_checkpoint + + assert ( + len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" + ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" + + layer_mask = [False] * num_blocks if layer_mask is None else layer_mask + assert ( + len(layer_mask) == num_blocks + ), f"Layer mask length {len(layer_mask)} does not match num_blocks {num_blocks}" + for idx in range(num_blocks): + if layer_mask[idx]: + continue + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + window_sizes=( + window_sizes if idx in window_block_indexes else [] + ), # There will be bug if using "WA-CA-MLP" + mlp_ratio=mlp_ratio, + spatial_attn_win_size=spatial_attn_win_size, + temporal_attn_win_size=temporal_attn_win_size, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + use_checkpoint=use_checkpoint, + ) + + self.build_decode_head() + self.build_additional_timestamp_embedder() + if self.affline_emb_norm: + log.critical("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.init_weights() + + if self.use_memory_save: + log.critical("Using checkpointing to save memory! only verified in 14B base model training!") + for block in self.blocks.values(): + block.set_memory_save() + + def init_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Tensor parallel + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + self.initialize_tensor_parallel_weights() + + def initialize_tensor_parallel_weights(self): + """ + Initialize weights for tensor parallel layers. + + This function performs the following steps: + 1. Retrieves the tensor parallel rank. + 2. Saves the current random state. + 3. Sets a new random seed based on the tensor parallel rank. + 4. Initializes weights for attention and MLP layers in each block. + 5. Restores the original random state. + + The use of different random seeds for each rank ensures + unique initializations across parallel processes. + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + # Save the current random state + rng_state = torch.get_rng_state() + + # Set a new random seed based on the tensor parallel rank + torch.manual_seed(tp_rank) + + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + # Initialize weights for attention layers + torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) + elif layer.block_type in ["mlp", "ff"]: + # Initialize weights for MLP layers + torch.nn.init.xavier_uniform_(layer.block.layer1.weight) + torch.nn.init.xavier_uniform_(layer.block.layer2.weight) + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + # Restore the original random state + torch.set_rng_state(rng_state) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_additional_timestamp_embedder(self): + if self.additional_timestamp_channels: + self.additional_timestamp_embedder = nn.ModuleDict() + for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): + log.critical( + f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" + ) + self.additional_timestamp_embedder[cond_name] = nn.Sequential( + SDXLTimesteps(cond_emb_channels), + SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), + ) + + def prepare_additional_timestamp_embedder(self, **kwargs): + condition_concat = [] + + for cond_name, embedder in self.additional_timestamp_embedder.items(): + condition_concat.append(embedder(kwargs[cond_name])[0]) + embedding = torch.cat(condition_concat, dim=1) + if embedding.shape[1] < self.model_channels: + embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) + return embedding + + def build_pos_embed(self): + if self.pos_emb_cls == "sincos": + cls_type = SinCosPosEmb + elif self.pos_emb_cls == "learnable": + cls_type = LearnableEmb3D + elif self.pos_emb_cls == "sincos_fps_aware": + cls_type = SinCosPosEmb_FPS_Aware + elif self.pos_emb_cls == "learnable_fps_aware": + cls_type = LearnableEmb3D_FPS_Aware + elif self.pos_emb_cls == "rope": + cls_type = VideoRopePositionEmb + elif self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward_blocks_regular( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + features = [] + for name, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + # Extract features + block_idx = int(name.split("block")[-1]) + if block_idx in feature_indices: + B, C, T, H, W = original_shape + H = H // self.patch_spatial + W = W // self.patch_spatial + T = T // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x_feat + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + features.append(x_B_T_H_W_D) + + if x_ctrl is not None and name in x_ctrl: + x = x + x_ctrl[name] + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward_blocks_memory_save( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + x_before_gate = 0 + x_skip = rearrange(x, "T H W B D -> (T H W) B D") + assert self.blocks["block0"].x_format == "THWBD" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") + else: + extra_per_block_pos_emb = None + gate_L_B_D = 1.0 + + features = [] + for name, block in self.blocks.items(): + gate_L_B_D, x_before_gate, x_skip = block( + x_before_gate, + x_skip, + gate_L_B_D, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_per_block_pos_emb, + ) + + # Extract features. + # Convert the block index in the memory save mode to the block index in the regular mode. + block_idx = int(name.split("block")[-1]) - 1 + if block_idx in feature_indices: + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + H = H_before_patchify // self.patch_spatial + W = W_before_patchify // self.patch_spatial + T = T_before_patchify // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x_skip + x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) + + features.append(x_B_T_H_W_D) + + new_name = f"block{block_idx}" + if x_ctrl is not None and new_name in x_ctrl: + x_ctrl_ = x_ctrl[new_name] + x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") + x_skip = x_skip + x_ctrl_ + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + x_THW_B_D_before_gate = x_before_gate + x_THW_B_D_skip = x_skip + + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + x_BT_HW_D_before_gate = rearrange( + x_THW_B_D_before_gate, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + x_BT_HW_D_skip = rearrange( + x_THW_B_D_skip, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + + x_BT_HW_D = self.final_layer.forward_with_memory_save( + x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, + x_BT_HW_D_skip=x_BT_HW_D_skip, + gate_L_B_D=gate_L_B_D, + emb_B_D=affline_emb_B_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + if self.use_memory_save: + return self.forward_blocks_memory_save( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + @property + def fsdp_wrap_block_cls(self): + return DITBuildingBlock + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + self.pos_embedder.enable_context_parallel(cp_group) + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.enable_context_parallel(cp_group) + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + layer.block.attn.attn_op.cp_stream = None + + log.debug("[CP] Disable context parallelism.") + + def enable_sequence_parallel(self): + self._set_sequence_parallel(True) + + def disable_sequence_parallel(self): + self._set_sequence_parallel(False) + + def _set_sequence_parallel(self, status: bool): + self.sequence_parallel = status + self.final_layer.sequence_parallel = status + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + layer.block.attn.to_q[0].sequence_parallel = status + layer.block.attn.to_k[0].sequence_parallel = status + layer.block.attn.to_v[0].sequence_parallel = status + layer.block.attn.to_out[0].sequence_parallel = status + layer.block.attn.attn_op.sequence_parallel = status + elif layer.block_type in ["mlp", "ff"]: + layer.block.layer1.sequence_parallel = status + layer.block.layer2.sequence_parallel = status + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_action.py b/cosmos_predict1/diffusion/training/networks/general_dit_action.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd4db422844da3b422512096e58836c93117e33 --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_action.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies +""" + +from collections.abc import Container +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.diffusion.module.timm import Mlp +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.training.tensor_parallel import scatter_along_first_dim +from cosmos_predict1.utils import log + + +class ActionConditionalGeneralDIT(GeneralDIT): + """ + ActionConditionalGeneralDIT is a subclass of GeneralDIT that take `action` as condition. + Action embedding is would be added to timestep embedding. + """ + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + action=action, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + if self.use_memory_save: + return self.forward_blocks_memory_save( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + +class ActionConditionalVideoExtendGeneralDIT(ActionConditionalGeneralDIT): + """ + ActionConditionalVideoExtendGeneralDIT is a subclass of ActionConditionalGeneralDIT that take `action` as condition. + Action embedding is would be added to timestep embedding. + """ + + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + assert hasattr(self, "model_channels"), "model_channels attribute is missing" + self.action_embedder_B_D = Mlp( + in_features=7, + hidden_features=self.model_channels * 4, + out_features=self.model_channels, + act_layer=lambda: nn.GELU(approximate="tanh"), + drop=0, + ) + self.action_embedder_B_3D = Mlp( + in_features=7, + hidden_features=self.model_channels * 4, + out_features=self.model_channels * 3, + act_layer=lambda: nn.GELU(approximate="tanh"), + drop=0, + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + # log.critical(f"hit video case, video_cond_bool: {video_cond_bool}, condition_video_indicator: {condition_video_indicator.flatten()}, condition_video_input_mask: {condition_video_input_mask.shape}, {condition_video_input_mask[:,:,:,0,0]}", rank0_only=False) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + action=action, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + # Add action conditioning + assert action is not None, "Action is required for action-conditional training" + if action is not None: + action = action[:, 0, :] # Since we are now training on 1 frame, we only need the first frame action. + action_embedding_B_D = self.action_embedder_B_D(action) + action_embedding_B_3D = self.action_embedder_B_3D(action) + timesteps_B_D = timesteps_B_D + action_embedding_B_D + adaln_lora_B_3D = adaln_lora_B_3D + action_embedding_B_3D + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py b/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py new file mode 100644 index 0000000000000000000000000000000000000000..3d344e9b49dbab908611354d8c9f76d4d006bc36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.module.blocks import SDXLTimestepEmbedding, SDXLTimesteps +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.training.tensor_parallel import scatter_along_first_dim +from cosmos_predict1.utils import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + SDXLTimesteps(self.model_channels), + SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def init_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().init_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + # log.critical(f"hit video case, video_cond_bool: {video_cond_bool}, condition_video_indicator: {condition_video_indicator.flatten()}, condition_video_input_mask: {condition_video_input_mask.shape}, {condition_video_input_mask[:,:,:,0,0]}", rank0_only=False) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py b/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2641d62a41cedd22ffab3949f531ba743adcae --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log + + +class VideoExtendMultiviewGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py b/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..62b16e37a50ec959373f91970bb11209e983710a --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py @@ -0,0 +1,460 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.module.blocks import GeneralDITTransformerBlock, PatchEmbed +from cosmos_predict1.diffusion.training.module.position_embedding import ( + MultiviewSinCosPosEmbAxis, + MultiviewVideoRopePosition3DEmb, +) +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewGeneralDIT(GeneralDIT): + def __init__( + self, + *args, + n_views: int = 3, + view_condition_dim: int = 3, + traj_condition_dim: int = 0, + concat_view_embedding: bool = True, + concat_traj_embedding: bool = False, + add_repeat_frame_embedding: bool = False, + **kwargs, + ): + self.n_views = n_views + self.view_condition_dim = view_condition_dim + self.concat_view_embedding = concat_view_embedding + self.traj_condition_dim = traj_condition_dim + self.concat_traj_embedding = concat_traj_embedding + self.add_repeat_frame_embedding = add_repeat_frame_embedding + + super().__init__(*args, **kwargs) + # reinit self.blocks + del self.blocks + self.blocks = nn.ModuleDict() + + layer_mask = [False] * self.num_blocks if kwargs["layer_mask"] is None else kwargs["layer_mask"] + assert ( + len(layer_mask) == self.num_blocks + ), f"Layer mask length {len(layer_mask)} does not match num_blocks {self.num_blocks}" + for idx in range(self.num_blocks): + if layer_mask[idx]: + continue + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=self.model_channels, + context_dim=kwargs["crossattn_emb_channels"], + num_heads=self.num_heads, + block_config=self.block_config, + window_sizes=( + kwargs["window_sizes"] if idx in kwargs["window_block_indexes"] else [] + ), # There will be bug if using "WA-CA-MLP" + mlp_ratio=kwargs["mlp_ratio"], + spatial_attn_win_size=kwargs["spatial_attn_win_size"], + temporal_attn_win_size=kwargs["temporal_attn_win_size"], + x_format=self.block_x_format, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + n_views=self.n_views, + ) + self.view_embeddings = nn.Embedding(n_views, view_condition_dim) # Learnable embedding layer + + if self.concat_traj_embedding: + self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer + if self.add_repeat_frame_embedding: + self.repeat_frame_embedding = nn.Linear(1, view_condition_dim) # Learnable embedding layer + + self.init_weights() + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + view_condition_dim, + traj_condition_dim, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.view_condition_dim, + self.traj_condition_dim, + ) + if self.concat_view_embedding: + in_channels = in_channels + view_condition_dim if view_condition_dim > 0 else in_channels + + if self.concat_traj_embedding: + in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = MultiviewVideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + n_views=self.n_views, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = MultiviewSinCosPosEmbAxis(**kwargs) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + trajectory = kwargs.get("trajectory", None) + frame_repeat = kwargs.get("frame_repeat", None) + + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + trajectory=trajectory, + frame_repeat=frame_repeat, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + trajectory: Optional[torch.Tensor] = None, + frame_repeat: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + + view_indices = torch.arange(self.n_views).to(x_B_C_T_H_W.device) # View indices [0, 1, ..., V-1] + view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] + view_embedding = rearrange(view_embedding, "V D -> D V") + view_embedding = view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) # Shape: [1, D, V, 1, 1, 1] + + if self.add_repeat_frame_embedding: + if frame_repeat is None: + frame_repeat = ( + torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) + .to(view_embedding.device) + .to(view_embedding.dtype) + ) + frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) + frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") + view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) + + x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) + view_embedding = view_embedding.expand( + x_B_C_V_T_H_W.shape[0], + view_embedding.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + if self.concat_traj_embedding: + traj_emb = self.traj_embeddings(trajectory) + traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + traj_emb = traj_emb.expand( + x_B_C_V_T_H_W.shape[0], + traj_emb.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) + else: + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) + + x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb + + +class VideoExtendGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/training/tensor_parallel.py b/cosmos_predict1/diffusion/training/tensor_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c756c38e53d2c71f3ab3fa2b08859fdf1bb96bc5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/tensor_parallel.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.distributed as dist +from torch.autograd import Function + + +class AllGather(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.rank = process_group.rank() + + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, tensor.contiguous(), process_group) + return torch.cat(gathered_tensors, dim=0) + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + rank = ctx.rank + + # Split the gradient tensor + grad_chunks = grad_output.chunk(world_size) + + # Select the gradient chunk for the current rank + grad_input = grad_chunks[rank] + return grad_input, None + + +def gather_along_first_dim(tensor, process_group): + return AllGather.apply(tensor, process_group) + + +class Scatter(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.process_group = process_group + rank = process_group.rank() + + # Split the tensor + tensor_chunks = tensor.chunk(world_size) + + # Select the tensor chunk for the current rank + return tensor_chunks[rank] + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + process_group = ctx.process_group + + # Gather the gradient tensor + gathered_grads = [torch.zeros_like(grad_output) for _ in range(world_size)] + dist.all_gather(gathered_grads, grad_output.contiguous(), process_group) + return torch.cat(gathered_grads, dim=0), None + + +def scatter_along_first_dim(tensor, process_group): + return Scatter.apply(tensor, process_group) + + +if __name__ == "__main__": + # Torch global setup for distributed training + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Create a tensor with gradients + x = torch.randn(10, 1, requires_grad=True, device="cuda") + + # Perform all_gather with gradient support + y = gather_along_first_dim(x, dist.group.WORLD) + print(f"{y.shape=}") + y = scatter_along_first_dim(y, dist.group.WORLD) + print(f"{y.shape=}") + + # Use the result in your computation + loss = y.sum() + loss.backward() + + # x.grad now contains the gradients + print(x.grad) diff --git a/cosmos_predict1/diffusion/training/train.py b/cosmos_predict1/diffusion/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a65eece39e5cc993918dd53be44d87b8c70d28a0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/train.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +import torch.distributed as dist +from loguru import logger as logging +from omegaconf import OmegaConf + +from cosmos_predict1.diffusion.config.config import Config +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig +from cosmos_predict1.utils.parallel_state_helper import is_tp_cp_pp_rank0 + + +@misc.timer("instantiate model") +def instantiate_model(config: Config, trainer) -> None: + misc.set_random_seed(seed=config.trainer.seed, by_rank=False) + config.model_obj.config = config.model + if getattr(config.model, "fsdp_enabled", False): + assert config.trainer.distributed_parallelism == "fsdp", "FSDP model is only supported with FSDP trainer" + log.critical("FSDP enabled") + config.model_obj.fsdp_checkpointer = trainer.checkpointer + model = instantiate(config.model_obj) + config.model_obj.fsdp_checkpointer = None + else: + model = instantiate(config.model_obj) + config.model_obj.config = None + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + return model + + +def destroy_distributed(): + log.info("Destroying distributed environment...") + if dist.is_available() and dist.is_initialized(): + try: + dist.destroy_process_group() + except ValueError as e: + print(f"Error destroying default process group: {e}") + + +@logging.catch(reraise=True) +def launch(config: Config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # # Setup the miscellaneous stuff for reproducibility. + # log_reproducible_setup(config, args) + # Create the model + model = instantiate_model(config, trainer) + model.on_model_init_end() + # Create the dataloaders. + if args.mp0_only_dl: + log.critical( + "Using only tp_cp_pp_rank0 dataloader for faster dataloading! Make sure val dl is mock and mock data has same keys as real data." + ) + raise NotImplementedError( + "mp0_only_dl is not implemented correctly! Please revisit this code and propose a more robust impl that raise error timely! It does not do necessary check before training to confirm it can work with image / video data. Current impl is problematic for image training." + ) + if is_tp_cp_pp_rank0() or not args.mp0_only_dl: + dataloader_train = instantiate(config.dataloader_train) + else: + dataloader_train = instantiate(config.dataloader_val) + dataloader_val = instantiate(config.dataloader_val) + # Start training + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + destroy_distributed() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training") + parser.add_argument( + "--config", + default="cosmos_predict1/diffusion/posttrain/config/config.py", + help="Path to the config file", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + parser.add_argument( + "--mp0_only_dl", + action="store_true", + help="Use only model parallel rank 0 dataloader for faster dataloading! Make sure mock data has same keys as real data.", + ) + args = parser.parse_args() + config_module = get_config_module(args.config) + config = importlib.import_module(config_module).make_config() + config = override(config, args.opts) + if args.dryrun: + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) + print(f"{config.job.path_local}/config.yaml") + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/diffusion/training/trainer.py b/cosmos_predict1/diffusion/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e68de41c528512404a4c43ade569faaace3caa08 --- /dev/null +++ b/cosmos_predict1/diffusion/training/trainer.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.utils.checkpointer import MultiRankCheckpointer +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer +from cosmos_predict1.utils.trainer import Trainer as BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config): + super(Trainer, self).__init__(config) + if config.trainer.distributed_parallelism == "ddp": + self.checkpointer = MultiRankCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + elif config.trainer.distributed_parallelism == "fsdp": + self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") diff --git a/cosmos_predict1/diffusion/training/utils/checkpointer.py b/cosmos_predict1/diffusion/training/utils/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..069ed708b9247a1b53c7e06d34c159f4d6d9d2e1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/checkpointer.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import List, NamedTuple, Tuple + +import torch + +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer as BaseCheckpointer +from cosmos_predict1.utils.model import Model + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +class MultiRankCheckpointer(BaseCheckpointer): + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + # checkpoint_file = f"iter_{iteration:09}.pt" + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + save_ranks = list(range(total_ema_num)) + for _rank in save_ranks: + if distributed.get_rank() == _rank: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") + resume = self.load_training_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) + if resume: + iteration = state_dict["iteration"] + assert optimizer and scheduler + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + iteration = 0 + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + return iteration + + +# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) diff --git a/cosmos_predict1/diffusion/training/utils/fsdp_helper.py b/cosmos_predict1/diffusion/training/utils/fsdp_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4c1f6150e6d71c9fc626867ae11541e83c0134 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/fsdp_helper.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial + +import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._runtime_utils import ( + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, +) +from torch.distributed.utils import _p_assert + +from cosmos_predict1.utils import distributed, log + + +def apply_fsdp_checkpointing(model, list_block_cls): + """apply activation checkpointing to model + returns None as model is updated directly + """ + log.critical("--> applying fdsp activation checkpointing...") + non_reentrant_wrapper = partial( + checkpoint_wrapper, + # offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + + def check_fn(submodule): + result = False + for block_cls in list_block_cls: + if isinstance(submodule, block_cls): + result = True + break + return result + + apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) + + +@contextmanager +def possible_fsdp_scope( + model: torch.nn.Module, +): + enabled = isinstance(model, FSDP) + if enabled: + assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" + handle = model._handle + args, kwargs = [0], dict(dummy=0) + with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): + args, kwargs = _root_pre_forward(model, model, args, kwargs) + unused = None + args, kwargs = _pre_forward( + model, + handle, + _pre_forward_unshard, + model._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == model.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{model.compute_device} but got {handle.flat_param.device}", + ) + try: + yield None + finally: + if enabled: + output = {"output": 1} + _post_forward(model, handle, _post_forward_reshard, model, unused, output) + + +def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): + """ + Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. + + This function requires explicit sizes for replica and sharding groups to accommodate models + whose GPU fit is unknown, providing flexibility in distributed training setups. + + Args: + replica_group_size (int): The size of each replica group. Must be provided to ensure + the model fits within the available resources. + sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to + ensure the correct distribution of model parameters. + device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" + with the local rank as the device index. + + Returns: + A device mesh object compatible with FSDP. + + Raises: + ValueError: If replica_group_size or sharding_group_size are not provided, or if the + world size is not evenly divisible by the sharding group size. + RuntimeError: If a valid device mesh cannot be created. + + Usage: + If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: + Sharding_Group_Size = 4 + Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups + >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) + >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) + """ + + # world_size = int(os.getenv("WORLD_SIZE", "1")) + world_size = distributed.get_world_size() + if sharding_group_size is None: + sharding_group_size = min(world_size, 8) + sharding_group_size = min(sharding_group_size, world_size) + if replica_group_size is None: + replica_group_size = world_size // sharding_group_size + + device = device or "cuda" + + if world_size % sharding_group_size != 0: + raise ValueError( + f"World size {world_size} is not evenly divisible by " f"sharding group size {sharding_group_size}." + ) + + if (world_size // sharding_group_size) % replica_group_size != 0: + raise ValueError( + f"The calculated number of replica groups is not evenly divisible by " + f"replica_group_size {replica_group_size}." + ) + + device_mesh = init_device_mesh( + device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") + ) + if device_mesh is None: + raise RuntimeError("Failed to create a valid device mesh.") + + log.critical( + f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" + ) + + return device_mesh diff --git a/cosmos_predict1/diffusion/training/utils/inference_long_video.py b/cosmos_predict1/diffusion/training/utils/inference_long_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb222a733666a710cbb98900e2055b3a231cc1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/inference_long_video.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from typing import Tuple, Union + +import einops +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as transforms_F +from matplotlib import pyplot as plt + +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.easy_io import easy_io + +"""This file contain functions needed for long video generation, +* function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20` + +""" + + +@contextmanager +def switch_config_for_inference(model): + """For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False. + This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context. + Args: + model (ExtendDiffusionModel): video generation model + """ + # Store the current condition_location + current_condition_location = model.config.conditioner.video_cond_bool.condition_location + if current_condition_location != "first_n" and current_condition_location != "first_and_last_1": + current_condition_location = "first_n" + current_apply_corruption_to_condition_region = ( + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region + ) + try: + log.info( + "Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False" + ) + # Change the condition_location to "first_n" for inference + model.config.conditioner.video_cond_bool.condition_location = current_condition_location + if current_apply_corruption_to_condition_region == "gaussian_blur": + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean" + elif current_apply_corruption_to_condition_region == "noise_with_sigma": + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed" + # Yield control back to the calling context + yield + finally: + # Restore the original condition_location after exiting the context + log.info( + f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}" + ) + model.config.conditioner.video_cond_bool.condition_location = current_condition_location + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = ( + current_apply_corruption_to_condition_region + ) + + +def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None): + """Debug function to display a latent tensor as a grid of images. + Args: + tensor (torch.Tensor): tensor in shape BCTHW + nrow (int): number of images per row + show_norm (bool): whether to display the norm of the tensor + save_fig_path (str): path to save the visualization + + """ + log.info( + f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}" + ) + tensor = tensor.float().cpu().detach() + tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow) # .numpy() + # display the grid + tensor_mean = tensor.mean(-1) + tensor_norm = tensor.norm(dim=-1) + log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}") + plt.figure(figsize=(20, 20)) + plt.imshow(tensor_mean) + plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}") + if save_fig_path: + os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) + log.info(f"save to {os.path.abspath(save_fig_path)}") + plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0) + plt.show() + if show_norm: + plt.figure(figsize=(20, 20)) + plt.imshow(tensor_norm) + plt.show() + + +def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None): + """Debug function to display a tensor as a grid of images. + Args: + tensor (torch.Tensor): tensor in shape BCTHW + nrow (int): number of images per row + save_fig_path (str): path to save the visualization + """ + log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}") + assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong" + tensor = tensor.float().cpu().detach() + tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w") + # use torchvision to save the tensor as a grid of images + grid = torchvision.utils.make_grid(tensor, nrow=nrow) + if save_fig_path is not None: + os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) + log.info(f"save to {os.path.abspath(save_fig_path)}") + torchvision.utils.save_image(tensor, save_fig_path) + # display the grid + plt.figure(figsize=(20, 20)) + plt.imshow(grid.permute(1, 2, 0)) + plt.show() + + +def compute_num_frames_condition(model: "ExtendDiffusionModel", num_of_latent_overlap: int, downsample_factor=8) -> int: + """This function computes the number of condition pixel frames given the number of latent frames to overlap. + Args: + model (ExtendDiffusionModel): Video generation model + num_of_latent_overlap (int): Number of latent frames to overlap + downsample_factor (int): Downsample factor for temporal reduce + Returns: + int: Number of condition frames in output space + """ + # Access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly + vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer + + # Check if the VAE is causal (default to True if attribute not found) + if getattr(vae, "is_casual", True): + # For causal model + num_frames_condition = num_of_latent_overlap // vae.latent_chunk_duration * vae.pixel_chunk_duration + if num_of_latent_overlap % vae.latent_chunk_duration == 1: + num_frames_condition += 1 + elif num_of_latent_overlap % vae.latent_chunk_duration > 1: + num_frames_condition += 1 + (num_of_latent_overlap % vae.latent_chunk_duration - 1) * downsample_factor + else: + num_frames_condition = num_of_latent_overlap * downsample_factor + + return num_frames_condition + + +def read_video_or_image_into_frames_BCTHW( + input_path: str, + input_path_format: str = None, + H: int = None, + W: int = None, + normalize: bool = True, + max_frames: int = -1, + also_return_fps: bool = False, +) -> torch.Tensor: + """Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1]. + Args: + input_path (str): path to the input video or image, end with .mp4 or .png or .jpg + H (int): height to resize the video + W (int): width to resize the video + Returns: + torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1] + """ + log.info(f"Reading video from {input_path}") + + loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None) + if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): + frames = np.array(loaded_data) # HWC, [0,255] + if frames.shape[-1] > 3: # RGBA, set the transparent to white + # Separate the RGB and Alpha channels + rgb_channels = frames[..., :3] + alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] + + # Create a white background + white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB + + # Blend the RGB channels with the white background based on the alpha channel + frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( + np.uint8 + ) + frames = [frames] + fps = 0 + else: + frames, meta_data = loaded_data + fps = int(meta_data.get("fps")) + if max_frames != -1: + frames = frames[:max_frames] + input_tensor = np.stack(frames, axis=0) + input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") + if normalize: + input_tensor = input_tensor / 128.0 - 1.0 + input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW + log.info(f"Raw data shape: {input_tensor.shape}") + if H is not None and W is not None: + input_tensor = transforms_F.resize( + input_tensor, + size=(H, W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) + if normalize: + input_tensor = input_tensor.to("cuda") + log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") + if also_return_fps: + return input_tensor, fps + return input_tensor + + +def create_condition_latent_from_input_frames( + model: ExtendDiffusionModel, + input_frames: torch.Tensor, + num_frames_condition: int = 25, +): + """Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent. + Args: + model (ExtendDiffusionModel): Video generation model + input_frames (torch.Tensor): Video tensor in shape (1,C,T,H,W), range [-1, 1] + num_frames_condition (int): Number of condition frames + Returns: + torch.Tensor: Condition latent in shape B,C,T,H,W + """ + B, C, T, H, W = input_frames.shape + # Dynamically access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly + vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer + num_frames_encode = vae.pixel_chunk_duration # Access pixel_chunk_duration from the VAE + log.info( + f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" + ) + + log.info( + f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" + ) + + assert ( + input_frames.shape[2] >= num_frames_condition + ), f"input_frames not enough for condition, require at least {num_frames_condition}, got {input_frames.shape[2]}, {input_frames.shape}" + assert ( + num_frames_encode >= num_frames_condition + ), f"num_frames_encode should be larger than num_frames_condition, got {num_frames_encode}, {num_frames_condition}" + + # Put the conditional frames at the beginning of the video, and pad the end with zeros + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + condition_frames_first = input_frames[:, :, :num_frames_condition] + condition_frames_last = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) + else: + condition_frames = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) + + log.info( + f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" + ) + if hasattr(model, "n_views"): + encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW + latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) + latent = torch.cat([latent1, latent2], dim=2) # BCTHW + else: + latent = model.encode(encode_input_frames) + return latent, encode_input_frames + + +def get_condition_latent( + model: ExtendDiffusionModel, + conditioned_image_or_video_path: str, + num_of_latent_condition: int = 4, + state_shape: list[int] = None, + input_path_format: str = None, + frame_index: int = 0, + frame_stride: int = 1, +): + if state_shape is None: + state_shape = model.state_shape + if num_of_latent_condition == 0: + log.info("No condition latent needed, return empty latent") + condition_latent = ( + torch.zeros( + [ + 1, + ] + + state_shape + ) + .to(torch.bfloat16) + .cuda() + ) + return condition_latent, None + + H, W = ( + state_shape[-2] * model.vae.spatial_compression_factor, + state_shape[-1] * model.vae.spatial_compression_factor, + ) + input_frames = read_video_or_image_into_frames_BCTHW( + conditioned_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + start_frame = frame_index * frame_stride + end_frame = (frame_index + 1) * frame_stride + input_frames = torch.cat( + [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 + ).contiguous() # BCTHW + + num_frames_condition = compute_num_frames_condition( + model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor + ) + + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition) + condition_latent = condition_latent.to(torch.bfloat16) + return condition_latent, input_frames + + +def generate_video_from_batch_with_loop( + model: ExtendDiffusionModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + condition_latent: torch.Tensor, + # hyper-parameters for inference + num_of_loops: int, + num_of_latent_overlap_list: list[int], + guidance: float, + num_steps: int, + seed: int, + add_input_frames_guidance: bool = False, + augment_sigma_list: list[float] = None, + data_batch_list: Union[None, list[dict]] = None, + visualize: bool = False, + save_fig_path: str = None, + skip_reencode: int = 0, + return_noise: bool = False, +) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]: + """Generate video with loop, given data batch. The condition latent will be updated at each loop. + Args: + model (ExtendDiffusionModel) + state_shape (list): shape of the state tensor + is_negative_prompt (bool): whether to use negative prompt + + data_batch (dict): data batch for video generation + condition_latent (torch.Tensor): condition latent in shape BCTHW + + num_of_loops (int): number of loops to generate video + num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap + guidance (float): The guidance scale to use during sample generation; defaults to 5.0. + num_steps (int): number of steps for diffusion sampling + seed (int): random seed for sampling + add_input_frames_guidance (bool): whether to add image guidance, default is False + augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None + + data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None + visualize (bool): whether to visualize the latent and grid, default is False + save_fig_path (str): path to save the visualization, default is None + + skip_reencode (int): whether to skip re-encode the input frames, default is 0 + return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False + Returns: + np.array: generated video in shape THWC, range [0, 255] + list: list of condition latent, each in shape BCTHW + list: list of sample latent, each in shape BCTHW + torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True) + """ + + if data_batch_list is None: + data_batch_list = [data_batch for _ in range(num_of_loops)] + if visualize: + assert save_fig_path is not None, "save_fig_path should be set when visualize is True" + + # Generate video with loop + condition_latent_list = [] + decode_latent_list = [] # list collect the latent token to be decoded at the end + sample_latent = [] + grid_list = [] + + augment_sigma_list = ( + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value + if augment_sigma_list is None + else augment_sigma_list + ) + + for i in range(num_of_loops): + num_of_latent_overlap_i = num_of_latent_overlap_list[i] + num_of_latent_overlap_i_plus_1 = ( + num_of_latent_overlap_list[i + 1] + if i + 1 < len(num_of_latent_overlap_list) + else num_of_latent_overlap_list[-1] + ) + if condition_latent.shape[2] < state_shape[1]: + # Padding condition latent to state shape + log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}") + b, c, t, h, w = condition_latent.shape + condition_latent = torch.cat( + [ + condition_latent, + condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), + ], + dim=2, + ).contiguous() + log.info(f"after padding, condition latent shape {condition_latent.shape}") + log.info(f"Generate video loop {i} / {num_of_loops}") + if visualize: + log.info(f"Visualize condition latent {i}") + visualize_latent_tensor_bcthw( + condition_latent[:, :, :4].float(), + nrow=4, + save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"), + ) # BCTHW + + condition_latent_list.append(condition_latent) + + if i < len(augment_sigma_list): + condition_video_augment_sigma_in_inference = augment_sigma_list[i] + log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}") + else: + condition_video_augment_sigma_in_inference = augment_sigma_list[-1] + assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported" + + sample = model.generate_samples_from_batch( + data_batch_list[i], + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed + i, + condition_latent=condition_latent, + num_condition_t=num_of_latent_overlap_i, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + return_noise=return_noise, + ) + + if return_noise: + sample, noise = sample + + if visualize: + log.info(f"Visualize sampled latent {i} 4-8 frames") + visualize_latent_tensor_bcthw( + sample[:, :, 4:8].float(), + nrow=4, + save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"), + ) # BCTHW + + diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i] + log.info( + f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}" + ) + + sample_latent.append(sample) + T = condition_latent.shape[2] + assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}" + + if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True" + if i == 0: + decode_latent_list.append(sample) + else: + decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:]) + else: + # Interpolator mode. Decode the first and last as an image. + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + grid_BCTHW_1 = (1.0 + model.decode(sample[:, :, :-1, ...])).clamp(0, 2) / 2 # [B, 3, T-1, H, W], [0, 1] + grid_BCTHW_2 = (1.0 + model.decode(sample[:, :, -1:, ...])).clamp(0, 2) / 2 # [B, 3, 1, H, W], [0, 1] + grid_BCTHW = torch.cat([grid_BCTHW_1, grid_BCTHW_2], dim=2) # [B, 3, T, H, W], [0, 1] + else: + grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] + + if visualize: + log.info(f"Visualize grid {i}") + visualize_tensor_bcthw( + grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png") + ) + grid_np_THWC = ( + (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + ) # THW3, range [0, 255] + + # Post-process the output: cut the conditional frames from the output if it's not the first loop + num_cond_frames = compute_num_frames_condition( + model, num_of_latent_overlap_i_plus_1, downsample_factor=model.tokenizer.temporal_compression_factor + ) + if i == 0: + new_grid_np_THWC = grid_np_THWC # First output, dont cut the conditional frames + else: + new_grid_np_THWC = grid_np_THWC[ + num_cond_frames: + ] # Remove the conditional frames from the output, since it's overlapped with previous loop + grid_list.append(new_grid_np_THWC) + + # Prepare the next loop: re-compute the condition latent + if hasattr(model, "n_views"): + grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views) + condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1 # BCTHW, range [0, 1] to [-1, 1] + if skip_reencode: + # Use the last num_of_latent_overlap latent token as condition latent + log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token") + condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:] + else: + # Re-encode the condition frames to get the new condition latent + condition_latent, _ = create_condition_latent_from_input_frames( + model, condition_frame_input, num_frames_condition=num_cond_frames + ) # BCTHW + condition_latent = condition_latent.to(torch.bfloat16) + + # save videos + if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + # decode all video together + decode_latent_list = torch.cat(decode_latent_list, dim=2) + grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] + video_THWC = ( + (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + ) # THW3, range [0, 255] + else: + video_THWC = np.concatenate(grid_list, axis=0) # THW3, range [0, 255] + + if return_noise: + return video_THWC, condition_latent_list, sample_latent, noise + return video_THWC, condition_latent_list, sample_latent diff --git a/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py b/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1d38076f2ecd368bcf26b992f71b4a5c5bb9fe81 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from collections import defaultdict +from typing import Union + +from loguru import logger +from omegaconf import DictConfig, ListConfig + +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils.validator import Float, Int, OneOf + + +class LayerControlConfigParser: + """ + Parses a config to select layers, blocks, and subblocks to apply LoRA, PEFT, and other finegrained post-training techniques. + A base model is first loaded then edits (i.e. LoRA, unfreeze, etc.) are applied to the model. Currently, only LoRA is supported for to_q, to_k, to_v, to_out attention layers. + See: cosmos_predict1/diffusion/training/utils/peft/lora_config.py and LoRA diffusion post-training for an example of how to create and use a LoRA config. + The input config is a dictionary with the following keys: + - enabled: whether to apply the PEFT + - customization_type: default/global type of PEFT to apply (LoRA, unfreeze, etc.) + - rank: default/global LoRA rank + - scale: default/global LoRA scale + - edits: a list of model edits to apply. + - blocks: a regex to select the blocks to apply the edit to: eg: r'\b(0|1|25|26)\b' + - block_edit: a list of subblocks to apply the edit to: eg: ["FA[to_q, to_v]", "CA[to_q, to_v]"]. + Subblock names correspond to FA (Full-Attention), CA (Cross-Attention), FL (FinalLayer), and MLP modules as defined in general_dit.py, + and the layers (i.e to_q, to_k, to_v, etc.) are defined in corresponding modules in attention.py. + - customization_type: type of PEFT to apply for the edit (LoRA, unfreeze, etc.) - overrides the global customization_type if provided + - rank: LoRA rank - overrides the global rank for target blocks and subblocks if provided + - scale: LoRA scale - overrides the global scale for target blocks and subblocks if provided + """ + + SUBBLOCK_PATTERN = r"^(?P.+?)\[(?P[^\]]+)\]$" # determines the subblock type (i.e. "FA[...]") + LAYER_PATTERN = r"^(?P.+?)(?::(?P.+?))?(?::(?P[\d\.]+))?$" # determines the layer details (i.e. to_q:8:0.6 or to_q) + FINAL_LAYER_NAME = "final_layer" + DEFAULT_ALLOWED_TYPES = { # subblock type to layer types + "FA": {"to_q", "to_k", "to_v", "to_out", "ada1", "ada2"}, + "CA": {"to_q", "to_k", "to_v", "to_out", "ada1", "ada2"}, + "MLP": {"l1", "l2", "ada1", "ada2"}, + } + + DEFAULT_VALUE_CONSTRAINTS = ( + { # field to allowed ranges. these ranges are not prescriptive and can be adjusted as needed. + "blocks": {"min": 0, "max": 27}, + "rank": {"min": 1, "max": 512}, + "scale": {"min": 1e-5, "max": 64}, + } + ) + ALLOWED_TYPES_FINAL_LAYER = {"FL": {"l1", "ada1", "ada2"}} + + def __init__(self, config: Union[str, dict] = {}, allowed_types: dict = None, value_constraints: dict = None): + self.config = self._config_to_dict(config) + self.enabled = str(self.config.get("enabled", "False")).lower() in ( + "true", + "1", + "yes", + ) # if not set, assume disabled + if self.enabled and not self.config.get("customization_type", ""): + raise AttributeError("Must specify a top-level customization_type.") + self.default_customization_type = CustomizationType.from_value(self.config.get("customization_type", "")) + self.default_rank = self.config.get("rank", None) + self.default_scale = self.config.get("scale", None) + + self.allowed_types = allowed_types or self.DEFAULT_ALLOWED_TYPES + self.value_constraints = value_constraints or self.DEFAULT_VALUE_CONSTRAINTS + logger.info( + f"Creating layers config with allowed subblock + layer types: \n{self.allowed_types} and value constraints: \n{self.value_constraints}" + ) + self.allowed_types_final_layer = self.ALLOWED_TYPES_FINAL_LAYER + + self._set_validators() + + self.all_blocks_str = ( + ",".join( + str(i) + for i in range( + self.value_constraints.get("blocks").get("min"), self.value_constraints.get("blocks").get("max") + 1 + ) + ) + + "," + + self.FINAL_LAYER_NAME + ) + + self.edits_per_block = defaultdict(lambda: None) + + def _set_validators(self): + """ + Sets validators for blocks, subblocks, rank, and scale. + + Raises: + AttributeError: If value constraints are not properly defined. + """ + self.subblock_validator = OneOf(default="", options=self.allowed_types.keys()) + self.final_layer_validator = OneOf(default="", options=self.allowed_types_final_layer.keys()) + self.rank_validator = None + self.scale_validator = None + try: + self.rank_validator = Int( + default=0, + min=self.value_constraints.get("rank").get("min"), + max=self.value_constraints.get("rank").get("max"), + ) + self.scale_validator = Float( + default=0, + min=self.value_constraints.get("scale").get("min"), + max=self.value_constraints.get("scale").get("max"), + ) + except AttributeError: + raise AttributeError( + "Value Constraints dictionary must contain 'blocks', 'rank', and 'scale' attributes with 'min' and 'max' attributes for each" + ) + + def _config_to_dict(self, config): + """ + Convert the given config into a dictionary if provided as a string. + + Args: + config (Union[str, dict]): The configuration as a JSON string or dictionary. + + Returns: + dict: The configuration as a dictionary. + + Raises: + ValueError: If the JSON string is invalid. + TypeError: If the config is not a string or dictionary. + """ + if isinstance(config, str): + try: + config = json.loads(config) + except json.JSONDecodeError: + raise ValueError("Invalid JSON string provided") + elif not isinstance(config, (dict, DictConfig)): + raise TypeError(f"Config should be either a JSON string or a dictionary, but got {type(config)}") + return config + + def _parse_blocks_regex(self, regex): + """ + Parse the 'blocks' regex and return a set of matching block numbers. + Allowed block numbers: defined in value_constraints, plus 'final_layer' + + Args: + regex (str): The regex pattern to match block numbers. + + Returns: + set: A set of block numbers that match the regex. + + Raises: + ValueError: If the regex pattern is invalid or matches invalid block numbers. + Exception: If 'final_layer' is defined with other blocks. + """ + try: + block_matches = re.findall(regex, self.all_blocks_str) + block_numbers = set() + for match in block_matches: + match = match.strip() + if match == "final_layer": + block_numbers.add(match) + else: + try: + block_numbers.add(int(match)) + except ValueError: + raise ValueError(f"Invalid match found: '{match}' is neither an integer nor 'final_layer'.") + except re.error as e: + raise ValueError(f"Invalid regex pattern provided: {regex}. Error: {e}") + + # as final_layer contains a different block type than other blocks, must be defined separately + if "final_layer" in block_numbers and len(block_numbers) > 1: + raise Exception(f"Block 'final_layer' must be defined separately, but got: {block_numbers}") + + return block_numbers + + def _parse_subblocks( + self, + block_edit: list | ListConfig, + customization_type: str, + rank: int, + scale: float, + is_final_layer: bool = False, + ): + """Generate a dictionary of edits config by subblock. + + Args: + block_edit (list): List of representing subblocks to apply the edit to (i.e ["FA[to_q, to_v]", "CA[to_q, to_v]"]) + customization_type (str): The type of PEFT to apply. + rank (int): The LoRA rank. + scale (float): The LoRA scale. + is_final_layer (bool): Indicates if this edit is for the final layer. + + Returns: + defaultdict: A dictionary of subblock edits configs. + + Raises: + TypeError: If block_edit is not a list. + AttributeError: If subblock format is incorrect or layer format is invalid. + ValueError: If rank and scale values are not provided. + """ + sb_dict = defaultdict(lambda: None) + + if not isinstance(block_edit, (list, ListConfig)): + raise TypeError(f"Config 'block_edits' field must be a list, but got {type(block_edit)}") + + if is_final_layer: # final layer has different allowed layer names + subblock_validator = self.final_layer_validator + allowed_types = self.allowed_types_final_layer + else: + subblock_validator = self.subblock_validator + allowed_types = self.allowed_types + + for subblock in block_edit: + sb_name = None + params_list = None + try: + sb_match = re.match(self.SUBBLOCK_PATTERN, subblock) + sb_name = subblock_validator.validate(sb_match.group("subblock")) + params_str = sb_match.group("parameters") + params_list = params_str.replace(" ", "").split(",") + except AttributeError: + raise AttributeError("Incorrect sub-block format: must be [...]") + layer_validator = OneOf(default="", options=allowed_types.get(sb_name)) + + # for each parameter in the subblock config + layers_dict = defaultdict(lambda: None) + for param in params_list: + try: + layer_match = re.match(self.LAYER_PATTERN, param) + layer_name = layer_validator.validate(layer_match.group("layer")) + layer_rank = layer_match.group("rank") or rank or self.default_rank + layer_scale = layer_match.group("scale") or scale or self.default_scale + if not layer_rank or not layer_scale: + raise ValueError( + "Rank and scale values must be provided at default, sub-block, or layer level." + ) + layer_rank = self.rank_validator.validate(layer_rank) + layer_scale = self.scale_validator.validate(layer_scale) + + layers_dict[layer_name] = {"activate": True, "lora_rank": layer_rank, "lora_scale": layer_scale} + layers_dict["customization_type"] = customization_type or self.default_customization_type + sb_dict[sb_name] = dict(layers_dict) + except AttributeError: + raise AttributeError("Layer format must be :[:] (where is optional)") + + if sb_dict: + sb_dict["customization_type"] = customization_type or self.default_customization_type + return sb_dict + + def parse(self): + """ + Parse the loaded config into a dictionary of edit configs by block number. + + Returns: + dict: A dictionary of edit configs applied to each block. + + Raises: + Exception: If more than one edit is specified for a block. + """ + if not self.enabled: + return {} + + # for each edit in the config + for edit in self.config.get("edits", []): + blocks = self._parse_blocks_regex(edit["blocks"]) # get the blocks affected by edit + logger.info(f"Applying edits for blocks {blocks}") + block_edit = edit.get("block_edit", []) + customization_type = CustomizationType.from_value(edit.get("customization_type", "")) + rank = edit.get("rank", None) + scale = edit.get("scale", None) + is_final_layer = blocks == set([self.FINAL_LAYER_NAME]) + # get subblock config + sb_dict = self._parse_subblocks( + block_edit=block_edit, + customization_type=customization_type, + rank=rank, + scale=scale, + is_final_layer=is_final_layer, + ) + + # for each block in the edit + for block in blocks: + if sb_dict: + if self.edits_per_block[block]: + raise Exception(f"More than one edit specified for block {block}") + self.edits_per_block[block] = dict(sb_dict) + if self.edits_per_block: + self.edits_per_block["customization_type"] = self.default_customization_type + return dict(self.edits_per_block) diff --git a/cosmos_predict1/diffusion/training/utils/optim_instantiate.py b/cosmos_predict1/diffusion/training/utils/optim_instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..96f45a65ffd11fd472efd34287bb867f69379986 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/optim_instantiate.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import torch +from torch import nn + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.fused_adam import FusedAdam + + +def get_regular_param_group(net: nn.Module): + """ + seperate the parameters of the network into two groups: decay and no_decay. + based on nano_gpt codebase. + """ + param_dict = {pn: p for pn, p in net.named_parameters()} + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + return decay_params, nodecay_params + + +def get_base_optimizer( + model: nn.Module, + lr: float, + weight_decay: float, + optim_type: str = "adamw", + sharding: bool = False, + **kwargs, +) -> torch.optim.Optimizer: + net_decay_param, net_nodecay_param = get_regular_param_group(model) + + num_decay_params = sum(p.numel() for p in net_decay_param) + num_nodecay_params = sum(p.numel() for p in net_nodecay_param) + net_param_total = num_decay_params + num_nodecay_params + log.critical(f"total num parameters : {net_param_total:,}") + + param_group = [ + { + "params": net_decay_param + net_nodecay_param, + "lr": lr, + "weight_decay": weight_decay, + }, + ] + + if optim_type == "adamw": + opt_cls = torch.optim.AdamW + elif optim_type == "fusedadam": + opt_cls = FusedAdam + else: + raise ValueError(f"Unknown optimizer type: {optim_type}") + + return opt_cls(param_group, **kwargs) + + +def get_base_scheduler( + optimizer: torch.optim.Optimizer, + model: nn.Module, + scheduler_config: dict, +): + net_scheduler = hydra.utils.instantiate(scheduler_config) + net_scheduler.model = model + + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=[ + net_scheduler.schedule, + ], + ) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py b/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..269ac32205117d820bb0f24f3dcac9e976439299 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange +from torch.utils.checkpoint import checkpoint +from transformer_engine.pytorch.attention import apply_rotary_pos_emb + +from cosmos_predict1.diffusion.module.attention import Attention +from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType + +try: + from megatron.core import parallel_state + + USE_MEGATRON = True +except ImportError: + USE_MEGATRON = False + + +def enable_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Enable LoRA for the attention block based on the peft_control dictionary. + + Args: + attn (Attention): The attention block to configure. + peft_control (dict): Dictionary containing PEFT configuration. + """ + attn.peft_lora_enabled = False + if peft_control: + try: + if peft_control["customization_type"] == CustomizationType.LORA: + attn.peft_lora_enabled = True + else: + raise Exception(f"Unsupported Customization type {peft_control['customization_type']}") + except KeyError as e: + raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.") + + +def configure_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Configure LoRA for the attention block based on the peft_control dictionary. + + Args: + attn (Attention): The attention block to configure. + peft_control (dict): Dictionary containing PEFT configuration. + """ + try: + attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False) + attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False) + attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False) + attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False) + if attn.q_lora_enabled: + attn.q_lora_rank = peft_control["to_q"]["lora_rank"] + attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"]) + if attn.k_lora_enabled: + attn.k_lora_rank = peft_control["to_k"]["lora_rank"] + attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"]) + if attn.v_lora_enabled: + attn.v_lora_rank = peft_control["to_v"]["lora_rank"] + attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"]) + if attn.out_lora_enabled: + attn.out_lora_rank = peft_control["to_out"]["lora_rank"] + attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"]) + except KeyError as e: + raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.") + except ValueError as e: + raise ValueError(f"Could not convert string to float: {e}") + + +def cal_qkv_lora( + self, + x: torch.Tensor, + context: torch.Tensor = None, + mask: torch.Tensor = None, + rope_emb: torch.Tensor = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + """ + Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv. + + Args: + x (torch.Tensor): Input tensor. + context (torch.Tensor, optional): Context tensor + mask (torch.Tensor, optional): Mask tensor + rope_emb (torch.Tensor, optional): Rotary positional embedding + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices. + """ + + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + + if self.peft_lora_enabled: + try: + if self.q_lora_enabled: + q_lora = self.to_q_lora(x) + q = q + self.q_lora_scale * q_lora + if self.k_lora_enabled: + k_lora = self.to_k_lora(context) + k = k + self.k_lora_scale * k_lora + if self.v_lora_enabled: + v_lora = self.to_v_lora(context) + v = v + self.v_lora_scale * v_lora + except AttributeError as e: + raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block") + + q, k, v = map( + lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) + return q, k, v + + q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) + + return q, k, v + + +def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + mask (torch.Tensor, optional): Mask tensor. + + Returns: + torch.Tensor: The attention output. + """ + if self.backend == "transformer_engine": + seq_dim = self.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] + out = self.to_out(attn_out) + + if self.peft_lora_enabled and self.out_lora_enabled: + try: + out_lora = self.to_out_lora(attn_out) + out = out + self.out_lora_scale * out_lora + except AttributeError as e: + raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") + + return out + elif self.backend == "torch": + attn_out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] + attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)") + out = self.to_out(attn_out) + + if self.peft_lora_enabled and self.out_lora_enabled: + try: + out_lora = self.to_out_lora(attn_out) + out = out + self.out_lora_scale * out_lora + except AttributeError as e: + raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") + + return out + else: + raise ValueError(f"Backend {self.backend} not found") + + +def build_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Configure, build and add LoRA layers to the attention block. + + Args: + attn (Attention): The attention block to add LoRA layers to. + peft_control (dict): Dictionary containing PEFT configuration. + """ + enable_attn_lora(attn, peft_control) + configure_attn_lora(attn, peft_control) + if attn.peft_lora_enabled: + query_dim = attn.query_dim + inner_dim = attn.inner_dim + context_dim = attn.context_dim + tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None + + if attn.tp_size == 1: + if attn.q_lora_enabled: + attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True) + if attn.k_lora_enabled: + attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True) + if attn.v_lora_enabled: + attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True) + if attn.out_lora_enabled: + attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True) + else: + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + if attn.q_lora_enabled: + attn.to_q_lora = TELoRALinearLayer( + query_dim, + inner_dim, + rank=attn.q_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.k_lora_enabled: + attn.to_k_lora = TELoRALinearLayer( + context_dim, + inner_dim, + rank=attn.k_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.v_lora_enabled: + attn.to_v_lora = TELoRALinearLayer( + context_dim, + inner_dim, + rank=attn.v_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.out_lora_enabled: + attn.to_out_lora = TELoRALinearLayer( + inner_dim, + query_dim, + rank=attn.out_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="row", + ) + attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__) + attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py b/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cc52b1a55234eb4b3b6b0cfa84acf9f0be42b0af --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + pytest -s cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py +""" + +import copy + +import pytest +import torch +import torch.nn as nn +from einops import rearrange, repeat +from loguru import logger + +from cosmos_predict1.diffusion.config.base.net import FADITV2Config +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, get_all_lora_params +from cosmos_predict1.utils.lazy_config import instantiate + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + dummy_net = copy.deepcopy(FADITV2Config) + dummy_net.num_blocks = 2 + dummy_net.model_channels = 256 + dummy_net.num_heads = 8 + self.net = instantiate(dummy_net).cuda() + + +@pytest.fixture() +def block1_peft_control(): + """ + This config has the following edits for the following blocks: + Block 0: FA, CA edits for ALL sub-blocks + """ + config = { + "enabled": "True", + "edits": [ + { + "blocks": "\\b\\d*([1])\\b", + "customization_type": "LoRA", + "rank": 8, + "scale": 0.6, + "block_edit": [ + "FA[to_q:8:0.8, to_k:16:1.2, to_v:4:64, to_out:8]", + "CA[to_q:16, to_k:16, to_v:4, to_out:32]", + ], + }, + ], + "customization_type": "LoRA", + "rank": 8, + "scale": 0.8, + } + config_parser = LayerControlConfigParser(config) + return config_parser.parse() + + +def test_model_without_lora(): + model = DummyModel() + lora_params = get_all_lora_params(model) + actual = len(lora_params) + expected = 0 + assert actual == expected, f"Expected {expected} LoRA layers, got {actual}" + + +def test_model_with_lora(block1_peft_control): + model = DummyModel() + add_lora_layers(model, block1_peft_control) + lora_params = get_all_lora_params(model) + actual = len(lora_params) + expected = 16 + assert actual == expected, f"Expected {expected} LoRA layers, got {actual}" + + +def test_model_cal_qkv_lora_matches_base_version_at_init(block1_peft_control): + model = DummyModel() + # isolate a single attention layer + block_idx = 1 + attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + x = torch.rand(2, 16, 256).cuda() # batch size, seq len, embed size + + q_base, k_base, v_base = attn.cal_qkv(x) + add_lora_layers(model, block1_peft_control) + model.cuda() + q_lora, k_lora, v_lora = attn.cal_qkv(x) + + assert torch.allclose(q_base, q_lora) + assert torch.allclose(k_base, k_lora) + assert torch.allclose(v_base, v_lora) + + +def test_model_cal_qkv_lora_with_non_zero_lora(block1_peft_control): + model = DummyModel() + block_idx = 1 + self_attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + cross_attn = model.net.blocks[f"block{block_idx}"].blocks[1].block.attn + # Set q_norm and k_norm to Identity + for attn in [self_attn, cross_attn]: + attn.to_q[0].weight.data.fill_(0.1) + attn.to_k[0].weight.data.fill_(0.1) + attn.to_v[0].weight.data.fill_(0.1) + attn.to_q[1] = nn.Identity() # Set normalization to Identity + attn.to_k[1] = nn.Identity() + attn.to_v[1] = nn.Identity() + attn.to_q[1].cuda() + attn.to_k[1].cuda() + attn.to_v[1].cuda() + + q_base, k_base, v_base = {}, {}, {} + x = torch.ones(2, 16, 256).cuda() # batch size, seq len, embed size + cross_attn_context = torch.ones(2, 16, 1024).cuda() + context_dim = {"FA": 256, "CA": 1024} + input_context = {"FA": x, "CA": cross_attn_context} + + # Compute base qkv for both self and cross attention + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + q_base[attn_name], k_base[attn_name], v_base[attn_name] = attn.cal_qkv(x, input_context[attn_name]) + # add lora layers + add_lora_layers(model, block1_peft_control) + model.cuda() + + # compute lora qkv with non-zero lora weights + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + attn.to_q_lora.net[0].weight.data.fill_(0.1) + attn.to_q_lora.net[1].weight.data.fill_(0.2) + + attn.to_k_lora.net[0].weight.data.fill_(0.3) + attn.to_k_lora.net[1].weight.data.fill_(0.4) + + attn.to_v_lora.net[0].weight.data.fill_(0.5) + attn.to_v_lora.net[1].weight.data.fill_(0.6) + + q_lora, k_lora, v_lora = attn.cal_qkv(x, input_context[attn_name]) + + # Compare with expected lora qkv + self_attn_q_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_q", {}).get("lora_scale") + ) + self_attn_q_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_q", {}).get("lora_rank") + ) + q_lora_diff = 256 * 0.1 * self_attn_q_lora_rank * 0.2 + + self_attn_k_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_k", {}).get("lora_scale") + ) + self_attn_k_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_k", {}).get("lora_rank") + ) + k_lora_diff = context_dim[attn_name] * 0.3 * self_attn_k_lora_rank * 0.4 + + self_attn_v_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_v", {}).get("lora_scale") + ) + self_attn_v_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_v", {}).get("lora_rank") + ) + v_lora_diff = context_dim[attn_name] * 0.5 * self_attn_v_lora_rank * 0.6 + + expected_q_lora = q_base[attn_name] + self_attn_q_lora_scale * q_lora_diff + expected_k_lora = k_base[attn_name] + self_attn_k_lora_scale * k_lora_diff + expected_v_lora = v_base[attn_name] + self_attn_v_lora_scale * v_lora_diff + logger.info(f"attn_name: {attn_name}, q_lora: {q_lora.shape}, expected_q_lora: {expected_q_lora.shape}") + assert torch.allclose( + q_lora, expected_q_lora, rtol=1e-2 + ), f"q_lora: {q_lora[0, 0, 0, :2]}, expected_q_lora: {expected_q_lora[0, 0, 0, :2]}" + assert torch.allclose( + k_lora, expected_k_lora, rtol=1e-2 + ), f"k_lora: {k_lora[0, 0, 0, :2]}, expected_k_lora: {expected_k_lora[0, 0, 0, :2]}" + assert torch.allclose( + v_lora, expected_v_lora, rtol=1e-2 + ), f"v_lora: {v_lora[0, 0, 0, :2]}, expected_v_lora: {expected_v_lora[0, 0, 0, :2]}" + + +def test_model_cal_attn_lora_matches_base_version_at_init(block1_peft_control): + model = DummyModel() + q = torch.rand(2, 16, 8, 32).cuda() + k = torch.rand(2, 16, 8, 32).cuda() + v = torch.rand(2, 16, 8, 32).cuda() + + # isolate a single attention layer + block_idx = 1 + attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + attn_output_base = attn.cal_attn(q, k, v) # [2, 16, 256] + + add_lora_layers(model, block1_peft_control) + model.cuda() + attn_output_lora = attn.cal_attn(q, k, v) + + assert torch.allclose(attn_output_base, attn_output_lora) + + +def test_model_cal_attn_lora_with_non_zero_output_lora(block1_peft_control): + model = DummyModel() + block_idx = 1 + self_attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + cross_attn = model.net.blocks[f"block{block_idx}"].blocks[1].block.attn + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + # Overwrite attn_op to return ones of shape [2, 16, 256] and output_dropout to be Identity + class OnesAttnOp(nn.Module): + def forward(self, *args, **kwargs): + return torch.ones([2, 16, 256]).cuda() + + attn.attn_op = OnesAttnOp() + attn.to_out[0].weight.data.fill_(0.1) + attn.to_out[1] = nn.Identity() # Remove dropout + + # Compute base attn output + q = torch.rand(2, 16, 8, 32).cuda() + k = torch.rand(2, 16, 8, 32).cuda() + v = torch.rand(2, 16, 8, 32).cuda() + attn_output_base = attn.cal_attn(q, k, v) + + # Add lora layers + add_lora_layers(model, block1_peft_control) + model.cuda() + # Set lora weights to non-zero + attn.to_out_lora.net[0].weight.data.fill_(0.1) + attn.to_out_lora.net[1].weight.data.fill_(0.2) + + # Compute lora attn output + attn_output_lora = attn.cal_attn(q, k, v) + + # Compare with expected lora attn output + output_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_out", {}).get("lora_scale") + ) + output_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_out", {}).get("lora_rank") + ) + + expected_attn_output_lora = attn_output_base + output_lora_scale * 256 * 0.1 * output_lora_rank * 0.2 + assert torch.allclose( + attn_output_lora, expected_attn_output_lora, rtol=1e-2 + ), f"attn_output_lora: {attn_output_lora[0, 0, :2]}, expected_attn_output_lora: {expected_attn_output_lora[0, 0, :2]}" diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_config.py b/cosmos_predict1/diffusion/training/utils/peft/lora_config.py new file mode 100644 index 0000000000000000000000000000000000000000..896d32f396a36c126323170462f452691d5f1be0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1): + """ + Get a LoRA configuration for the Self-Attention (FA) and Cross-Attention (CA) blocks in the model. + This LoRA configuration is used to inject LoRA parameters into the model. + + Args: + first_nblocks (int): The number of blocks to apply LoRA to. + rank (int): The rank of the LoRA matrices. + """ + blocks_regex = r"\b(" + "|".join([str(i) for i in range(first_nblocks)]) + r")\b" + return dict( + enabled=True, + customization_type="LoRA", + rank=rank, + scale=scale, + edits=[ + dict( + blocks=blocks_regex, + customization_type="LoRA", + rank=rank, + scale=scale, + block_edit=[ + "FA[to_q, to_v]", + "CA[to_q, to_v]", + ], + ) + ], + ) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_net.py b/cosmos_predict1/diffusion/training/utils/peft/lora_net.py new file mode 100644 index 0000000000000000000000000000000000000000..da7dbe2224613b4a52a4cd79707261c1b707a20b --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_net.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformer_engine as te +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.utils import log + + +class LoRALinearLayer(nn.Module): + """ + ported from + https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. + """ + + def __init__(self, in_features, out_features, rank=4, linear=False): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + if linear: + down = nn.Linear(in_features, rank, bias=False) + up = nn.Linear(rank, out_features, bias=False) + else: + down = nn.Conv1d(in_features, rank, 1, bias=False) + up = nn.Conv1d(rank, out_features, 1, bias=False) + + nn.init.normal_(down.weight, std=1 / rank) + nn.init.zeros_(up.weight) + self.net = nn.Sequential(down, up) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.net[0].weight.dtype + + up_hidden_states = self.net(hidden_states.to(dtype)) + + return up_hidden_states.to(orig_dtype) + + +class TELoRALinearLayer(nn.Module): + """ + ported from + https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. + """ + + def __init__(self, in_features, out_features, rank, linear, tp_size, tp_group, sequence_parallel, parallel_mode): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + if linear: + down = te.pytorch.Linear( + in_features, + rank, + bias=False, + tp_size=1, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=None, + ) + up = te.pytorch.Linear( + rank, + out_features, + bias=False, + tp_size=tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=parallel_mode, + ) + else: + down = te.pytorch.Conv1d( + in_features, + rank, + 1, + bias=False, + tp_size=1, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=None, + ) + up = te.pytorch.Conv1d( + rank, + out_features, + 1, + bias=False, + tp_size=tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=parallel_mode, + ) + tp_rank = parallel_state.get_tensor_model_parallel_rank() + # Create generator + gen = torch.Generator(device=down.weight.device) + # Save the current random state + gen_state = gen.get_state() + + # Set constant seed for non-tp layers + log.info(f"rank {tp_rank}: setting seed to 0") + gen.manual_seed(0) + nn.init.normal_(down.weight, std=1 / rank, generator=gen) + # Set a new random seed based on the tensor parallel rank + gen.manual_seed(tp_rank) + log.info(f"rank {tp_rank}: setting seed to {tp_rank}") + nn.init.zeros_(up.weight) + # Restore the original random state + gen.set_state(gen_state) + + self.net = nn.Sequential(down, up) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.net[0].weight.dtype + up_hidden_states = self.net(hidden_states.to(dtype)) + + return up_hidden_states.to(orig_dtype) diff --git a/cosmos_predict1/diffusion/training/utils/peft/peft.py b/cosmos_predict1/diffusion/training/utils/peft/peft.py new file mode 100644 index 0000000000000000000000000000000000000000..2540514d325976b8f84501c02e0e1ac043c9349d --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/peft.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils import log +from cosmos_predict1.utils.misc import count_params + + +def get_all_lora_params(model): + """ + Get all LoRA weight parameters in the model + """ + lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name] + lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()] + log.info(f"Found {len(lora_params)} LoRA weight matrices") + return lora_params + + +def setup_lora_requires_grad(model): + """ + Freeze all model parameters except LoRA parameters. + """ + num_param = count_params(model, verbose=True) + log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing") + lora_params = get_all_lora_params(model) + num_lora_param = sum([p.numel() for _, p in lora_params]) + log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M") + if num_lora_param > 0: + log.info("Freezing all parameters") + model.requires_grad_(False) + log.info("Unfreezing LoRA parameters") + for name, param in lora_params: + # log.info(f"Unfreezing loRA : {name}") + param.requires_grad_(True) + num_param = count_params(model, verbose=True) + log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing") + return num_lora_param + + +def add_lora_layers(model, peft_control_config): + for i, block_name in enumerate(model.net.blocks): + block = model.net.blocks[block_name] + peft_control = peft_control_config.get(i, {}) + for j, subblock in enumerate(block.blocks): + block_type = subblock.block_type + peft_control_subblock = peft_control.get(block_type.upper(), {}) + customization_type = peft_control_subblock.get("customization_type", None) + if customization_type == CustomizationType.LORA: + if block_type.upper() in ["CA", "FA"]: + build_attn_lora(subblock.block.attn, peft_control_subblock) diff --git a/cosmos_predict1/diffusion/types.py b/cosmos_predict1/diffusion/types.py new file mode 100644 index 0000000000000000000000000000000000000000..0459d88d423794a97fb62ee3275c57c9d1c2007a --- /dev/null +++ b/cosmos_predict1/diffusion/types.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class LabelImageCondition: + label: torch.Tensor + + def get_classifier_free_guidance_condition(self) -> LabelImageCondition: + return LabelImageCondition(torch.zeros_like(self.label)) + + +@dataclass +class DenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty diff --git a/cosmos_predict1/diffusion/utils/customization/customization_manager.py b/cosmos_predict1/diffusion/utils/customization/customization_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc6d8446b7c443512e9fa97174a11d66253adb4 --- /dev/null +++ b/cosmos_predict1/diffusion/utils/customization/customization_manager.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum + + +class CustomizationType(Enum): + LORA = 1 + REPLACE = 2 + + @classmethod + def from_value(cls, value): + """Convert both int and str to the corresponding enum.""" + if isinstance(value, str): + value = value.lower() + if value == "lora": + return cls.LORA + elif value == "replace": + return cls.REPLACE + elif value == "": + return None + else: + raise ValueError("Customization type must be lora or replace") + raise TypeError("CustomizationType must be specified as a string.") diff --git a/cosmos_predict1/tokenizer/__init__.py b/cosmos_predict1/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/inference/__init__.py b/cosmos_predict1/tokenizer/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/tokenizer/inference/image_cli.py b/cosmos_predict1/tokenizer/inference/image_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..5da9961c150a4ae3d0b81c81d430308c76c231a4 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/image_cli.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to run ImageTokenizer on plain images based on torch.jit. + +Usage: + python3 -m cosmos_predict1.tokenizer.inference.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_predict1.tokenizer.inference.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --mode torch \ + --tokenizer_type CI8x8 \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np +from loguru import logger as logging + +from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer +from cosmos_predict1.tokenizer.inference.utils import ( + get_filepaths, + get_output_filepath, + read_image, + resize_image, + write_image, +) +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.") + parser.add_argument( + "--image_pattern", + type=str, + default="path/to/images/*.jpg", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + default=None, + choices=[ + "CI8x8-360p", + "CI16x16-360p", + "DI8x8-360p", + "DI16x16-360p", + ], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision. Default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input image will be be outputed too.", + ) + args = parser.parse_args() + return args + + +logging.info("Initializes args ...") +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type is None: + logging.error("'torch' backend requires the tokenizer_type to be specified.") + sys.exit(1) + + +def _run_eval() -> None: + """Invokes the evaluation pipeline.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") + return + + if args.mode == "torch": + _type = args.tokenizer_type.replace("-", "_") + _config = TokenizerConfigs[_type].value + else: + _config = None + + logging.info( + f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." + ) + autoencoder = ImageTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=_config, + device=args.device, + dtype=args.dtype, + ) + + filepaths = get_filepaths(args.image_pattern) + logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.") + + for filepath in filepaths: + logging.info(f"Reading image {filepath} ...") + image = read_image(filepath) + image = resize_image(image, short_size=args.short_size) + batch_image = np.expand_dims(image, axis=0) + + logging.info("Invoking the autoencoder model in ... ") + output_image = autoencoder(batch_image)[0] + + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + logging.info(f"Outputing {output_filepath} ...") + write_image(output_filepath, output_image) + + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_image(input_filepath, image) + + +@logging.catch(reraise=True) +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/inference/image_lib.py b/cosmos_predict1/tokenizer/inference/image_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..9929d871a2d5862b0ced9e4e937f28033124ba62 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/image_lib.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A library for image tokenizers inference.""" + +from typing import Any + +import numpy as np +import torch + +from cosmos_predict1.tokenizer.inference.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_image_batch, + tensor2numpy, + unpad_image_batch, +) + + +class ImageTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of image tensors after embedding into a latent. + + Args: + input_tensor: The input image Bx3xHxW layout, range [-1..1]. + Returns: + The reconstructed tensor, layout Bx3xHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Decodes an image from a provided latent embedding. + + Args: + input_latent: The continuous latent Bx16xhxw for CI, + or the discrete indices Bxhxw for DI. + Returns: + The output tensor in Bx3xHxW, range [-1..1]. + """ + return self._dec_model(input_latent) + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes an image into a latent embedding or code. + + Args: + input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. + Returns: + For continuous image (CI) tokenizer, the tuple contains: + - The latent embedding, Bx16x(h)x(w), where the compression + rate is (H/h x W/w), and channel dimension of 16. + For discrete image (DI) tokenizer, the tuple contains: + - The indices, Bx(h)x(w), from a codebook of size 64K, which + corresponds to FSQ levels of (8,8,8,5,5,5). + - The discrete code, Bx6x(h)x(w), where the compression rate is + again (H/h x W/w), and channel dimension of 6. + """ + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def forward(self, image: np.ndarray) -> np.ndarray: + """Reconstructs an image using a pre-trained tokenizer. + + Args: + image: The input image BxHxWxC layout, range [0..255]. + Returns: + The reconstructed image in range [0..255], layout BxHxWxC. + """ + padded_input_image, crop_region = pad_image_batch(image) + input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_image = tensor2numpy(output_tensor) + return unpad_image_batch(padded_output_image, crop_region) diff --git a/cosmos_predict1/tokenizer/inference/utils.py b/cosmos_predict1/tokenizer/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53feab8e282207a4810a32ce0f2d7ceb29273623 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/utils.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for the inference libraries.""" + +import os +from glob import glob +from typing import Any + +import mediapy as media +import numpy as np +import torch + +from cosmos_predict1.tokenizer.networks import TokenizerModels + +_DTYPE, _DEVICE = torch.bfloat16, "cuda" +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_SPATIAL_ALIGN = 16 +_TEMPORAL_ALIGN = 8 + + +def load_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + full_model.load_state_dict(ckpts.state_dict(), strict=True) + return full_model.eval().to(device) + + +def load_encoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + encoder_model = full_model.encoder_jit() + encoder_model.load_state_dict(ckpts.state_dict(), strict=True) + return encoder_model.eval().to(device) + + +def load_decoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + decoder_model = full_model.decoder_jit() + decoder_model.load_state_dict(ckpts.state_dict(), strict=True) + return decoder_model.eval().to(device) + + +def _load_pytorch_model( + jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" +) -> torch.nn.Module: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + tokenizer_name = tokenizer_config["name"] + model = TokenizerModels[tokenizer_name].value(**tokenizer_config) + ckpts = torch.jit.load(jit_filepath, map_location=device) + return model, ckpts + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + model = torch.jit.load(jit_filepath, map_location=device) + return model.eval().to(device) + + +def save_jit_model( + model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, + jit_filepath: str = None, +) -> None: + """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. + + Args: + model: JIT compiled model loaded onto `config.checkpoint.jit.device`. + jit_filepath: The filepath to the JIT-compiled model. + """ + torch.jit.save(model, jit_filepath) + + +def get_filepaths(input_pattern) -> list[str]: + """Returns a list of filepaths from a pattern.""" + filepaths = sorted(glob(str(input_pattern))) + return list(set(filepaths)) + + +def get_output_filepath(filepath: str, output_dir: str = None) -> str: + """Returns the output filepath for the given input filepath.""" + output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" + output_filepath = f"{output_dir}/{os.path.basename(filepath)}" + os.makedirs(output_dir, exist_ok=True) + return output_filepath + + +def read_image(filepath: str) -> np.ndarray: + """Reads an image from a filepath. + + Args: + filepath: The filepath to the image. + + Returns: + The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. + """ + image = media.read_image(filepath) + # convert the grey scale image to RGB + # since our tokenizers always assume 3-channel RGB image + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + # convert RGBA to RGB + if image.shape[-1] == 4: + image = image[..., :3] + return image + + +def read_video(filepath: str) -> np.ndarray: + """Reads a video from a filepath. + + Args: + filepath: The filepath to the video. + Returns: + The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. + """ + video = media.read_video(filepath) + # convert the grey scale frame to RGB + # since our tokenizers always assume 3-channel video + if video.ndim == 3: + video = np.stack([video] * 3, axis=-1) + # convert RGBA to RGB + if video.shape[-1] == 4: + video = video[..., :3] + return video + + +def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes an image to have the short side of `short_size`. + + Args: + image: The image to resize, layout HxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized image. + """ + if short_size is None: + return image + height, width = image.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_image(image, shape=(height_new, width_new)) + + +def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes a video to have the short side of `short_size`. + + Args: + video: The video to resize, layout TxHxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized video. + """ + if short_size is None: + return video + height, width = video.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_video(video, shape=(height_new, width_new)) + + +def write_image(filepath: str, image: np.ndarray): + """Writes an image to a filepath.""" + return media.write_image(filepath, image) + + +def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: + """Writes a video to a filepath.""" + return media.write_video(filepath, video, fps=fps) + + +def numpy2tensor( + input_image: np.ndarray, + dtype: torch.dtype = _DTYPE, + device: str = _DEVICE, + range_min: int = -1, +) -> torch.Tensor: + """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. + + Args: + input_image: A batch of images in range [0..255], BxHxWx3 layout. + Returns: + A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. + """ + ndim = input_image.ndim + indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] + image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F + if range_min == -1: + image = 2.0 * image - 1.0 + return torch.from_numpy(image).to(dtype).to(device) + + +def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: + """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. + + Args: + input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. + Returns: + A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. + """ + if range_min == -1: + input_tensor = (input_tensor.float() + 1.0) / 2.0 + ndim = input_tensor.ndim + output_image = input_tensor.clamp(0, 1).cpu().numpy() + output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) + return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) + + +def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: + """Pads a batch of images to be divisible by `spatial_align`. + + Args: + batch: The batch of images to pad, layout BxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + height, width = batch.shape[1:3] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [ + height_to_pad >> 1, + width_to_pad >> 1, + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + return batch, crop_region + + +def pad_video_batch( + batch: np.ndarray, + temporal_align: int = _TEMPORAL_ALIGN, + spatial_align: int = _SPATIAL_ALIGN, +) -> tuple[np.ndarray, list[int]]: + """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. + + Zero pad spatially. Reflection pad temporally to handle causality better. + Args: + batch: The batch of videos to pad., layout BxFxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + num_frames, height, width = batch.shape[-4:-1] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + align = temporal_align + frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 + + crop_region = [ + frames_to_pad >> 1, + height_to_pad >> 1, + width_to_pad >> 1, + num_frames + (frames_to_pad >> 1), + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + batch = np.pad( + batch, + ( + (0, 0), + (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), + (0, 0), + (0, 0), + (0, 0), + ), + mode="edge", + ) + return batch, crop_region + + +def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads video with `crop_region`. + + Args: + batch: A batch of numpy videos, layout BxFxHxWxC. + crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy video, layout BxFxHxWxC. + """ + assert len(crop_region) == 6, "crop_region should be len of 6." + f1, y1, x1, f2, y2, x2 = crop_region + return batch[..., f1:f2, y1:y2, x1:x2, :] + + +def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads image with `crop_region`. + + Args: + batch: A batch of numpy images, layout BxHxWxC. + crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy image, layout BxHxWxC. + """ + assert len(crop_region) == 4, "crop_region should be len of 4." + y1, x1, y2, x2 = crop_region + return batch[..., y1:y2, x1:x2, :] diff --git a/cosmos_predict1/tokenizer/inference/video_cli.py b/cosmos_predict1/tokenizer/inference/video_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..51e17903c6b6b825c78c2e40760ea88aff51141f --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/video_cli.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. + +Usage: + python3 -m cosmos_predict1.tokenizer.inference.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_predict1.tokenizer.inference.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --mode=torch \ + --tokenizer_type=CV \ + --temporal_compression=4 \ + --spatial_compression=8 \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np +from loguru import logger as logging + +from cosmos_predict1.tokenizer.inference.utils import ( + get_filepaths, + get_output_filepath, + read_video, + resize_video, + write_video, +) +from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") + parser.add_argument( + "--video_pattern", + type=str, + default="path/to/videos/*.mp4", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + default=None, + choices=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CV4x8x8-360p", + "DV4x8x8-360p", + ], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--temporal_window", + type=int, + default=17, + help="The temporal window to operate at a time.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision, default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--output_fps", + type=float, + default=24.0, + help="Output frames-per-second (FPS).", + ) + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input video will be be outputted too.", + ) + args = parser.parse_args() + return args + + +logging.info("Initializes args ...") +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type is None: + logging.error("`torch` backend requires `--tokenizer_type` to be specified.") + sys.exit(1) + + +def _run_eval() -> None: + """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") + return + + if args.mode == "torch": + _type = args.tokenizer_type.replace("-", "_") + _config = TokenizerConfigs[_type].value + else: + _config = None + + logging.info( + f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." + ) + autoencoder = CausalVideoTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=_config, + device=args.device, + dtype=args.dtype, + ) + + logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...") + filepaths = get_filepaths(args.video_pattern) + logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.") + + for filepath in filepaths: + logging.info(f"Reading video {filepath} ...") + video = read_video(filepath) + video = resize_video(video, short_size=args.short_size) + + logging.info("Invoking the autoencoder model in ... ") + batch_video = video[np.newaxis, ...] + output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] + logging.info("Constructing output filepath ...") + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + logging.info(f"Outputing {output_filepath} ...") + write_video(output_filepath, output_video, fps=args.output_fps) + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_video(input_filepath, video, fps=args.output_fps) + + +@logging.catch(reraise=True) +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/inference/video_lib.py b/cosmos_predict1/tokenizer/inference/video_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe8d2ec1b52e01c4823a1f660133a529762744d --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/video_lib.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A library for Causal Video Tokenizer inference.""" + +from typing import Any + +import numpy as np +import torch +from tqdm import tqdm + +from cosmos_predict1.tokenizer.inference.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_video_batch, + tensor2numpy, + unpad_video_batch, +) + + +class CausalVideoTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of video tensors after embedding into a latent. + + Args: + video: The input video Bx3xTxHxW layout, range [-1..1]. + Returns: + The reconstructed video, layout Bx3xTxHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes a numpy video into a CausalVideo latent or code. + + Args: + input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. + Returns: + For causal continuous video (CV) tokenizer, the tuple contains: + - The latent embedding, Bx16x(t)x(h)x(w), where the compression + rate is (T/t x H/h x W/w), and channel dimension of 16. + For causal discrete video (DV) tokenizer, the tuple contains: + 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which + is formed by FSQ levels of (8,8,8,5,5,5). + 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate + is again (T/t x H/h x W/w), and channel dimension of 6. + """ + assert input_tensor.ndim == 5, "input video should be of 5D." + + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Encodes a numpy video into a CausalVideo latent. + + Args: + input_latent: The continuous latent Bx16xtxhxw for CV, + or the discrete indices Bxtxhxw for DV. + Returns: + The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. + """ + assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." + return self._dec_model(input_latent) + + def forward( + self, + video: np.ndarray, + temporal_window: int = 17, + ) -> np.ndarray: + """Reconstructs video using a pre-trained CausalTokenizer autoencoder. + Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer + in a sliding manner with a `temporal_window` size. + + Args: + video: The input video BxTxHxWx3 layout, range [0..255]. + temporal_window: The length of the temporal window to process, default=25. + Returns: + The reconstructed video in range [0..255], layout BxTxHxWx3. + """ + assert video.ndim == 5, "input video should be of 5D." + num_frames = video.shape[1] # can be of any length. + output_video_list = [] + for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): + # Input video for the current window. + start, end = idx * temporal_window, (idx + 1) * temporal_window + input_video = video[:, start:end, ...] + + # Spatio-temporally pad input_video so it's evenly divisible. + padded_input_video, crop_region = pad_video_batch(input_video) + input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_video = tensor2numpy(output_tensor) + output_video = unpad_video_batch(padded_output_video, crop_region) + + output_video_list.append(output_video) + return np.concatenate(output_video_list, axis=1) diff --git a/cosmos_predict1/tokenizer/modules/__init__.py b/cosmos_predict1/tokenizer/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a15e15e4bfa8b0a33140ce78830be28defb91238 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/__init__.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from cosmos_predict1.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution +from cosmos_predict1.tokenizer.modules.layers2d import Decoder, Encoder +from cosmos_predict1.tokenizer.modules.layers3d import DecoderBase, DecoderFactorized, EncoderBase, EncoderFactorized +from cosmos_predict1.tokenizer.modules.quantizers import FSQuantizer, LFQuantizer, ResidualFSQuantizer, VectorQuantizer + + +class EncoderType(Enum): + Default = Encoder + + +class DecoderType(Enum): + Default = Decoder + + +class Encoder3DType(Enum): + BASE = EncoderBase + FACTORIZED = EncoderFactorized + + +class Decoder3DType(Enum): + BASE = DecoderBase + FACTORIZED = DecoderFactorized + + +class ContinuousFormulation(Enum): + VAE = GaussianDistribution + AE = IdentityDistribution + + +class DiscreteQuantizer(Enum): + VQ = VectorQuantizer + LFQ = LFQuantizer + FSQ = FSQuantizer + RESFSQ = ResidualFSQuantizer diff --git a/cosmos_predict1/tokenizer/modules/distributions.py b/cosmos_predict1/tokenizer/modules/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..2347f7453611d9fea87d0f530bd8e54f02c3f39e --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/distributions.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The distribution modes to use for continuous image tokenizers.""" + +import torch + + +class IdentityDistribution(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, parameters): + return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) + + +class GaussianDistribution(torch.nn.Module): + def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): + super().__init__() + self.min_logvar = min_logvar + self.max_logvar = max_logvar + + def sample(self, mean, logvar): + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + def forward(self, parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) + return self.sample(mean, logvar), (mean, logvar) diff --git a/cosmos_predict1/tokenizer/modules/layers2d.py b/cosmos_predict1/tokenizer/modules/layers2d.py new file mode 100644 index 0000000000000000000000000000000000000000..5770bcf62f45468568fd3c99e22d9d4c9582d38f --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers2d.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for Continuous 2D layers + +Adapted from: https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py + +[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors] +https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE +""" + +import math + +import numpy as np + +# pytorch_diffusion + derived encoder decoder +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.patching import Patcher, UnPatcher +from cosmos_predict1.tokenizer.modules.utils import Normalize, nonlinearity + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # calculate the number of downsample operations + self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_downsamples <= self.num_resolutions + ), f"we can only downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level < self.num_downsamples: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level < self.num_downsamples: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: int, + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + # calculate the number of upsample operations + self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level >= (self.num_resolutions - self.num_upsamples): + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level >= (self.num_resolutions - self.num_upsamples): + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher(h) + return h diff --git a/cosmos_predict1/tokenizer/modules/layers2d_test.py b/cosmos_predict1/tokenizer/modules/layers2d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1fac5a0015b842febd68f3ee9e153741fec9febf --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers2d_test.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The test for model definition of 2D layers + +PYTHONPATH=$PWD pytest -v cosmos_predict1/tokenizer/modules/layers2d_test.py +""" +import os + +import numpy as np +import pytest +import torch +from torchvision.transforms import CenterCrop + +from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer +from cosmos_predict1.tokenizer.inference.utils import read_image +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + +# test configs +TEST_CONFIGS = [ + ("CI8x8-360p", "checkpoints/Cosmos-Tokenize1-CI8x8-360p"), + ("CI16x16-360p", "checkpoints/Cosmos-Tokenize1-CI16x16-360p"), + ("DI8x8-360p", "checkpoints/Cosmos-Tokenize1-DI8x8-360p"), + ("DI16x16-360p", "checkpoints/Cosmos-Tokenize1-DI16x16-360p"), +] + + +@pytest.fixture(scope="module") +def image_tensor(): + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "test_data", "image.png") + print(f"image_path: {image_path}") + image = read_image(image_path) + + assert image.shape[0] >= 512, "Image height should be at least 512 pixels" + assert image.shape[1] >= 512, "Image width should be at least 512 pixels" + assert image.shape[2] == 3, "Image should have 3 channels" + + input_tensor = CenterCrop(512)( + torch.from_numpy(image[np.newaxis, ...]).to("cuda").to(torch.bfloat16).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0 + ) + return input_tensor + + +@pytest.mark.parametrize("config", TEST_CONFIGS) +def test_tokenizer(config, image_tensor): + name, model_id = config + continuous = name.startswith(("C", "c")) + [ + spatial_compression, + ] = list(map(int, name[2:].split("x")[:1])) + print(f"\nTesting tokenizer: {model_id}") + print(f"spatial_compression={spatial_compression}") + + _config = TokenizerConfigs[name.replace("-", "_")].value + autoencoder = ImageTokenizer( + checkpoint_enc=f"{model_id}/encoder.jit", + checkpoint_dec=f"{model_id}/decoder.jit", + tokenizer_config=_config, + device="cuda", + dtype="bfloat16", + ) + + try: + # Test shape check + reconstructed_tensor = auto_shape_check(image_tensor, autoencoder, spatial_compression, continuous) + finally: + # Cleanup + del autoencoder + del reconstructed_tensor + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def auto_shape_check(input_tensor, autoencoder, spatial_compression, continuous): + if continuous: + (latent,) = autoencoder.encode(input_tensor) + torch.testing.assert_close(latent.shape, (1, 16, 512 // spatial_compression, 512 // spatial_compression)) + reconstructed_tensor = autoencoder.decode(latent) + else: + (indices, codes) = autoencoder.encode(input_tensor) + torch.testing.assert_close(indices.shape, (1, 512 // spatial_compression, 512 // spatial_compression)) + torch.testing.assert_close(codes.shape, (1, 6, 512 // spatial_compression, 512 // spatial_compression)) + reconstructed_tensor = autoencoder.decode(indices) + + torch.testing.assert_close(reconstructed_tensor.shape, input_tensor.shape) + return reconstructed_tensor diff --git a/cosmos_predict1/tokenizer/modules/layers3d.py b/cosmos_predict1/tokenizer/modules/layers3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4d12c37240d6f3c5e1d38870fa4b9099c5167b2e --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers3d.py @@ -0,0 +1,949 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D +from cosmos_predict1.tokenizer.modules.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) + +_LEGACY_NUM_GROUPS = 32 + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalUpsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = self.conv(x) + return x[..., int(time_factor - 1) :, :, :] + + +class CausalDownsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + time_stride=2, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x = replication_pad(x) + x = self.conv(x) + return x + + +class CausalHybridUpsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_up: bool = True, + temporal_up: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0) + if temporal_up + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1) + if spatial_up + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_up or temporal_up + else nn.Identity() + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_down: bool = True, + temporal_down: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0) + if spatial_down + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0) + if temporal_down + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_down or temporal_down + else nn.Identity() + ) + + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlock3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderBase(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # downsampling + self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = CausalDownsample3d(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def patcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.patcher(x) + x = batch2time(x, batch_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + else: + # temporal downsample (last level) + time_factor = 1 + 1 * (hs[-1].shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + hs[-1] = replication_pad(hs[-1]) + hs.append( + F.avg_pool3d( + hs[-1], + kernel_size=[time_factor, 1, 1], + stride=[2, 1, 1], + ) + ) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderBase(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = CausalUpsample3d(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.unpatcher(x) + x = batch2time(x, batch_size) + + return x + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # temporal upsample (last level) + time_factor = 1.0 + 1.0 * (h.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + h = h.repeat_interleave(int(time_factor), dim=2) + h = h[..., int(time_factor - 1) :, :, :] + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d( + in_channels, + channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d( + z_channels, + z_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed + # in the encoder should correspond to the layer index in + # reverse order where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/cosmos_predict1/tokenizer/modules/layers3d_test.py b/cosmos_predict1/tokenizer/modules/layers3d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..711e279afa811123485196af2355dd6bc6b27a36 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers3d_test.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The test for model definition of 3D layers + +PYTHONPATH=$PWD pytest -v cosmos_predict1/tokenizer/modules/layers3d_test.py +""" +import os + +import numpy as np +import pytest +import torch +from torchvision.transforms import CenterCrop + +from cosmos_predict1.tokenizer.inference.utils import read_video +from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + +# test configs +TEST_CONFIGS = [ + ("CV8x8x8-720p", "checkpoints/Cosmos-Tokenize1-CV8x8x8-720p"), + ("DV8x16x16-720p", "checkpoints/Cosmos-Tokenize1-DV8x16x16-720p"), + ("CV4x8x8-360p", "checkpoints/Cosmos-Tokenize1-CV4x8x8-360p"), + ("DV4x8x8-360p", "checkpoints/Cosmos-Tokenize1-DV4x8x8-360p"), +] + + +@pytest.fixture(scope="module") +def video_tensor(): + video_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "test_data", "video.mp4") + print(f"video_path: {video_path}") + video = read_video(video_path) + + assert video.shape[0] >= 17, "Video length should be at least 17 frames" + assert video.shape[1] >= 512, "Video height should be at least 512 pixels" + assert video.shape[2] >= 512, "Video width should be at least 512 pixels" + assert video.shape[3] == 3, "Video should have 3 channels" + + input_tensor = CenterCrop(512)( + torch.from_numpy(video[np.newaxis, ...])[:, :17].to("cuda").to(torch.bfloat16).permute(0, 4, 1, 2, 3) + / 255.0 + * 2.0 + - 1.0 + ) + return input_tensor + + +@pytest.mark.parametrize("config", TEST_CONFIGS) +def test_tokenizer(config, video_tensor): + name, model_id = config + continuous = name.startswith(("C", "c")) + temporal_compression, spatial_compression = list(map(int, name[2:].split("x")[:2])) + print(f"\nTesting tokenizer: {model_id}") + print(f"temporal_compression={temporal_compression}") + print(f"spatial_compression={spatial_compression}") + print(f"checkpoint_enc=checkpoints/{os.path.basename(model_id)}/encoder.jit") + print(f"checkpoint_dec=checkpoints/{os.path.basename(model_id)}/decoder.jit") + + _config = TokenizerConfigs[name.replace("-", "_")].value + autoencoder = CausalVideoTokenizer( + checkpoint_enc=f"checkpoints/{os.path.basename(model_id)}/encoder.jit", + checkpoint_dec=f"checkpoints/{os.path.basename(model_id)}/decoder.jit", + tokenizer_config=_config, + device="cuda", + dtype="bfloat16", + ) + + try: + # Test shape check + reconstructed_tensor = auto_shape_check( + video_tensor, autoencoder, temporal_compression, spatial_compression, continuous + ) + finally: + # Cleanup + del autoencoder + del reconstructed_tensor + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def auto_shape_check(input_tensor, autoencoder, temporal_compression, spatial_compression, continuous): + if continuous: + (latent,) = autoencoder.encode(input_tensor) + torch.testing.assert_close( + latent.shape, + (1, 16, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + reconstructed_tensor = autoencoder.decode(latent) + else: + (indices, codes) = autoencoder.encode(input_tensor) + torch.testing.assert_close( + indices.shape, + (1, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + torch.testing.assert_close( + codes.shape, + (1, 6, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + reconstructed_tensor = autoencoder.decode(indices) + + torch.testing.assert_close(reconstructed_tensor.shape, input_tensor.shape) + return reconstructed_tensor diff --git a/cosmos_predict1/tokenizer/modules/patching.py b/cosmos_predict1/tokenizer/modules/patching.py new file mode 100644 index 0000000000000000000000000000000000000000..028df019cdf9bf1126682e144f09c107da834908 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/patching.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The patcher and unpatcher implementation for 2D and 3D data. + +The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. +One on the rows and one on the columns. +For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. +We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. +For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. +Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all + as we need to support downsampling for more than 2x. +For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. + [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = True + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange( + x, + "b c (h p1) (w p2) -> b (c p1 p2) h w", + p1=self.patch_size, + p2=self.patch_size, + ).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", + patch_size * torch.ones([1], dtype=torch.int32), + persistent=_PERSISTENT, + ) + + def _dwt(self, x, wavelet, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, "haar", rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/cosmos_predict1/tokenizer/modules/quantizers.py b/cosmos_predict1/tokenizer/modules/quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..70a1e0c95c9e1143eb278d1c4a554073939af437 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/quantizers.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import reduce +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.utils import default, entropy, pack_one, rearrange, round_ste, unpack_one + +_PERSISTENT = True + + +class ResidualFSQuantizer(nn.Module): + """Residual Finite Scalar Quantization + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + indices_stack = [] + residual = x + quantized_out = 0 + loss_out = 0 + for i, layer in enumerate(self.layers): + quant_indices, z, loss = layer(residual) + indices_stack.append(quant_indices) + residual = residual - z.detach() + quantized_out = quantized_out + z + loss_out = loss_out + loss + self.residual = residual + indices = torch.stack(indices_stack, dim=1) + return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype) + + def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor: + quantized_out = 0 + for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)): + quantized_out += layer.indices_to_codes(indices) + return quantized_out + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Code adapted from Jax version in Appendix A.1. + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) + self.persistent = ignore_kwargs.get("persistent_quantizer", _PERSISTENT) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=self.persistent) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=self.persistent) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=self.persistent) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) + + +class VectorQuantizer(nn.Module): + """Improved version over VectorQuantizer. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + + Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/ + taming/modules/vqvae/quantize.py + + [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer] + https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25, + remap: str = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + use_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.n_e = num_embeddings + self.e_dim = embedding_dim + self.beta = beta + self.legacy = legacy + self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = num_embeddings + + self.sane_index_shape = sane_index_shape + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + z_flattened, + rearrange(self.embedding.weight, "n d -> d n"), + ) + ) + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) + min_encodings = None + + z_q, z = self.norm(z_q), self.norm(z) + + # compute loss for embedding + commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True) + emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True) + if not self.legacy: + loss = self.beta * emb_loss + commit_loss + else: + loss = emb_loss + self.beta * commit_loss + + # preserve gradients + z_q = z + (z_q - z).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1)) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return ( + z_q, + loss, + ( + encoding_indices.squeeze(1), + min_encodings, + commit_loss.mean().detach(), + self.beta * emb_loss.mean().detach(), + perplexity.mean().detach(), + ), + ) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class LFQuantizer(nn.Module): + """Lookup-Free Quantization + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/lookup_free_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + *, + codebook_size: int, + codebook_dim: int, + embed_dim: Optional[int] = None, # if None, use codebook_dim + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + default_temp: float = 0.01, + entropy_loss: bool = False, + **ignore_kwargs, + ): + """Lookup-Free Quantization + + Args: + codebook_size (int): The number of entries in the codebook. + codebook_dim (int): The number of bits in each code. + embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None. + entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1. + commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25. + default_temp (float, optional): The temprature to use. Defaults to 0.01. + entropy_loss (bool, optional): Flag for entropy loss. Defaults to False. + """ + super().__init__() + self.entropy_loss = entropy_loss + self.codebook_dim = codebook_dim + self.default_temp = default_temp + self.entrop_loss_weight = entropy_loss_weight + self.commitment_loss_weight = commitment_loss_weight + embed_dim = embed_dim or codebook_dim + + has_projections = embed_dim != codebook_dim + self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity() + logging.info(f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}") + + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + if entropy_loss: + assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim" + self.codebook_size = codebook_size + + self.register_buffer( + "mask", + 2 ** torch.arange(codebook_dim - 1, -1, -1), + persistent=_PERSISTENT, + ) + self.register_buffer("zero", torch.tensor(0.0), persistent=_PERSISTENT) + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = 2 * bits - 1.0 + + self.register_buffer("codebook", codebook, persistent=_PERSISTENT) # [codebook_size, codebook_dim] + + def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor: + temp = temp or self.default_temp + + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + z = self.project_in(z) + + # split out number of codebooks + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantization + original_input = z + + codebook_value = torch.ones_like(z) + z_q = torch.where(z > 0, codebook_value, -codebook_value) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # commit loss + commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3]) + + z_q = rearrange(z_q, "b n c d -> b n (c d)") + z_q = self.project_out(z_q) + + # reshape + z_q = unpack_one(z_q, ps, "b * d") + z_q = rearrange(z_q, "b ... d -> b d ...") + + loss = self.commitment_loss_weight * commit_loss + + # entropy loss (eq-5) + if self.entropy_loss: + # indices + indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") + indices = unpack_one(indices, ps, "b * c") + indices = rearrange(indices, "... 1 -> ...") + + distance = -2 * torch.einsum( + "... i d, j d -> ... i j", + original_input, + self.codebook.to(original_input.dtype), + ) + prob = (-distance / temp).softmax(dim=-1) + per_sample_entropy = entropy(prob).mean(dim=[1, 2]) + avg_prob = reduce(prob, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + entropy_aux_loss = per_sample_entropy - codebook_entropy + + loss += self.entrop_loss_weight * entropy_aux_loss + + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + ( + indices, + self.commitment_loss_weight * commit_loss.mean().detach(), + self.entrop_loss_weight * entropy_aux_loss.mean().detach(), + self.entrop_loss_weight * per_sample_entropy.mean().detach(), + self.entrop_loss_weight * codebook_entropy.mean().detach(), + ), + ) + else: + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + self.commitment_loss_weight * commit_loss.mean().detach(), + ) + + +class InvQuantizerJit(nn.Module): + """Use for decoder_jit to trace quantizer in discrete tokenizer""" + + def __init__(self, quantizer): + super().__init__() + self.quantizer = quantizer + + def forward(self, indices: torch.Tensor): + codes = self.quantizer.indices_to_codes(indices) + return codes.to(self.quantizer.dtype) diff --git a/cosmos_predict1/tokenizer/modules/utils.py b/cosmos_predict1/tokenizer/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..578bf2fa3f15e0dbe05054d30fda380f6d93e53f --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/utils.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for the networks module.""" + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) diff --git a/cosmos_predict1/tokenizer/networks/__init__.py b/cosmos_predict1/tokenizer/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b820aba6f474ec1ebe798b82b5f41362e0c43a4f --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/__init__.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from cosmos_predict1.tokenizer.networks.configs import continuous_image_8x8_360p as continuous_image_8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_image_16x16_360p as continuous_image_16x16_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_video_4x8x8_360p as continuous_video_4x8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_video_8x8x8_720p as continuous_video_8x8x8_720p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_image_8x8_360p as discrete_image_8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_image_16x16_360p as discrete_image_16x16_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_video_4x8x8_360p as discrete_video_4x8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_video_8x16x16_720p as discrete_video_8x16x16_720p_dict +from cosmos_predict1.tokenizer.networks.continuous_image import ContinuousImageTokenizer +from cosmos_predict1.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer +from cosmos_predict1.tokenizer.networks.discrete_image import DiscreteImageTokenizer +from cosmos_predict1.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer + + +class TokenizerConfigs(Enum): + """Continuous Image (CI) Tokenizer Configs""" + + # Cosmos-Tokenize1-CI8x8-360p + CI8x8_360p = continuous_image_8x8_360p_dict + + # Cosmos-Tokenize1-CI16x16-360p + CI16x16_360p = continuous_image_16x16_360p_dict + + """Discrete Image (DI) Tokenizer Configs""" + # Cosmos-Tokenize1-DI8x8-360p + DI8x8_360p = discrete_image_8x8_360p_dict + + # Cosmos-Tokenize1-DI16x16-360p + DI16x16_360p = discrete_image_16x16_360p_dict + + """Causal Continuous Video (CV) Tokenizer Configs""" + # Cosmos-Tokenize1-CV8x8x8-720p + CV8x8x8_720p = continuous_video_8x8x8_720p_dict + + # Cosmos-Tokenize1-CV4x8x8-360p + CV4x8x8_360p = continuous_video_4x8x8_360p_dict + + """Causal Discrete Video (DV) Tokenizer Configs""" + # Cosmos-Tokenize1-DV8x16x16-720p + DV8x16x16_720p = discrete_video_8x16x16_720p_dict + + # Cosmos-Tokenize1-DV4x8x8-360p + DV4x8x8_360p = discrete_video_4x8x8_360p_dict + + +class TokenizerModels(Enum): + CI = ContinuousImageTokenizer + DI = DiscreteImageTokenizer + CV = CausalContinuousVideoTokenizer + DV = CausalDiscreteVideoTokenizer diff --git a/cosmos_predict1/tokenizer/networks/configs.py b/cosmos_predict1/tokenizer/networks/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..540e04cefcfd08d1325c168830d5990f44eb760b --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/configs.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The default image and video tokenizer configs.""" + +from cosmos_predict1.tokenizer.modules import ( + ContinuousFormulation, + Decoder3DType, + DecoderType, + DiscreteQuantizer, + Encoder3DType, + EncoderType, +) + +continuous_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + z_factor=1, + name="CI", + # What formulation to use, either "AE" or "VAE". + # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) +continuous_image_8x8_360p = dict(continuous_image) +continuous_image_8x8_360p["patch_size"] = 2 +continuous_image_8x8_360p["spatial_compression"] = 8 + +continuous_image_16x16_360p = dict(continuous_image) +continuous_image_16x16_360p["patch_size"] = 2 +continuous_image_16x16_360p["spatial_compression"] = 16 + + +discrete_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DI", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) +discrete_image_8x8_360p = dict(discrete_image) +discrete_image_8x8_360p["patch_size"] = 2 +discrete_image_8x8_360p["spatial_compression"] = 8 + +discrete_image_16x16_360p = dict(discrete_image) +discrete_image_16x16_360p["patch_size"] = 2 +discrete_image_16x16_360p["spatial_compression"] = 16 + +continuous_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CV", +) + +continuous_video_8x8x8_720p = dict(continuous_video) +continuous_video_8x8x8_720p["temporal_compression"] = 8 +continuous_video_8x8x8_720p["spatial_compression"] = 8 + +continuous_video_4x8x8_360p = dict(continuous_video) +continuous_video_4x8x8_360p["temporal_compression"] = 4 +continuous_video_4x8x8_360p["spatial_compression"] = 8 +continuous_video_4x8x8_360p["patch_size"] = 2 + + +discrete_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="DV", +) + +discrete_video_8x16x16_720p = dict(discrete_video) +discrete_video_8x16x16_720p["temporal_compression"] = 8 +discrete_video_8x16x16_720p["spatial_compression"] = 16 + +discrete_video_4x8x8_360p = dict(discrete_video) +discrete_video_4x8x8_360p["z_channels"] = 256 +discrete_video_4x8x8_360p["temporal_compression"] = 4 +discrete_video_4x8x8_360p["spatial_compression"] = 8 +discrete_video_4x8x8_360p["patch_size"] = 2 diff --git a/cosmos_predict1/tokenizer/networks/continuous_image.py b/cosmos_predict1/tokenizer/networks/continuous_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e288d2f39bf69a0895f97b63a3ef3e04fbcdb6 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/continuous_image.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The continuous image tokenizer with VAE or AE formulation for 2D data.""" + +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import ContinuousFormulation, DecoderType, EncoderType + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class ContinuousImageTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "ContinuousImageTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: + latent, posteriors = self.encode(input) + dec = self.decode(latent) + if self.training: + return dict(reconstructions=dec, posteriors=posteriors, latent=latent) + return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) diff --git a/cosmos_predict1/tokenizer/networks/continuous_video.py b/cosmos_predict1/tokenizer/networks/continuous_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c054427d9c4a20f0f0d5e606ddb62e68cadfbeb3 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/continuous_video.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" +from collections import OrderedDict, namedtuple + +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import ContinuousFormulation, Decoder3DType, Encoder3DType +from cosmos_predict1.tokenizer.modules.layers3d import CausalConv3d + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class CausalContinuousVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d( + z_factor * z_channels, + z_factor * latent_channels, + kernel_size=1, + padding=0, + ) + self.post_quant_conv = CausalConv3d(latent_channels, z_channels, kernel_size=1, padding=0) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + return self.decoder(z) + + def forward(self, input): + latent, posteriors = self.encode(input) + reconstructions = self.decode(latent) + if self.training: + return dict( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) + return NetworkEval( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) diff --git a/cosmos_predict1/tokenizer/networks/discrete_image.py b/cosmos_predict1/tokenizer/networks/discrete_image.py new file mode 100644 index 0000000000000000000000000000000000000000..02b160b43028912ea7b17cd38a0e98375aa839a7 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/discrete_image.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType +from cosmos_predict1.tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class DiscreteImageTokenizer(nn.Module): + def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "DiscreteImageTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) + self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}.name." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(DiscreteImageTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/cosmos_predict1/tokenizer/networks/discrete_video.py b/cosmos_predict1/tokenizer/networks/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..db1aea39103f7e4bc92db10858890dd3b54b1e4b --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/discrete_video.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ. """ +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import Decoder3DType, DiscreteQuantizer, Encoder3DType +from cosmos_predict1.tokenizer.modules.layers3d import CausalConv3d +from cosmos_predict1.tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb b/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..75b3dcac46e5247fb646badad7e8a305b8bd50f2 --- /dev/null +++ b/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n3ryhkSfIEfl" + }, + "source": [ + "# Image Tokenization Using [NVIDIA Cosmos Tokenizer](https://github.com/NVIDIA-Cosmos/cosmos-predict1/blob/main/cosmos1/models/tokenizer) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer/notebook/Image_Tokenization.ipynb)\n", + "\n", + "The Jupyter Notebook example utilizes the **Cosmos-Tokenizer** pretrained models, which include Continuous Image (CI) tokenizers that transform images into continuous latents and Discrete Image (DI) tokenizers that transform images into discrete tokens. Both CI and DI tokenizers are available with compression rates of 8x8 and 16x16. For instance, **CI16x16** effectively downsizes both height and width by a factor of 16.\n", + "\n", + "Within the notebook, the `ImageTokenizer` class from the `cosmos_tokenizer.image_lib` module is employed to manage the encoder and decoder components of this model. The encoder compresses the input image into a condensed latent representation or discrete integers, while the decoder reconstructs the image from this latent representation or discrete integers.\n", + "\n", + "This instance of the Cosmos Tokenizer demonstrates its autoencoding capability: compressing an image into a smaller latent space and subsequently reconstructing it to its original form. This showcases the efficiency of image tokenization for tasks involving significant spatial compression during image reconstruction, a highly desirable feature for generative modeling.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5BkjyLTPLM6e" + }, + "source": [ + "This tutorial follows a simple, step-by-step approach, making it easy to understand and adapt.\n", + "\n", + "## Step 1: Clone the Cosmos Tokenizer Repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TEV88M9YG973" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/NVIDIA-Cosmos/cosmos-predict1.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AxOMEJpFL9QL" + }, + "source": [ + "## Step 2: Install **Cosmos-Tokenizer**\n", + "Before proceeding, ensure you have the **Cosmos Tokenizer** installed. If you cloned the repository in Step 1, use the following command to install it in editable mode:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XuwUR6HrIxD8" + }, + "outputs": [], + "source": [ + "# Step 2: # Install Cosmos and its Python dependencies.\n", + "import os\n", + "if os.path.exists(\"cosmos-predict1\"):\n", + " os.chdir(\"cosmos-predict1\")\n", + " %pip install -r requirements.txt\n", + "else:\n", + " print('cosmos-predict1 is already installed.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "id29RPiyMOtB" + }, + "source": [ + "## Step 3: Set Up Hugging Face API Token and Download Pretrained Models\n", + "\n", + "In this step, you'll configure the Hugging Face API token and download the pretrained model weights required for the **Cosmos Tokenizer**.\n", + "\n", + "1. **Ensure You Have a Hugging Face Account** \n", + " If you do not already have a Hugging Face account, follow these steps to create one and generate an API token:\n", + " - Go to the [Hugging Face website](https://huggingface.co/) and sign up for a free account.\n", + " - After logging in, navigate to your [Settings → Access Tokens](https://huggingface.co/settings/tokens).\n", + " - Click on \"New Token\" to generate an API token with the required permissions.\n", + "\n", + "2. **Set the Hugging Face Token** \n", + " Check if the Hugging Face token is already set in the environment variables. If not, you will be prompted to enter it manually. The token is essential to authenticate and access the Hugging Face models.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "joxcyOlnM7HQ" + }, + "outputs": [], + "source": [ + "# Check if the token is already set\n", + "if \"HUGGINGFACE_TOKEN\" not in os.environ:\n", + " os.environ[\"HUGGINGFACE_TOKEN\"] = input(\"Please enter your Hugging Face API token: \")\n", + "!git config --global credential.helper store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Lq7MAQ9pGPH9" + }, + "outputs": [], + "source": [ + "from huggingface_hub import login, snapshot_download\n", + "import os\n", + "HUGGINGFACE_TOKEN = os.environ.get(\"HUGGINGFACE_TOKEN\")\n", + "login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)\n", + "model_names = [\n", + " \"Cosmos-0.1-Tokenizer-CI8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CI16x16\",\n", + " \"Cosmos-0.1-Tokenizer-DI8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DI16x16\",\n", + "]\n", + "for model_name in model_names:\n", + " hf_repo = \"nvidia/\" + model_name\n", + " local_dir = \"checkpoints/\" + model_name\n", + " os.makedirs(local_dir, exist_ok=True)\n", + " print(f\"downloading {model_name}...\")\n", + " snapshot_download(repo_id=hf_repo, local_dir=local_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ltZ-v2vzNv74" + }, + "source": [ + "## Step 4: Use Cosmos Tokenizer for Image Reconstruction\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 839 + }, + "id": "gZFPrGCBGwtC", + "outputId": "0df7efc4-7a40-4011-81a6-3c541ba1601f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input image read from:\t /content/Cosmos-Tokenizer/test_data/image.png\n", + "Reconstruction saved:\t /content/Cosmos-Tokenizer/test_data/image_CI8x8.png\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
Input Image
\n", + "
\n", + "
Reconstructed Image
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title In this step, load the required checkpoints, and perform image reconstruction. {\"run\":\"auto\"}\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import importlib\n", + "from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer\n", + "import mediapy as media\n", + "\n", + "\n", + "# 1) Specify the model name, and the paths to the encoder/decoder checkpoints.\n", + "model_name = 'Cosmos-0.1-Tokenizer-CI8x8' # @param [\"Cosmos-0.1-Tokenizer-CI16x16\", \"Cosmos-0.1-Tokenizer-CI8x8\", \"Cosmos-0.1-Tokenizer-DI8x8\", \"Cosmos-0.1-Tokenizer-DI16x16\"]\n", + "\n", + "encoder_ckpt = f\"checkpoints/{model_name}/encoder.jit\"\n", + "decoder_ckpt = f\"checkpoints/{model_name}/decoder.jit\"\n", + "\n", + "# 2) Load or provide the image filename you want to tokenize & reconstruct.\n", + "input_filepath = \"cosmos_predict1/tokenizer/test_data/image.png\"\n", + "\n", + "# 3) Read the image from disk (shape = H x W x 3 in BGR). Then convert to RGB.\n", + "input_image = media.read_image(input_filepath)[..., :3]\n", + "assert input_image.ndim == 3 and input_image.shape[2] == 3, \"Image must have shape H x W x 3\"\n", + "\n", + "# 4) Expand dimensions to B x H x W x C, since the ImageTokenizer expects a batch dimension\n", + "# in the input. (Batch size = 1 in this example.)\n", + "batched_input_image = np.expand_dims(input_image, axis=0)\n", + "\n", + "# 5) Create the ImageTokenizer instance with the encoder & decoder.\n", + "# - device=\"cuda\" uses the GPU\n", + "# - dtype=\"bfloat16\" expects Ampere or newer GPU (A100, RTX 30xx, etc.)\n", + "tokenizer = ImageTokenizer(\n", + " checkpoint_enc=encoder_ckpt,\n", + " checkpoint_dec=decoder_ckpt,\n", + " device=\"cuda\",\n", + " dtype=\"bfloat16\",\n", + ")\n", + "\n", + "# 6) Use the tokenizer to autoencode (encode & decode) the image.\n", + "# The output is a NumPy array with shape = B x H x W x C, range [0..255].\n", + "batched_output_image = tokenizer(batched_input_image)\n", + "\n", + "# 7) Extract the single image from the batch (index 0), convert to uint8.\n", + "output_image = batched_output_image[0]\n", + "\n", + "# 9) Save the reconstructed image to disk.\n", + "input_dir, input_filename = os.path.split(input_filepath)\n", + "filename, ext = os.path.splitext(input_filename)\n", + "output_filepath = f\"{input_dir}/{filename}_{model_name.split('-')[-1]}{ext}\"\n", + "media.write_image(output_filepath, output_image)\n", + "print(\"Input image read from:\\t\", f\"{os.getcwd()}/{input_filepath}\")\n", + "print(\"Reconstruction saved:\\t\", f\"{os.getcwd()}/{output_filepath}\")\n", + "\n", + "# 10) Visualization of the input image (left) and the reconstruction (right).\n", + "media.show_images([input_image, output_image], [\"Input Image\", \"Reconstructed Image\"])" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb b/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..31016a779c7069c352e9691a9965cb2f3f5051a5 --- /dev/null +++ b/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n3ryhkSfIEfl" + }, + "source": [ + "# Video Tokenization Using [NVIDIA Cosmos Tokenizer](https://github.com/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer/notebook/Video_Tokenization.ipynb)\n", + "\n", + "The Jupyter Notebook example utilizes the **Cosmos-Tokenizer** pretrained models, which include Continuous Video (CV) tokenizers that transform videos into continuous spatio-temporal latents and Discrete Video (DI) tokenizers that transform videos into discrete tokens. Both CV and DV tokenizers are available with compression rates of (`TxHxW` format) 4x8x8 and 8x8x8, and 8x16x16. For instance, **CV4x8x8** effectively downsizes the number of frames by a factor of 4 and both height and width by a factor of 8.\n", + "\n", + "Within the notebook, the `VideoTokenizer` class from the `cosmos_tokenizer.video_lib` module is employed to manage the encoder and decoder components of this model. The encoder compresses the input video into a condensed latent representation or discrete integers, while the decoder reconstructs the video from this latent representation or discrete integers.\n", + "\n", + "This instance of the Cosmos Tokenizer demonstrates its autoencoding capability: compressing a video into a smaller latent space and subsequently reconstructing it to its original form. This showcases the efficiency of video tokenization for tasks involving significant spatial compression during video reconstruction, a highly desirable feature for generative modeling.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5BkjyLTPLM6e" + }, + "source": [ + "This tutorial follows a simple, step-by-step approach, making it easy to understand and adapt.\n", + "\n", + "## Step 1: Clone the Cosmos Tokenizer Repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TEV88M9YG973" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/NVIDIA-Cosmos/cosmos-predict1.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AxOMEJpFL9QL" + }, + "source": [ + "## Step 2: Install **Cosmos-Tokenizer**\n", + "Before proceeding, ensure you have the **Cosmos Tokenizer** installed. If you cloned the repository in Step 1, use the following command to install it in editable mode:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XuwUR6HrIxD8" + }, + "outputs": [], + "source": [ + "# Step 2: # Install Cosmos-Tokenizer and its Python dependencies.\n", + "import os\n", + "if os.path.exists(\"cosmos-predict1\"):\n", + " os.chdir(\"cosmos-predict1\")\n", + " !apt-get update\n", + " !apt-get install -y git-lfs\n", + " !git lfs pull\n", + " %pip install -r requirements.txt\n", + "else:\n", + " print('cosmos-predict1 is already installed.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "id29RPiyMOtB" + }, + "source": [ + "## Step 3: Set Up Hugging Face API Token and Download Pretrained Models\n", + "\n", + "In this step, you'll configure the Hugging Face API token and download the pretrained model weights required for the **Cosmos Tokenizer**.\n", + "\n", + "1. **Ensure You Have a Hugging Face Account** \n", + " If you do not already have a Hugging Face account, follow these steps to create one and generate an API token:\n", + " - Go to the [Hugging Face website](https://huggingface.co/) and sign up for a free account.\n", + " - After logging in, navigate to your [Settings → Access Tokens](https://huggingface.co/settings/tokens).\n", + " - Click on \"New Token\" to generate an API token with the required permissions.\n", + "\n", + "2. **Set the Hugging Face Token** \n", + " Check if the Hugging Face token is already set in the environment variables. If not, you will be prompted to enter it manually. The token is essential to authenticate and access the Hugging Face models.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "joxcyOlnM7HQ" + }, + "outputs": [], + "source": [ + "# Check if the token is already set\n", + "if \"HUGGINGFACE_TOKEN\" not in os.environ:\n", + " os.environ[\"HUGGINGFACE_TOKEN\"] = input(\"Please enter your Hugging Face API token: \")\n", + "!git config --global credential.helper store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Lq7MAQ9pGPH9" + }, + "outputs": [], + "source": [ + "from huggingface_hub import login, snapshot_download\n", + "import os\n", + "HUGGINGFACE_TOKEN = os.environ.get(\"HUGGINGFACE_TOKEN\")\n", + "login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)\n", + "model_names = [\n", + " \"Cosmos-0.1-Tokenizer-CV4x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CV8x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CV8x16x16\",\n", + " \"Cosmos-0.1-Tokenizer-DV4x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DV8x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DV8x16x16\",\n", + " \"Cosmos-Tokenize1-CV8x8x8-720p\",\n", + " \"Cosmos-Tokenize1-DV8x16x16-720p\",\n", + "]\n", + "for model_name in model_names:\n", + " hf_repo = \"nvidia/\" + model_name\n", + " local_dir = \"checkpoints/\" + model_name\n", + " os.makedirs(local_dir, exist_ok=True)\n", + " print(f\"downloading {model_name}...\")\n", + " snapshot_download(repo_id=hf_repo, local_dir=local_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ltZ-v2vzNv74" + }, + "source": [ + "## Step 4: Use Cosmos Tokenizer for Video Reconstruction\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 594 + }, + "id": "gZFPrGCBGwtC", + "outputId": "ad18dc16-c1f2-410c-937b-787c677ec27e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00, 6.45s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input video read from:\t /home/freda/Cosmos/cosmos1/models/tokenizer/test_data/video.mp4\n", + "Reconstruction saved:\t /home/freda/Cosmos/cosmos1/models/tokenizer/test_data/video_CV8x8x8.mp4\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
Input Video
\n", + "
\n", + "
Reconstructed Video
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title In this step, load the required checkpoints, and perform video reconstruction. {\"run\":\"auto\"}\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import importlib\n", + "from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer\n", + "import mediapy as media\n", + "\n", + "\n", + "# 1) Specify the model name, and the paths to the encoder/decoder checkpoints.\n", + "model_name = 'Cosmos-Tokenize1-CV8x8x8-720p' # @param [\"Cosmos-0.1-Tokenizer-CV4x8x8\", \"Cosmos-0.1-Tokenizer-CV8x8x8\", \"Cosmos-0.1-Tokenizer-CV8x16x16\", \"Cosmos-0.1-Tokenizer-DV4x8x8\", \"Cosmos-0.1-Tokenizer-DV8x8x8\", \"Cosmos-0.1-Tokenizer-DV8x16x16\", \"Cosmos-Tokenize1-CV8x8x8-720p\", \"Cosmos-Tokenize1-DV8x16x16-720p\"]\n", + "temporal_window = 49 # @param {type:\"slider\", min:1, max:121, step:8}\n", + "\n", + "encoder_ckpt = f\"checkpoints/{model_name}/encoder.jit\"\n", + "decoder_ckpt = f\"checkpoints/{model_name}/decoder.jit\"\n", + "\n", + "# 2) Load or provide the video filename you want to tokenize & reconstruct.\n", + "input_filepath = \"cosmos_predict1/tokenizer/test_data/video.mp4\"\n", + "\n", + "# 3) Read the video from disk (shape = T x H x W x 3 in BGR).\n", + "input_video = media.read_video(input_filepath)[..., :3]\n", + "assert input_video.ndim == 4 and input_video.shape[-1] == 3, \"Frames must have shape T x H x W x 3\"\n", + "\n", + "# 4) Expand dimensions to B x Tx H x W x C, since the CausalVideoTokenizer expects a batch dimension\n", + "# in the input. (Batch size = 1 in this example.)\n", + "batched_input_video = np.expand_dims(input_video, axis=0)\n", + "\n", + "# 5) Create the CausalVideoTokenizer instance with the encoder & decoder.\n", + "# - device=\"cuda\" uses the GPU\n", + "# - dtype=\"bfloat16\" expects Ampere or newer GPU (A100, RTX 30xx, etc.)\n", + "tokenizer = CausalVideoTokenizer(\n", + " checkpoint_enc=encoder_ckpt,\n", + " checkpoint_dec=decoder_ckpt,\n", + " device=\"cuda\",\n", + " dtype=\"bfloat16\",\n", + ")\n", + "\n", + "# 6) Use the tokenizer to autoencode (encode & decode) the video.\n", + "# The output is a NumPy array with shape = B x T x H x W x C, range [0..255].\n", + "batched_output_video = tokenizer(batched_input_video,\n", + " temporal_window=temporal_window)\n", + "\n", + "# 7) Extract the single video from the batch (index 0).\n", + "output_video = batched_output_video[0]\n", + "\n", + "# 9) Save the reconstructed video to disk.\n", + "input_dir, input_filename = os.path.split(input_filepath)\n", + "filename, ext = os.path.splitext(input_filename)\n", + "output_filepath = f\"{input_dir}/{filename}_{model_name.split('-')[-1]}{ext}\"\n", + "media.write_video(output_filepath, output_video)\n", + "print(\"Input video read from:\\t\", f\"{os.getcwd()}/{input_filepath}\")\n", + "print(\"Reconstruction saved:\\t\", f\"{os.getcwd()}/{output_filepath}\")\n", + "\n", + "# 10) Visualization of the input video (left) and the reconstruction (right).\n", + "media.show_videos([input_video, output_video], [\"Input Video\", \"Reconstructed Video\"], height=480)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/cosmos_predict1/tokenizer/test_data/image.png b/cosmos_predict1/tokenizer/test_data/image.png new file mode 100644 index 0000000000000000000000000000000000000000..370b83e4fd1c42547cbe34190ff726994d2c34a6 --- /dev/null +++ b/cosmos_predict1/tokenizer/test_data/image.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f2261a585eea38a0c9ec16f2ea81a2295b49c5ad6a3e39fc7cfdd1aa39f53b +size 1786433 diff --git a/cosmos_predict1/tokenizer/test_data/video.mp4 b/cosmos_predict1/tokenizer/test_data/video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..13dfd018f2ba6afb264c862cfc54a83b8d9e5f6b --- /dev/null +++ b/cosmos_predict1/tokenizer/test_data/video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b1112c71ee9f14b0d1d2b60a7b8d76bdd133d8fc788e5b41e104132f75bfb4f +size 3570241 diff --git a/cosmos_predict1/tokenizer/training/__init__.py b/cosmos_predict1/tokenizer/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/callbacks.py b/cosmos_predict1/tokenizer/training/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6e11f350a51fcff117a100a7774ca0b0b5961d --- /dev/null +++ b/cosmos_predict1/tokenizer/training/callbacks.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenizer callbacks extended from base callbacks.""" + +import math +from typing import Any, Optional + +import numpy as np +import torch +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.utils import callback, distributed, log +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + +_UINT8_MAX_F = float(np.iinfo(np.uint8).max) +_VIDEO_CONSISTENCY_LOSS = "video_consistency" + + +def make_video_grid(video, nrow=None, padding=1): + r"""Make a grid of videos for visualization. + Args: + video (tensor): video of size B x C x T x H x W. + nrow (int): number of rows in the grid. + padding (int): size of paddings between videos. + """ + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().detach().numpy() * _UINT8_MAX_F).astype("uint8") + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype="uint8") + + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + return video + + +def compute_weight_norm(model): + weight_norm = dict() + for layer_name, param in model.named_parameters(): + if torch.isnan(param).any(): + raise ValueError(f"[weight] {layer_name} NaN detected in gradients") + weight_norm[f"{layer_name}"] = torch.norm(param, p=2).item() + return weight_norm + + +def compute_grad_norm(model): + grad_norm = dict() + for layer_name, param in model.named_parameters(): + if param.grad is not None: + if torch.isnan(param.grad).any(): + raise ValueError(f"[grad] {layer_name} NaN detected in gradients") + grad_norm[f"{layer_name}"] = torch.norm(param.grad, p=2).item() + return grad_norm + + +class AdaptCkptStateDict(callback.Callback): + def __init__(self, config: Config, trainer: Trainer): + super().__init__(config, trainer) + + def on_save_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: + """Adapt the state dict should the model be a compiled one.""" + if not isinstance(model.network, torch_OptimizedModule): + return + + def _uncompiled_key(k): + if k.startswith("network._orig_mod"): + return k.replace("network._orig_mod", "network") + elif k.startswith("ema.network-_orig_mod"): + return k.replace("ema.network-_orig_mod", "ema.network") + return k + + fixed_keys_state_dict = {} + + for k, v in state_dict["model"].items(): + fixed_keys_state_dict[_uncompiled_key(k)] = v + + state_dict["model"] = fixed_keys_state_dict + + def on_load_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: + """Adapt the state dict should the model be a compiled one.""" + if not isinstance(model.network, torch_OptimizedModule): + return + + def _compiled_key(k): + if k.startswith("network."): + return k.replace("network", "network._orig_mod") + elif k.startswith("ema.network-"): + return k.replace("ema.network", "ema.network-_orig_mod") + return k + + fixed_keys_state_dict = {} + + for k, v in state_dict["model"].items(): + fixed_keys_state_dict[_compiled_key(k)] = v + + state_dict["model"] = fixed_keys_state_dict + + +class GradClipCallback(callback.GradClipCallback): + """The verbose tokenizer callback for gradient clipping.""" + + def __init__(self, grad_clip_norm: float, config: Config, trainer: Trainer, verbose: bool): + super().__init__(config, trainer, grad_clip_norm) + self.verbose = verbose + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + grad_scaler.unscale_(optimizer) + total_norm = torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) + if torch.isnan(total_norm): + raise ValueError("[gradient clipping] NaN detected in gradient norms") + if torch.isfinite(total_norm) and total_norm > self.grad_clip_norm and self.verbose: + if model_ddp.module.network.training: + log.warning( + f"[net:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." + ) + else: + log.warning( + f"[unknown:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." + ) + + +class ExpandLossMask(callback.Callback): + def __init__(self, kernel_size: int, config: Config, trainer: Trainer): + super().__init__(config, trainer) + self.kernel_size = kernel_size + + def on_training_step_start(self, model: Model, data: dict[str, Any], iteration: int = 0) -> None: + """Expand loss_mask with max pooling (to cover some partial human regions)""" + + if "loss_mask" not in data.keys(): + return + + assert data["loss_mask"].ndim == 4 or data["loss_mask"].ndim == 5, "ndim of loss_mask must be 4 or 5" + + kernel_size = self.kernel_size + if data["loss_mask"].ndim == 4: + data["loss_mask"] = torch.nn.functional.max_pool2d( + data["loss_mask"], kernel_size, stride=1, padding=kernel_size // 2 + ) + else: + data["loss_mask"] = torch.nn.functional.max_pool3d( + data["loss_mask"], + (1, kernel_size, kernel_size), + stride=1, + padding=(0, kernel_size // 2, kernel_size // 2), + ) + + +class TorchCompile(callback.Callback): + """ + Callback to use torch.compile() on network or modules in losses(FlowLoss and PerceptualLoss) or both. + We compile them at later iteration as it prevents NCCL timeouts when times are very unstable during first iterations + """ + + _TORCH_DYNAMO_CACHE_SIZE = 128 + + def __init__( + self, + compile_after_iterations: int = 8, + compile_network: bool = False, + compile_loss: bool = False, + compile_loss_keys: list[str] = ["flow", "perceptual"], + ): + self.initial_iteration: Optional[int] = None + self.compile_after_iterations: int = compile_after_iterations + + self.compile_network: bool = compile_network + self.compile_loss: bool = compile_loss + + self.compile_loss_keys: list[str] = compile_loss_keys + + if self.compile_network or self.compile_loss: + torch._dynamo.config.cache_size_limit = TorchCompile._TORCH_DYNAMO_CACHE_SIZE + + # Hack to make ".training" work on "torch.compile()" module. + # Value of ".training" is incorrectly set on torch.compile() module, when .eval() or .train() + # is invoked, but is correctly set on original module and this hack accesses that value + # I've created issue about this: https://github.com/pytorch/pytorch/issues/132986 + torch_OptimizedModule.training = property( + lambda self: self._orig_mod.training, lambda self, value: None, lambda self: None + ) + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + if not (self.compile_network or self.compile_loss): + return + + if self.initial_iteration is None: + log.info(f"Compilation will done on iteration {iteration + self.compile_after_iterations}") + self.initial_iteration = iteration + + if self.compile_network: + if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: + log.warning( + '"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), network will not be compiled' + ) + + if iteration - self.initial_iteration == self.compile_after_iterations: + if self.compile_network: + if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: + log.warning( + '"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), skipping network compilation' + ) + else: + log.info("Compiling network") + model.network = torch.compile(model.network, dynamic=False) + + if self.compile_loss: + for key in self.compile_loss_keys: + if key not in model.loss.loss_modules: + log.warning(f"Loss module for compilation with key: {key} not found") + else: + if ( + hasattr(model.loss.loss_modules[key], "checkpoint_activations") + and getattr(model.loss.loss_modules[key], "checkpoint_activations") is True + ): + log.warning( + f"torch.compile() doesn't work with activation checkpointing, skipping compilation for loss with key: {key}" + ) + else: + log.info(f"Compiling loss with key: {key}") + model.loss.loss_modules[key].torch_compile() diff --git a/cosmos_predict1/tokenizer/training/checkpointer.py b/cosmos_predict1/tokenizer/training/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb58c0d08149a869cc0e26321142fb8b3754ca2 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/checkpointer.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading + +import torch +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.utils import callback, distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.model import Model + + +class TokenizerCheckpointer(Checkpointer): + """The tokenizer checkpointer, extends the shared checkpointer. + + Supports checkpoint saving/loading to local disk: + - network weights and training optimizer states. + - optionally, export a TorchScript version of the EMA model. + """ + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + super().__init__(config_checkpoint, config_job, callbacks) + self.callbacks = callbacks + self.config_jit = config_checkpoint.jit + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = -1, + **ignore_kwargs, + ) -> None: + """Saves network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer: The model optimizer. + scheduler: The optimization scheduler. + grad_scaler: The gradient scaler (for mixed precision training). + iteration: Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + model.eval() + checkpoint_file = f"iter_{iteration:09}.pt" + + if distributed.get_rank() == 0: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, self._get_ema_jit(model), checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local( + self, + state_dict: dict[str, torch.Tensor], + jit_models: dict[str, torch.ScriptModule], + checkpoint_file: str, + rank: int = 0, + ) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict: The state dict of the model/optimizer/scheduler. + ema_jit: A dict of TorchScript EMA model, representing the encoder, decoder and full model. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + torch.save(state_dict, checkpoint_path) + for key, jit_model in jit_models.items(): + checkpoint_jit = checkpoint_path.replace(".pt", f"_{key}.jit") + torch.jit.save(jit_model, checkpoint_jit) + log.success(f"Saved checkpoint: {checkpoint_jit}") + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + def _get_ema_jit(self, model: Model) -> dict[str, torch.ScriptModule]: + """Returns a TorchScript version of ema models compiled by PyTorch JIT.""" + if not self.config_jit.enabled: + return dict() + input_shape = tuple(self.config_jit.input_shape) + example_input = torch.randn(input_shape) + dtype = getattr(torch, self.config_jit.dtype) + example_input = example_input.to(self.config_jit.device).to(dtype) + with ema.ema_scope(model, enabled=model.config.ema.enabled): + _model = model.network + if isinstance(_model, torch_OptimizedModule): + _model = _model._orig_mod + + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + ema_jit = torch.jit.trace(_model, example_input, strict=self.config_jit.strict) + encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=self.config_jit.strict) + decoder_example = encoder_jit(example_input) + if isinstance(decoder_example, tuple): + decoder_example = decoder_example[0] + else: + assert isinstance(decoder_example, torch.Tensor), "decoder_example should be a tensor or tuple" + decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=self.config_jit.strict) + return {"ema": ema_jit, "enc": encoder_jit, "dec": decoder_jit} diff --git a/cosmos_predict1/tokenizer/training/configs/__init__.py b/cosmos_predict1/tokenizer/training/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/tokenizer/training/configs/base/__init__.py b/cosmos_predict1/tokenizer/training/configs/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/configs/base/callback.py b/cosmos_predict1/tokenizer/training/configs/base/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..941ea69d78ef30651ed53b600cfe43c424bcdb3d --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/callback.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""callbacks config options: + +BASIC_CALLBACKS: always recommended to use +""" + +from cosmos_predict1.tokenizer.training.callbacks import ( + AdaptCkptStateDict, + ExpandLossMask, + GradClipCallback, + TorchCompile, +) +from cosmos_predict1.utils.callback import EMAModelCallback, LowPrecisionCallback, ProgressBarCallback +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L + +BASIC_CALLBACKS = dict( + low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), + grad_clip=L(GradClipCallback)(grad_clip_norm=1, verbose=False, config=PLACEHOLDER, trainer=PLACEHOLDER), + ema=L(EMAModelCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), + progress_bar=L(ProgressBarCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), + expand_loss_mask=L(ExpandLossMask)(kernel_size=51, config=PLACEHOLDER, trainer=PLACEHOLDER), + adapt_ckpt_state_dict=L(AdaptCkptStateDict)(config=PLACEHOLDER, trainer=PLACEHOLDER), + torch_compile=L(TorchCompile)( + compile_after_iterations=8, + compile_network=False, + compile_loss=False, + compile_loss_keys=["flow", "perceptual"], + ), +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py b/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f869223833a0027c29ee6d9fae4bf460a85e0639 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""checkpoints config options: + +CHECKPOINT_LOCAL: store at local file system + +""" +import attrs + +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config import make_freezable +from cosmos_predict1.utils.lazy_config import LazyDict + + +@make_freezable +@attrs.define(slots=False) +class ExperimentConfig: + # Enables enforcing experiment naming. + enabled: bool = True + # The project, e.g. edify_video4. + project: str = None + # The valid groups, e.g ["video"]. + groups: list[str] = None + # The approved name prefixes, e.g. ["DV1024", "DI256"]. + name_prefixes: list[str] = None + + +@make_freezable +@attrs.define(slots=False) +class TokenizerCheckpointConfig(config.CheckpointConfig): + # Experiment naming configs. + experiment: ExperimentConfig = attrs.field(factory=ExperimentConfig) + + +jit_config = config.JITConfig( + enabled=True, + input_shape=[1, 3, 1024, 1024], +) + +experiment_config = ExperimentConfig( + enabled=True, + project="cosmos_tokenizer", + groups=["debug", "video"], + name_prefixes=[ + f"{base}{size}" if base in ["CI", "DI"] else f"{base}{size}_Causal" + for base in ["CI", "DI", "CV", "DV"] + for size in [256, 320, 480, 512, 720, 1024, 1080] + ] + + [f"{base}{size}" for base in ["CV", "DV"] for size in [256, 320, 512, 720]] + + ["mock"], +) + +CHECKPOINT_LOCAL: LazyDict = attrs.asdict( + TokenizerCheckpointConfig( + save_iter=5000, + jit=jit_config, + experiment=experiment_config, + ) +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/data.py b/cosmos_predict1/tokenizer/training/configs/base/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a698489074d8dfde2bf7dafbdb68f24203ede8 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/data.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""dataloader config options + +Available dataloader options: + image_loader_basic + video_loader_basic + joint_image_video_loader_basic +""" + +from torch.utils.data import DataLoader + +from cosmos_predict1.tokenizer.training.configs.base.mock_data import get_mock_video_dataloader +from cosmos_predict1.tokenizer.training.datasets.dataset_provider import dataset_entry +from cosmos_predict1.utils.lazy_config import LazyCall + +DATALOADER_OPTIONS = {} + + +def dataloader_register(key): + def decorator(func): + DATALOADER_OPTIONS[key] = func + return func + + return decorator + + +@dataloader_register("video_loader_basic") +def get_video_dataloader( + dataset_name, + is_train, + batch_size=1, + num_video_frames=25, + resolution="720", + crop_height=128, + num_workers=8, +): + if dataset_name.startswith("mock"): + return get_mock_video_dataloader( + batch_size=batch_size, + is_train=is_train, + num_video_frames=num_video_frames, + resolution=resolution, + crop_height=crop_height, + ) + return LazyCall(DataLoader)( + dataset=LazyCall(dataset_entry)( + dataset_name=dataset_name, + dataset_type="video", + is_train=is_train, + resolution=resolution, + crop_height=crop_height, + num_video_frames=num_video_frames, + ), + batch_size=batch_size, # 2 + num_workers=num_workers, # 8 + prefetch_factor=2, + shuffle=None, # do we need this? + sampler=None, + persistent_workers=False, + pin_memory=True, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/base/loss.py b/cosmos_predict1/tokenizer/training/configs/base/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fcd754a5f0913fbf28e376773bde03bb4f9806 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/loss.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss config options + +Loss weights are scheduled using a piecewise linear LR schedule. The schedule is defined by a list of boundaries and values. + +`boundaries` is a list of integers representing the iteration at which the weight value changes. +`values` is a list of floats representing the weight value at each boundary. It should have one more value than `boundaries`. + +Example: + A loss's weight will be: + values[0] when step <= boundaries[0], + values[1] when step > boundaries[0] and step <= boundaries[1], + ..., and + values[-1] when step > boundaries[-1]. +""" +import attrs + +from cosmos_predict1.tokenizer.training.losses import ReduceMode +from cosmos_predict1.tokenizer.training.losses.continuous import ( + ColorLoss, + FlowLoss, + KLLoss, + PerceptualLoss, + TokenizerLoss, + VideoConsistencyLoss, +) +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class KLConfig: + # each step is greater than boundaries[-1], so weight=values[-1] + boundaries: list[int] = [0] + values: list[float] = [1e-6] + + +@attrs.define(slots=False) +class PerceptualConfig: + lpips_boundaries: list[int] = [500000] + lpips_values: list[float] = [0.1, 0.073] + # Layer weights for linearly combining the multi-layer vgg-based losses. + layer_weights: list[float] = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] + # Gram loss, whether to turn on, and what weights to use. + gram_enabled: bool = True + gram_boundaries: list[int] = [500000] + gram_values: list[float] = [0.0, 0.062] + # Corr loss, whether to turn on, and what weights to use. + corr_enabled: bool = False + corr_boundaries: list[int] = [0] + corr_values: list[float] = [0.0] + # In the example training memory usage dropped from 64.03 GiB to 60.54 GiB + # with checkpointing enabled for this loss for about 3.2% slowdown. + # With checkpointing this and PerceptualLoss memory usage dropped + # from 64.03 GiB to 52.94 GiB for about 18% slowdown + # more details in MR:949 + checkpoint_activations: bool = False + + +@attrs.define(slots=False) +class ColorConfig: + # Color (RGB) basic loss and its weight schedule. + norm: str = "L1" + boundaries: list[int] = [0] + values: list[float] = [1.0] + + +@attrs.define(slots=False) +class FlowConfig: + # Flow loss and its weight schedule. + boundaries: list[int] = [250000] + values: list[float] = [0.0, 0.01] + scale: int = 2 + # Flow loss depends on RAFT, as such it requires a specific dtype. + dtype: str = "bfloat16" + # In the example training memory usage dropped from 28GB to 23GB + # with checkpointing enabled for this loss + # With checkpointing this and PerceptualLoss memory usage dropped + # from 64.03 GiB to 52.94 GiB for about 18% slowdown + # more details in MR:949 + checkpoint_activations: bool = False + enabled: bool = False + + +@attrs.define(slots=False) +class VideoConsistencyConfig: + # Add consistency loss between overlapped video frames + boundaries: list[int] = [250000] + values: list[float] = [0.0, 0.01] + enabled: bool = False + num_frames: int = 9 + step: int = 1 + + +@attrs.define(slots=False) +class VideoLoss: + # The combined loss function, and its reduction mode. + color: LazyDict = L(ColorLoss)(config=ColorConfig()) + kl: LazyDict = L(KLLoss)(config=KLConfig()) + perceptual: LazyDict = L(PerceptualLoss)(config=PerceptualConfig()) + flow: LazyDict = L(FlowLoss)(config=FlowConfig()) + video_consistency: LazyDict = L(VideoConsistencyLoss)(config=VideoConsistencyConfig()) + reduce: str = ReduceMode.MEAN.value # model.config.loss.config.reduce={'MEAN', 'SUM', 'SUM_PER_FRAME'} + + +VideoLossConfig: LazyDict = L(TokenizerLoss)(config=VideoLoss()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/metric.py b/cosmos_predict1/tokenizer/training/configs/base/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..daecc1cb8beb9bbc2c18ec1acb978ee3b5fdcdab --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/metric.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metric configurations for the tokenizer model. + +Support for PSNR or SSIM, there are validation only metrics. +""" +import attrs + +from cosmos_predict1.tokenizer.training.metrics import CodeUsageMetric, PSNRMetric, SSIMMetric, TokenizerMetric +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class Metric: + # The combined loss function, and its reduction mode. + PSNR: LazyDict = L(PSNRMetric)() + SSIM: LazyDict = L(SSIMMetric)() + + +@attrs.define(slots=False) +class DiscreteTokenizerMetric: + # with code usage (perplexity PPL), for discrete tokenizers only + PSNR: LazyDict = L(PSNRMetric)() + SSIM: LazyDict = L(SSIMMetric)() + CodeUsage: LazyDict = L(CodeUsageMetric)(codebook_size=64000) + + +MetricConfig: LazyDict = L(TokenizerMetric)(config=Metric()) + +DiscreteTokenizerMetricConfig: LazyDict = L(TokenizerMetric)(config=DiscreteTokenizerMetric()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/mock_data.py b/cosmos_predict1/tokenizer/training/configs/base/mock_data.py new file mode 100644 index 0000000000000000000000000000000000000000..47e5a8efa694cdb00873b2c38d0c224be9109db3 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/mock_data.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.data import DataLoader + +from cosmos_predict1.tokenizer.training.datasets.mock_dataset import CombinedDictDataset, LambdaDataset +from cosmos_predict1.tokenizer.training.datasets.utils import VIDEO_KEY, VIDEO_VAL_CROP_SIZE_INFO, get_crop_size_info +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +_IMAGE_ASPECT_RATIO = "1,1" +_VIDEO_ASPECT_RATIO = "16,9" + + +def get_video_dataset( + is_train: bool, + resolution: str, + crop_height: int, + num_video_frames: int, +): + if is_train: + crop_sizes = get_crop_size_info(crop_height) + log.info( + f"[video] training num_frames={num_video_frames}, crop_height={crop_height} and crop_sizes: {crop_sizes}." + ) + else: + if crop_height is None: + crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] + else: + crop_sizes = get_crop_size_info(crop_height) + log.info(f"[video] validation num_frames={num_video_frames}, crop_sizes: {crop_sizes}") + + h = crop_sizes[_VIDEO_ASPECT_RATIO][1] + w = crop_sizes[_VIDEO_ASPECT_RATIO][0] + + def video_fn(): + return 2 * torch.rand(3, num_video_frames, h, w) - 1 + + return CombinedDictDataset( + **{ + VIDEO_KEY: LambdaDataset(video_fn), + } + ) + + +def get_mock_video_dataloader( + batch_size: int, is_train: bool = True, num_video_frames: int = 9, resolution: str = "720", crop_height: int = 128 +) -> LazyDict: + """A function to get mock video dataloader. + + Args: + batch_size: The batch size. + num_video_frames: The number of video frames. + resolution: The resolution. Defaults to "1024". + + Returns: + LazyDict: A LazyDict object specifying the video dataloader. + """ + if resolution not in VIDEO_VAL_CROP_SIZE_INFO: + resolution = "720" + return L(DataLoader)( + dataset=L(get_video_dataset)( + is_train=is_train, + resolution=resolution, + crop_height=crop_height, + num_video_frames=num_video_frames, + ), + batch_size=batch_size, + shuffle=False, + num_workers=8, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/base/model.py b/cosmos_predict1/tokenizer/training/configs/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2f859195d4b81253bfba583e1429ed6749b33e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/model.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attrs + +from cosmos_predict1.tokenizer.training.model import TokenizerModel +from cosmos_predict1.utils.config import EMAConfig +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class ModelConfig: + network: LazyDict = None + loss: LazyDict = None + metric: LazyDict = None + ema: EMAConfig = EMAConfig(enabled=True, beta=0.9999) + precision: str = "bfloat16" + torch_compile: bool = False + disc: LazyDict = None + disc_optimizer: LazyDict = None + disc_scheduler: LazyDict = None + + +DefaultModelConfig: LazyDict = L(TokenizerModel)(config=ModelConfig()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/net.py b/cosmos_predict1/tokenizer/training/configs/base/net.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0ae02ff7b34a67ce0c90324e80f6d6c6c5d153 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/net.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Net config options for cosmos/tokenizer + +ContinuousImageTokenizerConfig +DiscreteImageTokenizerConfig +CausalContinuousVideoTokenizerConfig + +""" + +from cosmos_predict1.tokenizer.modules import ( + ContinuousFormulation, + Decoder3DType, + DecoderType, + DiscreteQuantizer, + Encoder3DType, + EncoderType, +) +from cosmos_predict1.tokenizer.networks.continuous_image import ContinuousImageTokenizer +from cosmos_predict1.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer +from cosmos_predict1.tokenizer.networks.discrete_image import DiscreteImageTokenizer +from cosmos_predict1.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +ContinuousImageTokenizerConfig: LazyDict = L(ContinuousImageTokenizer)( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio, default 8. + spatial_compression=8, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + # Since we are using AE formulation, we only need the mean, so z_factor=1. + z_factor=1, + name="ContinuousImageTokenizer", + # What formulation to use, either "AE" or "VAE". + # Chose AE here, since this has been proven to be effective. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +DiscreteImageTokenizerConfig: LazyDict = L(DiscreteImageTokenizer)( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. Default FSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + persistent_quantizer=False, + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DiscreteImageTokenizer", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +CausalContinuousFactorizedVideoTokenizerConfig: LazyDict = L(CausalContinuousVideoTokenizer)( + # The new causal continuous tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Adopts an AE formulation + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + # Most of the CV and DV tokenizers trained before September 1, 2024, + # used temporal upsampling that was not perfectly mirrored with the + # # encoder's temporal downsampling. Moving forward, new CV/DV tokenizers + # will use legacy_mode=False, meaning they will adopt mirrored upsampling. + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CausalContinuousFactorizedVideoTokenizer", +) + +CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( + # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before quantization is changed to 256 + # from 16 (old versions). It aligns with the DI that uses 256 channels, + # making initialization from image tokenizers easier. + z_channels=256, + z_factor=1, + num_groups=1, + # Most of the CV and DV tokenizers trained before September 1, 2024, + # used temporal upsampling that was not perfectly mirrored with the + # # encoder's temporal downsampling. Moving forward, new CV/DV tokenizers + # will use legacy_mode=False, meaning they will adopt mirrored upsampling. + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + persistent_quantizer=False, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CausalDiscreteFactorizedVideoTokenizer", +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/optim.py b/cosmos_predict1/tokenizer/training/configs/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb3b7eef0c7d9c973806fb3cfc935a409c27fd9 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/optim.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""optimizer config options: + +fused_adam - FusedAdamConfig +adamw - AdamWConfig +""" + +import torch + +from cosmos_predict1.utils import fused_adam +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.scheduler import WarmupCosineLR, WarmupLambdaLR + +FusedAdamConfig: LazyDict = L(fused_adam.FusedAdam)( + capturable=True, + master_weights=True, + adam_w_mode=True, + params=PLACEHOLDER, + lr=1e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +AdamWConfig: LazyDict = L(torch.optim.AdamW)( + params=PLACEHOLDER, + lr=1e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +WarmupLRConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) + +FusedAdamDiscConfig: LazyDict = L(fused_adam.FusedAdam)( + capturable=True, + master_weights=True, + adam_w_mode=True, + params=PLACEHOLDER, + lr=4e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +WarmupLRDiscConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) + +WarmupCosineLRConfig: LazyDict = L(WarmupCosineLR)( + optimizer=PLACEHOLDER, warmup_iters=5000, lr_decay_iters=1000000, min_lr=1e-8 +) diff --git a/cosmos_predict1/tokenizer/training/configs/config.py b/cosmos_predict1/tokenizer/training/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..afb94f8813da7d78f48a79a98ef362eed681c5ef --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/config.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default config for cosmos/tokenizer project.""" + +from typing import Any, List + +import attrs + +from cosmos_predict1.tokenizer.training.configs.base.model import DefaultModelConfig +from cosmos_predict1.tokenizer.training.configs.registry import register_configs +from cosmos_predict1.tokenizer.training.trainer import TokenizerTrainer +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": "mock_video720"}, + {"data_val": "mock_video720"}, + {"optimizer": "fused_adam"}, + {"scheduler": "warmup"}, + {"network": "continuous_factorized_video"}, + {"loss": "video"}, + {"metric": "reconstruction"}, + {"checkpoint": "local"}, + {"callbacks": "basic"}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig, + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + checkpoint=None, + ) + c.job.project = "posttraining" + c.job.group = "debug" + c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = TokenizerTrainer + c.trainer.run_validation = True + + c.trainer.seed = 1234 + c.trainer.max_iter = 10_000_000 + c.trainer.validation_iter = 5000 + c.trainer.max_val_iter = 1 + c.trainer.logging_iter = 100 + + c.trainer.callbacks = None + c.trainer.ddp.static_graph = True + c.trainer.ddp.find_unused_parameters = False + register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.tokenizer.training.configs.experiments") + + return c diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/__init__.py b/cosmos_predict1/tokenizer/training/configs/experiments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/basic.py b/cosmos_predict1/tokenizer/training/configs/experiments/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..9201f00af0de69e570cbcb2588f468125e2b44ba --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/basic.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config settings for cosmos/tokenizer (basic image and video setting)""" + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.utils.lazy_config import LazyDict + +CAUSAL_VIDEO_BASIC: LazyDict = LazyDict( + dict( + defaults=[ + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "mock_video720"}, + {"override /data_val": "mock_video720"}, + {"override /loss": "video"}, + {"override /optimizer": "fused_adam"}, + {"override /callbacks": ["basic"]}, + "_self_", + ], + model=dict( + config=dict( + loss=dict( + config=dict( + perceptual=dict( + config=dict( + lpips_boundaries=[0], + lpips_values=[0.1], + gram_enabled=False, + gram_boundaries=[0], + ) + ), + video_consistency=dict( + config=dict( + enabled=False, + boundaries=[0], + values=[1.0], + num_frames=32, + step=8, + ) + ), + flow=dict( + config=dict( + enabled=False, + boundaries=[1_000_000], + values=[0.0, 0.01], + scale=2, + dtype="bfloat16", + checkpoint_activations=False, + ) + ), + ) + ) + ) + ), + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=720, + num_video_frames=49, + ), + batch_size=1, + ), + job=dict( + project="posttraining", + group="tokenizer", + name="basic_${now:%Y-%m-%d}_${now:%H-%M-%S}", + ), + checkpoint=dict(load_path=None, jit=dict(input_shape=[1, 3, 17, 512, 512])), + ) +) + +cs = ConfigStore.instance() +cs.store(group="experiment", package="_global_", name="video_basic", node=CAUSAL_VIDEO_BASIC) diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py b/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py new file mode 100644 index 0000000000000000000000000000000000000000..42c1d4f27c2cb7806d0482287f0909d49342ca9b --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.tokenizer.training.configs.experiments.utils import create_debug_job_with_mock_data +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyDict + +# Post-training config for Cosmos-Tokenize1-CV8x8x8-720p-HDVILA +Cosmos_Tokenize1_CV8x8x8_720p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "hdvila_video720"}, + {"override /data_val": "hdvila_video720"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=121, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=121, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + channels_mult=[2, 4, 4], + patch_size=4, + legacy_mode=False, + temporal_compression=8, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-CV8x8x8-720p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-DV8x16x16-720p-HDVILA +Cosmos_Tokenize1_DV8x16x16_720p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "discrete_factorized_video"}, + {"override /data_train": "hdvila_video720"}, + {"override /data_val": "hdvila_video720"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + persistent_quantizer=False, + z_channels=16, + channels_mult=[2, 4, 4], + patch_size=4, + legacy_mode=False, + temporal_compression=8, + spatial_compression=16, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-DV8x16x16-720p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-CV4x8x8-360p-HDVILA +Cosmos_Tokenize1_CV4x8x8_360p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "hdvila_video360"}, + {"override /data_val": "hdvila_video360"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + channels_mult=[2, 4, 4], + patch_size=2, + legacy_mode=False, + temporal_compression=4, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-CV4x8x8-360p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-CV4x8x8-360p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-DV4x8x8-360p-HDVILA +Cosmos_Tokenize1_DV4x8x8_360p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "discrete_factorized_video"}, + {"override /data_train": "hdvila_video360"}, + {"override /data_val": "hdvila_video360"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + persistent_quantizer=False, + z_channels=256, + channels_mult=[2, 4, 4], + patch_size=2, + legacy_mode=False, + temporal_compression=4, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-DV4x8x8-360p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-DV4x8x8-360p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +cs = ConfigStore.instance() + +for _item in [ + Cosmos_Tokenize1_CV8x8x8_720p_HDVILA, + Cosmos_Tokenize1_DV8x16x16_720p_HDVILA, + Cosmos_Tokenize1_CV4x8x8_360p_HDVILA, + Cosmos_Tokenize1_DV4x8x8_360p_HDVILA, +]: + experiment_name = [name for name, value in globals().items() if value is _item][0] + + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) + + mock_experiment = f"mock_{experiment_name}" + log.info(f"Registering mock experiment: {mock_experiment}") + _debug_item = create_debug_job_with_mock_data(_item["job"]["name"]) + cs.store( + group="experiment", + package="_global_", + name=mock_experiment, + node=_debug_item, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/utils.py b/cosmos_predict1/tokenizer/training/configs/experiments/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d68db0312c7abe891b99d55683a776ba6a1e43e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/utils.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""registry for commandline override options for config.""" +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_debug_job_with_mock_data(full_experiment_name): + job_dict = dict( + defaults=[ + f"/experiment/{full_experiment_name.replace('-', '_')}", + {"override /data_train": "mock_video360"}, + {"override /data_val": "mock_video360"}, + "_self_", + ], + job=dict(group="debug", name=f"mock_{full_experiment_name}" + "_${now:%Y-%m-%d}_${now:%H-%M-%S}"), + trainer=dict( + max_iter=2, + logging_iter=1, + max_val_iter=1, + validation_iter=2, + ), + checkpoint=dict( + strict_resume=False, + load_training_state=False, + save_iter=2, + ), + ) + return LazyDict(job_dict) diff --git a/cosmos_predict1/tokenizer/training/configs/registry.py b/cosmos_predict1/tokenizer/training/configs/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cda4109b0b6edbc0cbfa29f9af676f359dc90f1f --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/registry.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""registry for commandline override options for config.""" +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.tokenizer.training.configs.base.callback import BASIC_CALLBACKS +from cosmos_predict1.tokenizer.training.configs.base.checkpoint import CHECKPOINT_LOCAL +from cosmos_predict1.tokenizer.training.configs.base.data import DATALOADER_OPTIONS +from cosmos_predict1.tokenizer.training.configs.base.loss import VideoLossConfig +from cosmos_predict1.tokenizer.training.configs.base.metric import DiscreteTokenizerMetricConfig, MetricConfig +from cosmos_predict1.tokenizer.training.configs.base.net import ( + CausalContinuousFactorizedVideoTokenizerConfig, + CausalDiscreteFactorizedVideoTokenizerConfig, + ContinuousImageTokenizerConfig, + DiscreteImageTokenizerConfig, +) +from cosmos_predict1.tokenizer.training.configs.base.optim import ( + AdamWConfig, + FusedAdamConfig, + WarmupCosineLRConfig, + WarmupLRConfig, +) + + +def register_training_data(cs): + for data_source in ["mock", "hdvila"]: + for resolution in ["1080", "720", "480", "360", "256"]: + cs.store( + group="data_train", + package="dataloader_train", + name=f"{data_source}_video{resolution}", # `davis_video720` + node=DATALOADER_OPTIONS["video_loader_basic"]( + dataset_name=f"{data_source}_video", + is_train=True, + resolution=resolution, + ), + ) + + +def register_val_data(cs): + for data_source in ["mock", "hdvila"]: + for resolution in ["1080", "720", "480", "360", "256"]: + cs.store( + group="data_val", + package="dataloader_val", + name=f"{data_source}_video{resolution}", # `davis_video720` + node=DATALOADER_OPTIONS["video_loader_basic"]( + dataset_name=f"{data_source}_video", + is_train=False, + resolution=resolution, + ), + ) + + +def register_net(cs): + cs.store( + group="network", package="model.config.network", name="continuous_image", node=ContinuousImageTokenizerConfig + ) + cs.store(group="network", package="model.config.network", name="discrete_image", node=DiscreteImageTokenizerConfig) + + cs.store( + group="network", + package="model.config.network", + name="continuous_factorized_video", + node=CausalContinuousFactorizedVideoTokenizerConfig, + ) + cs.store( + group="network", + package="model.config.network", + name="discrete_factorized_video", + node=CausalDiscreteFactorizedVideoTokenizerConfig, + ) + + +def register_optim(cs): + cs.store(group="optimizer", package="optimizer", name="fused_adam", node=FusedAdamConfig) + cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig) + + +def register_scheduler(cs): + cs.store(group="scheduler", package="scheduler", name="warmup", node=WarmupLRConfig) + cs.store( + group="scheduler", + package="scheduler", + name="warmup_cosine", + node=WarmupCosineLRConfig, + ) + + +def register_loss(cs): + cs.store(group="loss", package="model.config.loss", name="video", node=VideoLossConfig) + + +def register_metric(cs): + cs.store(group="metric", package="model.config.metric", name="reconstruction", node=MetricConfig) + cs.store(group="metric", package="model.config.metric", name="code_usage", node=DiscreteTokenizerMetricConfig) + + +def register_checkpoint(cs): + cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) + + +def register_callback(cs): + cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) + + +def register_configs(): + cs = ConfigStore.instance() + + register_training_data(cs) + register_val_data(cs) + + register_net(cs) + + register_optim(cs) + register_scheduler(cs) + register_loss(cs) + register_metric(cs) + register_checkpoint(cs) + + register_callback(cs) diff --git a/cosmos_predict1/tokenizer/training/datasets/__init__.py b/cosmos_predict1/tokenizer/training/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py b/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..02166bf140754a4b05a2b883cffea9d5d98e9872 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Augmentations for tokenizer training (image and video)""" + + +from cosmos_predict1.tokenizer.training.datasets.augmentors import ( + CenterCrop, + CropResizeAugmentor, + HorizontalFlip, + Normalize, + RandomReverse, + ReflectionPadding, + ResizeSmallestSideAspectPreserving, + UnsqueezeImage, +) +from cosmos_predict1.tokenizer.training.datasets.utils import ( + VIDEO_KEY, + VIDEO_RES_SIZE_INFO, + VIDEO_VAL_CROP_SIZE_INFO, + get_crop_size_info, +) +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall, LazyDict + +_PROB_OF_CROP_ONLY: float = 0.1 + + +def video_train_augmentations( + input_keys: list[str], + resolution: str = "1080", + crop_height: int = 256, +) -> dict[str, LazyDict]: + [_video_key] = input_keys + crop_sizes = get_crop_size_info(crop_height) + log.info(f"[video] training crop_height={crop_height} and crop_sizes: {crop_sizes}.") + augmentations = { + "crop_resize": LazyCall(CropResizeAugmentor)( + input_keys=[_video_key], + output_keys=[VIDEO_KEY], + crop_args={"size": crop_sizes}, + resize_args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + args={"prob": _PROB_OF_CROP_ONLY}, + ), + "random_reverse": LazyCall(RandomReverse)( + input_keys=[VIDEO_KEY], + args={"prob": 0.5}, + ), + "reflection_padding": LazyCall(ReflectionPadding)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "horizontal_flip": LazyCall(HorizontalFlip)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "normalize": LazyCall(Normalize)( + input_keys=[VIDEO_KEY], + args={"mean": 0.5, "std": 0.5}, + ), + "unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), + } + + return augmentations + + +def video_val_augmentations( + input_keys: list[str], resolution: str = "1080", crop_height: int = None +) -> dict[str, LazyDict]: + [_video_key] = input_keys + if crop_height is None: + crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] + else: + crop_sizes = get_crop_size_info(crop_height) + + log.info(f"[video] validation crop_sizes: {crop_sizes}.") + augmenations = { + "resize_smallest_side_aspect_ratio_preserving": LazyCall(ResizeSmallestSideAspectPreserving)( + input_keys=[VIDEO_KEY], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "center_crop": LazyCall(CenterCrop)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "reflection_padding": LazyCall(ReflectionPadding)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "normalize": LazyCall(Normalize)( + input_keys=[VIDEO_KEY], + args={"mean": 0.5, "std": 0.5}, + ), + "unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), + } + return augmenations diff --git a/cosmos_predict1/tokenizer/training/datasets/augmentors.py b/cosmos_predict1/tokenizer/training/datasets/augmentors.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ed4bea68655b56175324746add2ea3a8fd6e2b --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/augmentors.py @@ -0,0 +1,540 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Additional augmentors for image and video training loops.""" + +from typing import Any, Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.training.datasets.utils import obtain_augmentation_size, obtain_image_size +from cosmos_predict1.utils import log + + +class Augmentor: + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + r"""Base augmentor class + + Args: + input_keys (list): List of input keys + output_keys (list): List of output keys + args (dict): Arguments associated with the augmentation + """ + self.input_keys = input_keys + self.output_keys = output_keys + self.args = args + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise ValueError("Augmentor not implemented") + + +class LossMask(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + mask_config = self.args["masking"] + + input_key = self.input_keys[0] + default_mask = torch.ones_like(data_dict[input_key]) + loss_mask = mask_config["nonhuman_mask"] * default_mask + for curr_key in mask_config: + if curr_key not in self.input_keys: + continue + curr_mask = data_dict[curr_key] + curr_weight = mask_config[curr_key] + curr_loss_mask = curr_mask * curr_weight + (1 - curr_mask) * loss_mask + loss_mask = torch.max(curr_loss_mask, loss_mask) + _ = data_dict.pop(curr_key) + data_dict["loss_mask"] = loss_mask + return data_dict + + +class UnsqueezeImage(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + for key in self.input_keys: + data_dict[key] = data_dict[key].unsqueeze(1) + + return data_dict + + +class RandomReverse(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random temporal reversing of frames. + + Args: + data_dict (dict): Input data dict, CxTxHxW + Returns: + data_dict (dict): Output dict where videos are randomly reversed. + """ + assert self.args is not None + p = self.args.get("prob", 0.5) + coin_flip = torch.rand(1).item() <= p + for key in self.input_keys: + if coin_flip: + data_dict[key] = torch.flip(data_dict[key], dims=[1]) + + return data_dict + + +class RenameInputKeys(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Rename the input keys from the data dict. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict with keys renamed. + """ + assert len(self.input_keys) == len(self.output_keys) + for input_key, output_key in zip(self.input_keys, self.output_keys): + if input_key in data_dict: + data_dict[output_key] = data_dict.pop(input_key) + return data_dict + + +class CropResizeAugmentor(Augmentor): + def __init__( + self, + input_keys: list, + output_keys: Optional[list] = None, + crop_args: Optional[dict] = None, + resize_args: Optional[dict] = None, + args: Optional[dict] = None, + ) -> None: + super().__init__(input_keys, output_keys, args) + self.crop_args = crop_args + self.resize_args = resize_args + self.crop_op = RandomCrop(input_keys, output_keys, crop_args) + self.resize_op = ResizeSmallestSideAspectPreserving(input_keys, output_keys, resize_args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random temporal reversing of frames. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where videso are randomly reversed. + """ + assert self.args is not None + p = self.args.get("prob", 0.1) + + if p > 0.0: + crop_img_size = obtain_augmentation_size(data_dict, self.crop_args) + crop_width, crop_height = crop_img_size + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + if orig_w < crop_width or orig_h < crop_height: + log.warning( + f"Data size ({orig_w}, {orig_h}) is smaller than crop size ({crop_width}, {crop_height}), skip the crop augmentation." + ) + coin_flip = torch.rand(1).item() <= p + if coin_flip and crop_width <= orig_w and crop_height <= orig_h: + data_dict = self.crop_op(data_dict) + return data_dict + + data_dict = self.resize_op(data_dict) + data_dict = self.crop_op(data_dict) + + return data_dict + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # Calculate padding vals + padding_left = int((target_w - orig_w) / 2) + padding_right = target_w - orig_w - padding_left + padding_top = int((target_h - orig_h) / 2) + padding_bottom = target_h - orig_h - padding_top + padding_vals = [padding_left, padding_top, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + # Return padding_mask when padding is performed. + # Padding mask denotes which pixels are padded. + padding_mask = torch.ones((1, target_h, target_w)) + padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 + data_dict["padding_mask"] = padding_mask + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] <= img_h and target_size[1] <= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py b/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..9904bd1eba5d77d3baae9b5ff7b5b6acabba5dae --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of dataset settings and augmentations for tokenization + +Run this command to interactively debug: +python3 -m cosmos_predict1.tokenizer.training.datasets.dataset_provider + +""" + +from cosmos_predict1.tokenizer.training.datasets.augmentation_provider import ( + video_train_augmentations, + video_val_augmentations, +) +from cosmos_predict1.tokenizer.training.datasets.utils import categorize_aspect_and_store +from cosmos_predict1.tokenizer.training.datasets.video_dataset import Dataset +from cosmos_predict1.utils.lazy_config import instantiate + +_VIDEO_PATTERN_DICT = { + "hdvila_video": "datasets/hdvila/videos/*.mp4", +} + + +def apply_augmentations(data_dict, augmentations_dict): + """ + Loop over each LazyCall object and apply it to data_dict in place. + """ + for aug_name, lazy_aug in augmentations_dict.items(): + aug_instance = instantiate(lazy_aug) + data_dict = aug_instance(data_dict) + return data_dict + + +class AugmentDataset(Dataset): + def __init__(self, base_dataset, augmentations_dict): + """ + base_dataset: the video dataset instance + augmentations_dict: the dictionary returned by + video_train_augmentations() or video_val_augmentations() + """ + self.base_dataset = base_dataset + + # Pre-instantiate every augmentation ONCE: + self.augmentations = [] + for aug_name, lazy_aug in augmentations_dict.items(): + aug_instance = instantiate(lazy_aug) # build the actual augmentation + self.augmentations.append((aug_name, aug_instance)) + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, index): + # Get the raw sample from the base dataset + data = self.base_dataset[index] + data = categorize_aspect_and_store(data) + + # Apply each pre-instantiated augmentation + for aug_name, aug_instance in self.augmentations: + data = aug_instance(data) + + return data + + +def dataset_entry( + dataset_name: str, + dataset_type: str, + is_train: bool = True, + resolution="720", + crop_height=256, + num_video_frames=25, +) -> AugmentDataset: + if dataset_type != "video": + raise ValueError(f"Dataset type {dataset_type} is not supported") + + # Instantiate the video dataset + base_dataset = Dataset( + video_pattern=_VIDEO_PATTERN_DICT[dataset_name.lower()], + num_video_frames=num_video_frames, + ) + + # Pick the training or validation augmentations + if is_train: + aug_dict = video_train_augmentations( + input_keys=["video"], # adjust if necessary + resolution=resolution, + crop_height=crop_height, + ) + else: + aug_dict = video_val_augmentations( + input_keys=["video"], + resolution=resolution, + crop_height=crop_height, + ) + + # Wrap the dataset with the augmentations + return AugmentDataset(base_dataset, aug_dict) + + +if __name__ == "__main__": + # Example usage / quick test + dataset = dataset_entry( + dataset_name="davis_video", + dataset_type="video", + is_train=False, + resolution="720", + crop_height=256, + num_video_frames=25, + ) + + # 2) Print out some basic info: + print(f"Total samples in dataset: {len(dataset)}") + + # 3) Grab one sample (or a few) to check shapes, keys, etc. + if len(dataset) > 0: + sample_idx = 0 + sample = dataset[sample_idx] + print(f"Sample index {sample_idx} keys: {list(sample.keys())}") + if "video" in sample: + print("Video shape:", sample["video"].shape) + if "video_name" in sample: + print("Video metadata:", sample["video_name"]) + print("---\nSample loaded successfully.\n") + else: + print("Dataset has no samples!") diff --git a/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py b/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1e1d5e67b647e1a628c4a2f891b5d95ad07980 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Copied from jam_data. +""" + +import inspect +from typing import Any, Callable, Dict + +import torch +from torch.utils.data import Dataset + +MAX_LENGTH = 1 << 15 + + +class LambdaDataset(torch.utils.data.Dataset): + """ + A dataset that generates items by applying a function. This allows for creating + dynamic datasets where the items are the result of function calls. The function can optionally + accept an index argument. + + Attributes: + length (int): The total number of items in the dataset. + fn (Callable): The function to generate dataset items. + is_index_in_params (bool): Flag to determine whether 'index' should be passed + to the function `fn`. + """ + + def __init__(self, fn: Callable, length: int = MAX_LENGTH) -> None: + """ + Initializes the LambdaDataset with a function and the total length. + + Args: + fn (Callable): A function that returns a dataset item. It can optionally accept an + index argument to generate data items based on their index. + length (int): The total number of items in the dataset, defaults to MAX_LENGTH. + """ + self.length = length + self.fn = fn + + try: + # Attempt to inspect the function signature to determine if it accepts an 'index' parameter. + signature = inspect.signature(fn) + self.is_index_in_params = "index" in signature.parameters + except ValueError: + # If the function signature is not inspectable, assume 'index' is not a parameter. + self.is_index_in_params = False + + def __len__(self) -> int: + """ + Returns the total length of the dataset. + + Returns: + int: The number of items in the dataset. + """ + return self.length + + def __getitem__(self, index: int) -> Any: + """ + Retrieves an item at a specific index from the dataset by calling the function `fn`. + Passes the index to `fn` if `fn` is designed to accept an index. + + Args: + index (int): The index of the item to retrieve. + + Returns: + Any: The item returned by the function `fn`. + """ + if self.is_index_in_params: + return self.fn(index) # Call fn with index if it accepts an index parameter. + return self.fn() # Call fn without any parameters if it does not accept the index. + + +class RepeatDataset(torch.utils.data.Dataset): + """ + A dataset wrapper that allows repeating access to items from an underlying dataset. + + This dataset can be used to create an artificial extension of the underlying dataset + to a specified `length`. Each item from the original dataset can be accessed + repeatedly up to `num_item` times before it loops back. + + Attributes: + length (int): The total length of the dataset to be exposed. + dataset (Dataset): The original dataset. + num_item (int): Number of times each item is repeated. + cache_item (dict): Cache to store accessed items to avoid recomputation. + """ + + def __init__(self, dataset: Dataset, length: int = MAX_LENGTH, num_item: int = 1) -> None: + """ + Initializes the RepeatDataset with a dataset, length, and number of repeats per item. + + Args: + dataset (Dataset): The dataset to repeat. + length (int): The total length of the dataset to be exposed. Defaults to MAX_LENGTH. + num_item (int): The number of times to repeat each item. Defaults to 1. + """ + self.length = length + self.dataset = dataset + self.num_item = num_item + self.cache_item = {} + + def __len__(self) -> int: + return self.length + + def __getitem__(self, index: int) -> Any: + index = index % self.num_item + if index not in self.cache_item: + self.cache_item[index] = self.dataset[index] + return self.cache_item[index] + + +class CombinedDictDataset(torch.utils.data.Dataset): + """ + A dataset that wraps multiple PyTorch datasets and returns a dictionary of data items from each dataset for a given index. + This dataset ensures that all constituent datasets have the same length by setting the length to the minimum length of the datasets provided. + + Parameters: + ----------- + **datasets : Dict[str, Dataset] + A dictionary where keys are string identifiers for the datasets and values are the datasets instances themselves. + + Attributes: + ----------- + datasets : Dict[str, Dataset] + Stores the input datasets. + max_length : int + The minimum length among all provided datasets, determining the length of this combined dataset. + + Examples: + --------- + >>> dataset1 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) + >>> dataset2 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) + >>> combined_dataset = CombinedDictDataset(dataset1=dataset1, dataset2=dataset2) + >>> print(len(combined_dataset)) + 100 + >>> data = combined_dataset[50] + >>> print(data.keys()) + dict_keys(['dataset1', 'dataset2']) + """ + + def __init__(self, **datasets: Dict[str, Dataset]) -> None: + """ + Initializes the CombinedDictDataset with multiple datasets. + + Args: + **datasets (Dict[str, Dataset]): Key-value pairs where keys are dataset names and values + are dataset instances. Each key-value pair adds a dataset + under the specified key. + """ + self.datasets = datasets + self.max_length = min([len(dataset) for dataset in datasets.values()]) + + def __len__(self) -> int: + return self.max_length + + def __getitem__(self, index: int) -> Dict[str, Any]: + """ + Retrieves an item from each dataset at the specified index, combines them into a dictionary, + and returns the dictionary. Each key in the dictionary corresponds to one of the dataset names provided + during initialization, and its value is the item from that dataset at the given index. + + Args: + index (int): The index of the items to retrieve across all datasets. + + Returns: + Dict[str, Any]: A dictionary containing data items from all datasets for the given index. + Each key corresponds to a dataset name, and its value is the data item from that dataset. + """ + data = {} + for key, dataset in self.datasets.items(): + data[key] = dataset[index] + return data diff --git a/cosmos_predict1/tokenizer/training/datasets/utils.py b/cosmos_predict1/tokenizer/training/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca7c109e778d0329fa9e76b1e8b74b39b4a6972 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/utils.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for datasets creation.""" + +IMAGE_KEY = "images" +VIDEO_KEY = "video" +RECON_KEY = "reconstructions" +LATENT_KEY = "latent" +INPUT_KEY = "INPUT" +MASK_KEY = "loss_mask" + +_SPATIAL_ALIGN = 16 + + +import math +from typing import Union + +import torch +from PIL import Image + +# This is your "for short_side=720" map: +_ASPECT_SIZE_DICT = { + "1,1": (720, 720), + "4,3": (960, 720), + "3,4": (720, 960), + "16,9": (1280, 720), + "9,16": (720, 1280), +} + + +VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = { + "1080": { # 1080p doesn't have 1:1 + "4,3": (1440, 1072), + "3,4": (1072, 1440), + "16,9": (1920, 1072), + "9,16": (1072, 1920), + }, + "720": {"1,1": (720, 720), "4,3": (960, 720), "3,4": (720, 960), "16,9": (1280, 720), "9,16": (720, 1280)}, + "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (854, 480), "9,16": (480, 854)}, + "512": {"1,1": (512, 512), "4,3": (672, 512), "3,4": (512, 672), "16,9": (896, 512), "9,16": (512, 896)}, + "360": {"1,1": (320, 320), "4,3": (416, 320), "3,4": (320, 416), "16,9": (544, 320), "9,16": (320, 544)}, + "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)}, + "128": { # Note that we set res lower than 256 to the same resolution as 256 + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (448, 256), + "9,16": (256, 448), + }, +} + +VIDEO_VAL_CROP_SIZE_INFO: dict[str, tuple[int, int]] = { + "1080": { # 1080p doesn't have 1:1 + "4,3": (1424, 1072), + "3,4": (1072, 1424), + "16,9": (1904, 1072), + "9,16": (1072, 1904), + "16,10": (1715, 1072), + }, + "720": {"1,1": (704, 704), "4,3": (944, 704), "3,4": (704, 944), "16,9": (1264, 704), "9,16": (704, 1264)}, + "480": {"1,1": (464, 464), "4,3": (624, 464), "3,4": (464, 624), "16,9": (848, 464), "9,16": (464, 848)}, + "360": {"1,1": (320, 320), "4,3": (416, 320), "3,4": (320, 416), "16,9": (544, 320), "9,16": (320, 544)}, + "512": {"1,1": (512, 512), "4,3": (672, 512), "3,4": (512, 672), "16,9": (896, 512), "9,16": (512, 896)}, + "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)}, + "128": { # Note that we set res lower than 256 to the same resolution as 256 + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (320, 192), + "9,16": (192, 320), + "16,10": (410, 256), + }, +} + + +def _pick_closest_aspect_ratio(height, width): + """ + Given a video's height and width, return the closest aspect ratio key + from aspect_dict. + """ + if height == 0: + return "1,1" # fallback if something weird, to avoid div by zero + + actual_ratio = width / height + + best_key = None + min_diff = math.inf + + for ratio_key, (w_target, h_target) in _ASPECT_SIZE_DICT.items(): + # for "16,9" -> (1280, 720), ratio is 1280/720 = 1.7777... + ratio = w_target / h_target + diff = abs(actual_ratio - ratio) + if diff < min_diff: + min_diff = diff + best_key = ratio_key + + return best_key + + +def categorize_aspect_and_store(data_sample): + """ + data_sample: a dict with 'video' shaped [C,T,H,W]. + We will determine the aspect ratio, pick the closest "1,1", "4,3", etc., + and store a new dict entry. + """ + # Suppose 'video' is [C, T, H, W]. + video_tensor = data_sample["video"] + H = video_tensor.shape[-2] + W = video_tensor.shape[-1] + data_sample["aspect_ratio"] = _pick_closest_aspect_ratio(H, W) + return data_sample + + +def get_crop_size_info(crop_sz: int = 128): + aspect_ratios = [(1, 1), (4, 3), (3, 4), (16, 9), (9, 16)] + crop_sizes = dict() + for aspect_ratio in aspect_ratios: + if aspect_ratio[0] < aspect_ratio[1]: + crop_h = crop_sz // _SPATIAL_ALIGN * _SPATIAL_ALIGN + crop_w = int(crop_h * aspect_ratio[0] / aspect_ratio[1] + 0.5) + crop_w = crop_w // _SPATIAL_ALIGN * _SPATIAL_ALIGN + else: + crop_w = crop_sz // _SPATIAL_ALIGN * _SPATIAL_ALIGN + crop_h = int(crop_w * aspect_ratio[1] / aspect_ratio[0] + 0.5) + crop_h = crop_h // _SPATIAL_ALIGN * _SPATIAL_ALIGN + key = f"{aspect_ratio[0]},{aspect_ratio[1]}" + crop_sizes.update({key: (crop_w, crop_h)}) + return crop_sizes + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + return aug_size diff --git a/cosmos_predict1/tokenizer/training/datasets/video_dataset.py b/cosmos_predict1/tokenizer/training/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..169a9bc92fd598f9c586dbcf62f1fc8635dedf37 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/video_dataset.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/tokenizer/training/datasets/video_dataset.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from glob import glob + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + video_pattern, + sequence_interval=1, + start_frame_interval=1, + num_video_frames=25, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + video_pattern (str): path/to/videos/*.mp4 + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.video_directory_or_pattern = video_pattern + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_video_frames + + self.video_paths = sorted(glob(str(video_pattern))) + print(f"{len(self.video_paths)} videos in total") + + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose( + [ + ToTensorVideo(), + ] + ) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.video_directory_or_pattern}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + return frame_data + + def _get_frames(self, video_path, frame_ids): + frames = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, data) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "start_frame_id": str(frame_ids[0]), + } + data["fps"] = 24 + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) # .cuda() # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) # .cuda() + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + video_directory_or_pattern="assets/example_training_data/videos/*.mp4", + sequence_interval=1, + num_frames=57, + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print((f"{idx=} " f"{data['video'].sum()=}\n" f"{data['video'].shape=}\n" f"{data['video_name']=}\n" "---")) diff --git a/cosmos_predict1/tokenizer/training/jit_cli.py b/cosmos_predict1/tokenizer/training/jit_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..6520a2afbb7ff2c4b024766cab652b4475f5bab4 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/jit_cli.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to export an pre-trained tokenizer checkpoint into a torch.ScriptModule. + +Usage: +python3 -m cosmos_predict1.tokenizer.training.jit_cli \ + --ckpt_path=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/iter_001000000.pt \ + --output_dir=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/exported \ + --strict_resume \ + --config=cosmos_predict1/tokenizer/training/configs/config.py -- \ + experiment=CV720_Causal_AE49_4x8x8_cosmos + + + will output: + /iter_001000000_ema.jit + /iter_001000000_enc.jit + /iter_001000000_dec.jit + + if --reg is specified, it will export the regular model: + /iter_001000000_reg.jit + /iter_001000000_enc.jit + /iter_001000000_dec.jit +""" + +import argparse +import importlib +import os + +import torch +from loguru import logger as logging +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer +from cosmos_predict1.utils import callback, ema +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.model import Model + +parser = argparse.ArgumentParser(description="Export a pre-trained model into a torch.jit.ScriptModule.") +parser.add_argument( + "--config", type=str, default="cosmos_predict1/tokenizer/training/configs/config.py", help="Path to the config file" +) +parser.add_argument("--ckpt_path", type=str, default=None, help="The full ckpt path.") +parser.add_argument("--credentials", type=str, default="credentials/pdx_vfm_base.secret", help="The credentials file.") +parser.add_argument("--strict_resume", action="store_true", help="Enable strictly loading into every network weight.") +parser.add_argument("--reg", action="store_true", help="Enable regular model export.") +parser.add_argument("--output_dir", type=str, default=None, help="Optional output directory.") + +parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, +) + +logging.info("Initialize args, cfg from command line arguments ...") +args = parser.parse_args() +config_module = get_config_module(args.config) +config: Config = importlib.import_module(config_module).make_config() +config = override(config, args.opts) + + +def _compile_jit_models(model: Model) -> dict[str, torch.ScriptModule]: + """Returns a TorchScript version of REG or EMA models compiled by PyTorch JIT.""" + assert hasattr(config, "checkpoint") and hasattr(config.checkpoint, "jit") + config_jit = config.checkpoint.jit + input_shape = tuple(config_jit.input_shape) + example_input = torch.randn(input_shape) + dtype = getattr(torch, config_jit.dtype) + example_input = example_input.to(config_jit.device).to(dtype) + + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + with ema.ema_scope(model, enabled=model.config.ema.enabled and not args.reg): + _model = model.network.eval() + if isinstance(_model, torch_OptimizedModule): + _model = _model._orig_mod + model_jit = torch.jit.trace(_model, example_input, strict=config_jit.strict) + encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=config_jit.strict) + decoder_example = encoder_jit(example_input)[0] + decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=config_jit.strict) + if args.reg: + return {"reg": model_jit, "enc": encoder_jit, "dec": decoder_jit} + return {"ema": model_jit, "enc": encoder_jit, "dec": decoder_jit} + + +def _run_export() -> None: + """Exports a torch.nn.Module into a torch.jit.ScriptModule.""" + # Check that the config is valid. + config.validate() + config.checkpoint.load_path = args.ckpt_path + config.checkpoint.strict_resume = args.strict_resume + config.checkpoint.load_training_state = False + config.job.name = os.path.basename(args.output_dir) if args.output_dir else os.path.basename(args.ckpt_path) + + # Freeze the config. + config.freeze() # type: ignore + callbacks = callback.CallBackGroup(config=config, trainer=None) + checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=callbacks) + + # Create the model. + logging.info(f"Instantiate model={config.model.config.network.name} ...") + model = instantiate(config.model) + model = model.to("cuda", memory_format=config.trainer.memory_format) # type: ignore + model.on_train_start(config.trainer.memory_format) + + logging.info(f"loading weights from {config.checkpoint.load_path}...") + _ = checkpointer.load(model) + model.eval() + ckpt_name = config.checkpoint.load_path.split("/")[-1][:-3] + + # Drive the output directory. + tmp_output_dir = os.path.dirname(config.checkpoint.load_path) + output_dir = args.output_dir or tmp_output_dir + os.makedirs(output_dir, exist_ok=True) + + logging.info("Performing JIT compilation ...") + jit_models = _compile_jit_models(model) + for name, jit_model in jit_models.items(): + logging.info(f"Outputing torch.jit: {output_dir}/{ckpt_name}_{name}.jit") + torch.jit.save(jit_model, f"{output_dir}/{ckpt_name}_{name}.jit") + + +@logging.catch(reraise=True) +def main() -> None: + _run_export() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/training/losses/__init__.py b/cosmos_predict1/tokenizer/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f946d23712a17a93be7fc9668985dc7cf070b01c --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/__init__.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The loss reduction modes.""" + +from enum import Enum + +import torch + + +def _mean(recon: torch.Tensor) -> torch.Tensor: + return torch.mean(recon) + + +def _sum_per_frame(recon: torch.Tensor) -> torch.Tensor: + batch_size = recon.shape[0] * recon.shape[2] if recon.ndim == 5 else recon.shape[0] + return torch.sum(recon) / batch_size + + +def _sum(recon: torch.Tensor) -> torch.Tensor: + return torch.sum(recon) / recon.shape[0] + + +class ReduceMode(Enum): + MEAN = "MEAN" + SUM_PER_FRAME = "SUM_PER_FRAME" + SUM = "SUM" + + @property + def function(self): + if self == ReduceMode.MEAN: + return _mean + elif self == ReduceMode.SUM_PER_FRAME: + return _sum_per_frame + elif self == ReduceMode.SUM: + return _sum + else: + raise ValueError("Invalid ReduceMode") diff --git a/cosmos_predict1/tokenizer/training/losses/continuous.py b/cosmos_predict1/tokenizer/training/losses/continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..86178374dd737d61f5f0783bf5e781cc261a1734 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/continuous.py @@ -0,0 +1,479 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The combined loss functions for continuous-space tokenizers training.""" +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torchvision.models.optical_flow as optical_flow + +from cosmos_predict1.tokenizer.modules.utils import batch2time, time2batch +from cosmos_predict1.tokenizer.training.datasets.utils import INPUT_KEY, LATENT_KEY, MASK_KEY, RECON_KEY +from cosmos_predict1.tokenizer.training.losses import ReduceMode +from cosmos_predict1.tokenizer.training.losses.lpips import LPIPS +from cosmos_predict1.utils.lazy_config import instantiate + +_VALID_LOSS_NAMES = ["color", "perceptual", "flow", "kl", "video_consistency"] +VIDEO_CONSISTENCY_LOSS = "video_consistency" +RECON_CONSISTENCY_KEY = f"{RECON_KEY}_consistency" + + +class TokenizerLoss(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + _reduce = ReduceMode(config.reduce.upper()) if hasattr(config, "reduce") else None + self.reduce = _reduce.function + self.loss_modules = nn.ModuleDict() + for key in _VALID_LOSS_NAMES: + self.loss_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NullLoss() + + def forward(self, inputs, output_batch, iteration) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + loss = dict() + total_loss = 0.0 + + inputs[MASK_KEY] = torch.ones_like(inputs[INPUT_KEY]) + # Calculates reconstruction losses (`total_loss`). + for key, module in self.loss_modules.items(): + curr_loss = module(inputs, output_batch, iteration) + loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) + total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) + + loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) + + # Computes the overall loss as sum of the reconstruction losses and the generator loss. + total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) + return dict(loss=loss), total_loss + + +class WeightScheduler(torch.nn.Module): + def __init__(self, boundaries, values): + super().__init__() + self.boundaries = list(boundaries) + self.values = list(values) + + def forward(self, iteration): + for boundary, value in zip(self.boundaries, self.values): + if iteration < boundary: + return value + return self.values[-1] + + +class NullLoss(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, inputs, output_batch, iteration) -> dict[dict, torch.Tensor]: + return dict() + + +class ColorLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + reconstructions = output_batch[RECON_KEY] + weights = inputs[MASK_KEY] + recon = weights * torch.abs(inputs[INPUT_KEY].contiguous() - reconstructions.contiguous()) + color_weighted = self.schedule(iteration) * recon + if torch.isnan(color_weighted).any(): + raise ValueError("[COLOR] NaN detected in loss") + return dict(color=color_weighted) + + +class KLLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + + def kl(self, mean, logvar): + _dims = [idx for idx in range(1, mean.ndim)] + var = torch.exp(logvar) + return 0.5 * (torch.pow(mean, 2) + var - 1.0 - logvar) + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + if "posteriors" not in output_batch: # No KL loss for discrete tokens. + return dict() + mean, logvar = output_batch["posteriors"] + if mean.ndim == 1: # No KL if the mean is a scalar. + return dict() + kl = self.kl(mean, logvar) + kl_weighted = self.schedule(iteration) * kl + if torch.isnan(kl_weighted).any(): + raise ValueError("[KL] NaN detected in loss") + return dict(kl=kl_weighted) + + +class PerceptualLoss(LPIPS): + """Relevant changes that're internal to us: + + - Remove linear projection layers, simply use the raw pre-normalized features. + - Use pyramid-layer weights: [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]. + - Accepts pixel-wise masks and modulates the features before norm calculation. + - Implements gram-matrix and correlation losses. + """ + + def __init__(self, config): + super(PerceptualLoss, self).__init__(config.checkpoint_activations) + self.net = self.net.eval() + self.gram_enabled = config.gram_enabled + self.corr_enabled = config.corr_enabled + self.layer_weights = list(config.layer_weights) + self.lpips_schedule = WeightScheduler(config.lpips_boundaries, config.lpips_values) + self.gram_schedule = WeightScheduler(config.gram_boundaries, config.gram_values) + self.corr_schedule = WeightScheduler(config.corr_boundaries, config.corr_values) + self.checkpoint_activations = config.checkpoint_activations + + def _temporal_gram_matrix(self, x, batch_size=None): + x = batch2time(x, batch_size) + c, t, h, w = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1] + reshaped_x = torch.reshape(x, [-1, c, t * h * w]) + return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(t * h * w) + + def _gram_matrix(self, x, batch_size=None): + if batch_size is not None and x.shape[0] != batch_size: + return self._temporal_gram_matrix(x, batch_size) + c, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + reshaped_x = torch.reshape(x, [-1, c, h * w]) + return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(h * w) + + def forward(self, inputs, output_batch, iteration): + output_dict = dict() + reconstructions = output_batch[RECON_KEY] + weights = inputs[MASK_KEY] + input_images = inputs[INPUT_KEY] + + if input_images.ndim == 5: + input_images, batch_size = time2batch(input_images) + reconstructions, _ = time2batch(reconstructions) + weights, _ = time2batch(weights) + else: + batch_size = input_images.shape[0] + + in0_input, in1_input = (self.scaling_layer(input_images), self.scaling_layer(reconstructions)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + + _layer_weights = self.layer_weights + weights_map, res, diffs = {}, {}, {} + for kk in range(len(self.chns)): + weights_map[kk] = torch.nn.functional.interpolate(weights[:, :1, ...], outs0[kk].shape[-2:]) + diffs[kk] = weights_map[kk] * torch.abs(outs0[kk] - outs1[kk]) + res[kk] = _layer_weights[kk] * diffs[kk].mean([1, 2, 3], keepdim=True) + + val = res[0] + for ll in range(1, len(self.chns)): + val += res[ll] + # Scale by number of pixels to match pixel-wise losses. + val = val.expand(-1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1]) + if batch_size != input_images.shape[0]: + val = batch2time(val, batch_size) + if torch.isnan(val).any(): + raise ValueError("[LPIPS] NaN detected in loss") + output_dict["lpips"] = self.lpips_schedule(iteration) * val + + if self.gram_enabled and self.gram_schedule(iteration) > 0.0: + num_chans = len(self.chns) + grams0 = [self._gram_matrix(weights_map[kk] * outs0[kk], batch_size) for kk in range(num_chans)] + grams1 = [self._gram_matrix(weights_map[kk] * outs1[kk], batch_size) for kk in range(num_chans)] + gram_diffs = [(grams0[kk] - grams1[kk]) ** 2 for kk in range(num_chans)] + grams_res = [_layer_weights[kk] * gram_diffs[kk].mean([1, 2], keepdim=True) for kk in range(num_chans)] + gram_val = grams_res[0] + for ll in range(1, len(self.chns)): + gram_val += grams_res[ll] + + # Scale by number of total pixels to match pixel-wise losses. + gram_val = gram_val.unsqueeze(1).expand( + -1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1] + ) + if batch_size != input_images.shape[0]: + gram_val = batch2time(gram_val, batch_size) + if torch.isnan(gram_val).any(): + raise ValueError("[GRAM] NaN detected in loss") + output_dict["gram"] = self.gram_schedule(iteration) * gram_val + return output_dict + + def torch_compile(self): + """ + This method invokes torch.compile() on this loss + """ + # cuda-graphs crash after 1k iterations + self.net = torch.compile(self.net, dynamic=False) + + +class FlowLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(config.boundaries, config.values) + self.scale = config.scale + self.dtype = getattr(torch, config.dtype) + self.checkpoint_activations = config.checkpoint_activations + self.enabled = config.enabled + + current_device = torch.device(torch.cuda.current_device()) + + # In order to be able to run model in bf16 we need to change make_coords_grid() + # to allow it to return arbitrary type provided by us in argument + # the line from orginal implementation that caused results to be only fp32 is commented + # Additionally I've changed that function to run on GPU instead of CPU, which results in + # less graph breaks when torch.compile() is used + # This function is copied from + # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/_utils.py#L22 + # commit: b06ea39d5f0adbe949d08257837bda912339e415 + def make_coords_grid( + batch_size: int, h: int, w: int, device: torch.device = current_device, dtype: torch.dtype = self.dtype + ): + # Original: def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): + device = torch.device(device) + coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).to(dtype) + # Original: coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch_size, 1, 1, 1) + + # We also need to specify output dtype of torch.linspace() in index_pyramid() + # method of CorrBlock, otherwise it uses default fp32 dtype as output. + # Additionally I've changed that function to run on GPU instead of CPU, which results in + # less graph breaks when torch.compile() is used + # This function is copied from + # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py#L394 + # commit: b06ea39d5f0adbe949d08257837bda912339e415 + def index_pyramid( + self, centroids_coords, dtype: torch.dtype = self.dtype, device: torch.device = current_device + ): + # Original: def index_pyramid(self, centroids_coords): + """Return correlation features by indexing from the pyramid.""" + neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels + di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) + dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) + # Original: di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + # Original: dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) + delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) + + batch_size, _, h, w = centroids_coords.shape # _ = 2 + centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) + + indexed_pyramid = [] + for corr_volume in self.corr_pyramid: + sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) + indexed_corr_volume = optical_flow.raft.grid_sample( + corr_volume, sampling_coords, align_corners=True, mode="bilinear" + ).view(batch_size, h, w, -1) + indexed_pyramid.append(indexed_corr_volume) + centroids_coords = centroids_coords / 2 + + corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() + + expected_output_shape = (batch_size, self.out_channels, h, w) + if corr_features.shape != expected_output_shape: + raise ValueError( + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" + ) + + return corr_features + + optical_flow.raft.make_coords_grid = make_coords_grid + optical_flow.raft.CorrBlock.index_pyramid = index_pyramid + + flow_model = optical_flow.raft_large(pretrained=True, progress=False) + flow_model.requires_grad_(False) + flow_model.eval() + flow_model = flow_model.to(self.dtype) + + self.flow_model = flow_model + + def _run_model(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + """Runs flow_model in the forward mode on explicit dtype=float32. + + Args: + input1: First video frames batch, layout (T, C, H, W), bfloat16. + input2: Next video frames batch, layout (T, C, H, W), bfloat16. + + Returns: + Forward optical flow, (T, 2, H, W), bfloat16. + """ + input_dtype = input1.dtype + flow_output = self.flow_model.to(self.dtype)(input1.to(self.dtype), input2.to(self.dtype))[-1] + return flow_output.to(input_dtype) + + def _run_model_fwd(self, input_video: torch.Tensor) -> torch.Tensor: + """Runs foward flow on a batch of videos, one batch at a time. + Args: + input_video: The input batch of videos, layout (B, T, C, H, W). + + Returns: + Forward optical flow, layout (B, 2, T-1, H, W). + """ + output_list = list() + for fwd_input_frames in input_video: + fwd_input_frames = fwd_input_frames.transpose(1, 0) + fwd_flow_output = self._run_model(fwd_input_frames[:-1], fwd_input_frames[1:]) + output_list.append(fwd_flow_output.transpose(1, 0)) + return torch.stack(output_list, dim=0) + + def _bidirectional_flow(self, input_video: torch.Tensor) -> torch.Tensor: + """The bidirectional optical flow on a batch of videos. + + The forward and backward flows are averaged to get the bidirectional flow. + To reduce memory pressure, the input video is scaled down by a factor of `self.scale`, + and rescaled back to match other pixel-wise losses. + + Args: + input_video: The input batch of videos, layout (B, T, C, H, W). + + Returns: + Biderectinoal flow, layout (B, 2, T-1, H, W). + """ + # scale down the input video to reduce memory pressure. + t, h, w = input_video.shape[-3:] + input_video_scaled = F.interpolate(input_video, (t, h // self.scale, w // self.scale), mode="trilinear") + + # forward flow. + if self.checkpoint_activations: + fwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) + else: + fwd_flow_output = self._run_model_fwd(input_video_scaled) + + # backward flow. + input_video_scaled = input_video_scaled.flip([2]) + if self.checkpoint_activations: + bwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) + else: + bwd_flow_output = self._run_model_fwd(input_video_scaled) + bwd_flow_output = bwd_flow_output.flip([2]) + + # bidirectional flow, concat fwd and bwd along temporal axis. + flow_input = torch.cat([fwd_flow_output, bwd_flow_output], dim=2) + return self.scale * F.interpolate(flow_input, (2 * (t - 1), h, w), mode="trilinear") + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + input_images = inputs[INPUT_KEY] + if input_images.ndim == 4 or input_images.shape[2] == 1: + return dict() + if not self.enabled or self.schedule(iteration) == 0.0: + return dict() + + # Biderectional flow (B, 2, 2*(T-1), H, W) + flow_input = self._bidirectional_flow(input_images) + flow_recon = self._bidirectional_flow(output_batch[RECON_KEY]) + + # L1 loss on the flow. (B, 1, 2*(T-1), H, W) + flow_loss = torch.abs(flow_input - flow_recon).mean(dim=1, keepdim=True) + + flow_loss_weighted = self.schedule(iteration) * flow_loss + if torch.isnan(flow_loss_weighted).any(): + raise ValueError("[FLOW] NaN detected in loss") + return dict(flow=flow_loss_weighted) + + def torch_compile(self): + """ + This method invokes torch.compile() on this loss + """ + self.flow_model = torch.compile(self.flow_model, dynamic=False) + + +class VideoConsistencyLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + self.enabled = config.enabled + self.num_frames = config.num_frames + self.step = config.step + self.num_windows = None + + def shuffle(self, inputs: torch.Tensor) -> torch.Tensor: + """ + For input video of [B, 3, T, H, W], this function will reshape the video to + the shape of [B*(T-num_frames+1)//step, 3, num_frames, H, W] using a sliding window + This function is used to compute the temporal consistency between overlapped frames + to enable temporal consistency + """ + assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" + B, C, T, H, W = inputs.shape + assert T >= self.num_frames, f"inputs {T} should be greater than {self.num_frames}" + + # [B, C, num_windows, H, W, num_frames] + outputs = inputs.unfold(dimension=2, size=self.num_frames, step=self.step) + self.num_windows = outputs.shape[2] + outputs = einops.rearrange(outputs, "b c m h w n -> (b m) c n h w") + + return outputs + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + if not self.enabled or self.num_windows is None: + return dict() + if self.schedule(iteration) == 0.0: + return dict() + # reshape output_batch to compute loss between overlapped frames + reconstructions = output_batch[RECON_CONSISTENCY_KEY] + B, C, T, H, W = reconstructions.shape + + assert T == self.num_frames, f"reconstruction shape invalid (shape[2] should be {self.num_frames})" + assert ( + B % self.num_windows == 0 + ), f"reconstruction shape invalid (shape[0]={B} not dividable by {self.num_windows})" + + B = B // self.num_windows + videos = reconstructions.view(B, self.num_windows, C, self.num_frames, H, W) + + # Compute the L1 distance between overlapped frames for all windows at once + diff = torch.mean(torch.abs(videos[:, :-1, :, self.step :, :, :] - videos[:, 1:, :, : -self.step, :, :])) + diff_weighted = self.schedule(iteration) * diff + + if LATENT_KEY not in output_batch: + return dict(frame_consistency=diff_weighted) + + B_latent, C_latent, T_latent, H_latent, W_latent = output_batch["latent"].shape + assert B_latent % self.num_windows == 0, f"latent batches should be divisible by {self.num_windows}" + + latents = output_batch[LATENT_KEY].view( + B_latent // self.num_windows, self.num_windows, C_latent, T_latent, H_latent, W_latent + ) + temporal_rate = self.num_frames // T_latent + spatial_rate = (H // H_latent) * (W // W_latent) + step_latent = self.step // temporal_rate + latent_diff = torch.mean( + torch.abs(latents[:, :-1, :, step_latent:, :, :] - latents[:, 1:, :, :-step_latent, :, :]) + ) + latent_diff_weighted = self.schedule(iteration) * latent_diff * (C * temporal_rate * spatial_rate) / (C_latent) + return dict(frame_consistency=diff_weighted, latent_consistency=latent_diff_weighted) + + def unshuffle(self, inputs: torch.Tensor) -> torch.Tensor: + """ + For input video of [B*num_windows, 3, num_frames, H, W], this function will + undo the shuffle to a tensor of shape [B, 3, T, H, W] + """ + assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" + B, C, T, H, W = inputs.shape + assert T == self.num_frames, f"inputs shape invalid (shape[2] should be {self.num_frames})" + assert B % self.num_windows == 0, f"inputs shape invalid (shape[0]={B} not dividable by {self.num_windows})" + + B = B // self.num_windows + videos = inputs.view(B, self.num_windows, C, self.num_frames, H, W) + + T = self.num_frames + (self.num_windows - 1) * self.step + current_device = torch.device(torch.cuda.current_device()) + outputs = torch.zeros(B, C, T, H, W).to(inputs.dtype).to(current_device) + counter = torch.zeros_like(outputs) + for i in range(self.num_windows): + outputs[:, :, i * self.step : i * self.step + self.num_frames, :, :] += videos[:, i, :, :, :, :] + counter[:, :, i * self.step : i * self.step + self.num_frames, :, :] += 1 + outputs = outputs / counter + + return outputs diff --git a/cosmos_predict1/tokenizer/training/losses/lpips.py b/cosmos_predict1/tokenizer/training/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b866177d11d6eef033529b32679a087dbfe102ab --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/lpips.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LPIPS loss. + +Adapted from: github.com/CompVis/stable-diffusion/ldm/modules/losses/contperceptual.py. +""" + +import hashlib +import os +from collections import namedtuple + +import requests +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from loguru import logger as logging +from torchvision import models +from tqdm import tqdm + +from cosmos_predict1.utils.distributed import is_rank0 + +_TORCH_HOME = os.getenv("TORCH_HOME", "/mnt/workspace/.cache/torch") +_URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} +_CKPT_MAP = {"vgg_lpips": "vgg.pth"} +_MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def _download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def _md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def _get_ckpt_path(name, root, check=False): + assert name in _URL_MAP + path = os.path.join(root, _CKPT_MAP[name]) + if not os.path.exists(path) or (check and not _md5_hash(path) == _MD5_MAP[name]): + logging.info("Downloading {} model from {} to {}".format(name, _URL_MAP[name], path)) + _download(_URL_MAP[name], path) + md5 = _md5_hash(path) + assert md5 == _MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + def __init__(self, checkpoint_activations: bool = False): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False, checkpoint_activations=checkpoint_activations) + + if dist.is_initialized() and not is_rank0(): + dist.barrier() + self.load_from_pretrained() + if dist.is_initialized() and is_rank0(): + dist.barrier() + + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = _get_ckpt_path(name, f"{_TORCH_HOME}/hub/checkpoints") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + logging.info("Loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = _get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [diffs[kk].mean([1, 2, 3], keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None], persistent=False) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None], persistent=False) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, checkpoint_activations: bool = False): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.checkpoint_activations = checkpoint_activations + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice1, X, use_reentrant=False) + else: + h = self.slice1(X) + h_relu1_2 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice2, h, use_reentrant=False) + else: + h = self.slice2(h) + h_relu2_2 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice3, h, use_reentrant=False) + else: + h = self.slice3(h) + h_relu3_3 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice4, h, use_reentrant=False) + else: + h = self.slice4(h) + h_relu4_3 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice5, h, use_reentrant=False) + else: + h = self.slice5(h) + h_relu5_3 = h + + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out diff --git a/cosmos_predict1/tokenizer/training/metrics.py b/cosmos_predict1/tokenizer/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..c541999a1de5cb44cf1337933b9827d473a84ab3 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/metrics.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The combined loss functions for continuous-space tokenizers training.""" + +import numpy as np +import torch +import torch.nn as nn +from skimage.metrics import structural_similarity as ssim + +from cosmos_predict1.tokenizer.modules.utils import time2batch +from cosmos_predict1.utils.lazy_config import instantiate + +_VALID_METRIC_NAMES = ["PSNR", "SSIM", "CodeUsage"] +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_FLOAT32_EPS = torch.finfo(torch.float32).eps +_RECONSTRUCTION = "reconstructions" +_QUANT_INFO = "quant_info" + + +class TokenizerMetric(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.metric_modules = nn.ModuleDict() + for key in _VALID_METRIC_NAMES: + self.metric_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NULLMetric() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + metric = dict() + for _, module in self.metric_modules.items(): + metric.update(module(inputs, output_batch, iteration)) + return dict(metric=metric) + + +class NULLMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + return dict() + + +class PSNRMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + reconstructions = output_batch[_RECONSTRUCTION] + if inputs.ndim == 5: + inputs, _ = time2batch(inputs) + reconstructions, _ = time2batch(reconstructions) + + # Normalize to uint8 [0..255] range. + true_image = (inputs.to(torch.float32) + 1) / 2 + pred_image = (reconstructions.to(torch.float32) + 1) / 2 + true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + + # Calculate PNSR, based on Mean Squared Error (MSE) + true_image = true_image.to(torch.float32) + pred_image = pred_image.to(torch.float32) + mse = torch.mean((true_image - pred_image) ** 2, dim=(1, 2, 3)) + psnr = 10 * torch.log10(_UINT8_MAX_F**2 / (mse + _FLOAT32_EPS)) + return dict(PSNR=torch.mean(psnr)) + + +class SSIMMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + reconstructions = output_batch[_RECONSTRUCTION] + if inputs.ndim == 5: + inputs, _ = time2batch(inputs) + reconstructions, _ = time2batch(reconstructions) + + # Normalize to uint8 [0..255] range. + true_image = (inputs.to(torch.float32) + 1) / 2 + pred_image = (reconstructions.to(torch.float32) + 1) / 2 + true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + + # Move tensors to CPU and convert to numpy arrays + true_image_np = true_image.permute(0, 2, 3, 1).cpu().numpy() + pred_image_np = pred_image.permute(0, 2, 3, 1).cpu().numpy() + + # Calculate SSIM for each image in the batch and average over the batch + ssim_values = [] + for true_image_i, pred_image_i in zip(true_image_np, pred_image_np): + ssim_value = ssim(true_image_i, pred_image_i, data_range=_UINT8_MAX_F, multichannel=True, channel_axis=-1) + ssim_values.append(ssim_value) + ssim_mean = np.mean(ssim_values) + return dict(SSIM=torch.tensor(ssim_mean, dtype=torch.float32, device=inputs.device)) + + +class CodeUsageMetric(torch.nn.Module): + """ + Calculate the perplexity of codebook usage (only for discrete tokenizers) + + :param codebook_indices: Tensor of codebook indices (quant_info) + :param codebook_size: The total number of codebook entries + :return: Perplexity of the codebook usage + """ + + def __init__(self, codebook_size: int) -> None: + super().__init__() + self.codebook_size = codebook_size + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + code_indices = output_batch[_QUANT_INFO] + usage_counts = torch.bincount(code_indices.flatten().int(), minlength=self.codebook_size) + total_usage = usage_counts.sum().float() + usage_probs = usage_counts.float() / total_usage + entropy = -torch.sum(usage_probs * torch.log(usage_probs + _FLOAT32_EPS)) + perplexity = torch.exp(entropy) + return dict(CodeUsage=torch.tensor(perplexity, dtype=torch.float32, device=code_indices.device)) diff --git a/cosmos_predict1/tokenizer/training/model.py b/cosmos_predict1/tokenizer/training/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a75630d01950abcdf0f7247fd8868a3a612694c --- /dev/null +++ b/cosmos_predict1/tokenizer/training/model.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements the forward op for training, validation, and inference.""" + +from typing import Any + +import torch + +from cosmos_predict1.tokenizer.training.datasets.utils import IMAGE_KEY, INPUT_KEY, MASK_KEY, RECON_KEY, VIDEO_KEY +from cosmos_predict1.tokenizer.training.losses.continuous import RECON_CONSISTENCY_KEY, VIDEO_CONSISTENCY_LOSS +from cosmos_predict1.utils import ema +from cosmos_predict1.utils.lazy_config import LazyDict, instantiate +from cosmos_predict1.utils.model import Model + +PREDICTION = "prediction" +EMA_PREDICTION = "ema_prediction" + + +class TokenizerModel(Model): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.network = instantiate(config.network) + self.loss = instantiate(config.loss) + self.metric = instantiate(config.metric) + self.precision = getattr(torch, config.precision) + if self.config.ema.enabled: + self.ema = ema.EMAModelTracker( + self, + beta=self.config.ema.beta, + torch_compile_buffer_renaming=self.config.ema.torch_compile_buffer_renaming, + ) + self.init_input_keys() + + def init_input_keys(self): + self.image_key = IMAGE_KEY + self.video_key = VIDEO_KEY + + def get_input_key(self, data_batch: dict[str, torch.Tensor]) -> str: + if self.image_key in data_batch: + return self.image_key + elif self.video_key in data_batch: + return self.video_key + else: + raise ValueError("Input key not found in data_batch.") + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the network. + + Args: + optimizer_config: The optimizer config for the net. + scheduler_config: The scheduler config for the net. + + Returns: + optimizer (torch.optim.Optimizer): The net optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The net optimization scheduler. + """ + optimizer_config.params = self.network.parameters() + optimizer = instantiate(optimizer_config) + scheduler_config.optimizer = optimizer + scheduler = instantiate(scheduler_config) + + return optimizer, scheduler + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + if self.config.ema.enabled: + self.ema.to(dtype=torch.float32) + self.network = self.network.to(dtype=self.precision, memory_format=memory_format) + self.loss = self.loss.to(dtype=self.precision, memory_format=memory_format) + + def state_dict( + self, destination: dict[str, Any] = None, prefix: str = "", keep_vars: bool = False + ) -> dict[str, Any]: + original_state_dict = super(TokenizerModel, self).state_dict(destination, prefix, keep_vars) + + # Filter out '.loss' and 'ema.loss-' keys from the state dict. + filtered_state_dict = {k: v for k, v in original_state_dict.items() if not k.startswith("loss.")} + filtered_state_dict = {k: v for k, v in filtered_state_dict.items() if not k.startswith("ema.loss-")} + filtered_state_dict = { + k: v for k, v in filtered_state_dict.items() if not k.startswith("network.encoder.patcher") + } + filtered_state_dict = { + k: v for k, v in filtered_state_dict.items() if not k.startswith("network.decoder.unpatcher") + } + + return filtered_state_dict + + def load_state_dict(self, state_dict: Any, strict: bool = True) -> None: + own_state = self.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in own_state} + + # Load only filtered state dict. + super(TokenizerModel, self).load_state_dict(filtered_state_dict, strict=False) + + # If strict is True, ensure all parameters are loaded (except the excluded ones) + missing_keys = set(own_state.keys()) - set(filtered_state_dict.keys()) + if missing_keys and strict: + raise KeyError(f"Missing keys in state_dict: {missing_keys}") + + def _on_before_network_forward(self, data_batch: dict[str, torch.Tensor]) -> None: + consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] + if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: + _input_key = self.get_input_key(data_batch) + if _input_key is self.video_key: + data_batch[_input_key] = consistency_loss.shuffle(data_batch[_input_key]) + return + + def _on_after_network_forward( + self, data_batch: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor] + ) -> None: + consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] + if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: + _input_key = self.get_input_key(data_batch) + if _input_key is self.video_key: + data_batch[_input_key] = consistency_loss.unshuffle(data_batch[_input_key]) + output_batch[RECON_CONSISTENCY_KEY] = torch.ones_like(output_batch[RECON_KEY]) * output_batch[RECON_KEY] + output_batch[RECON_KEY] = consistency_loss.unshuffle(output_batch[RECON_KEY]) + return + + def _network_forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + # A callback proxy to modify the input before the forward pass. + self._on_before_network_forward(data_batch) + + # Do the forward pass. + tensor_batch = data_batch[self.get_input_key(data_batch)] + output_batch = self.network(tensor_batch) + output_batch = output_batch if self.network.training else output_batch._asdict() + + # A callback proxy to modify the output after the forward pass. + self._on_after_network_forward(data_batch, output_batch) + return output_batch + + def training_step( + self, + data_batch: dict[str, torch.Tensor], + iteration: int, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] + + # pass loss_mask to loss computation + inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} + + loss_dict, loss_value = self.loss(inputs, output_dict, iteration) + return dict({PREDICTION: recon_images, **loss_dict}), loss_value + + @torch.no_grad() + def validation_step( + self, + data_batch: dict[str, torch.Tensor], + iteration: int, + ema_model: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] + + # pass loss_mask to loss computation + inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} + + loss_dict, loss_value = self.loss(inputs, output_dict, iteration) + metric_dict = self.metric(input_images, output_dict, iteration) + loss_dict.update(metric_dict) + prediction_key = EMA_PREDICTION if ema_model else PREDICTION + return dict({prediction_key: recon_images, **loss_dict}), loss_value + + @torch.inference_mode() + def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + return dict({PREDICTION: output_dict[RECON_KEY]}) diff --git a/cosmos_predict1/tokenizer/training/train.py b/cosmos_predict1/tokenizer/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4c2cf58c1f68327c2739326ead7e557c13fe25 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/train.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +from loguru import logger as logging + +from cosmos_predict1.utils.config import Config, pretty_print_overrides +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig + + +@logging.catch(reraise=True) +def launch(config: Config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # Create the model + model = instantiate(config.model) + model.on_model_init_end() + dataloader_train = instantiate(config.dataloader_train) + dataloader_val = instantiate(config.dataloader_val) + # Start training + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + + +if __name__ == "__main__": + # Usage: torchrun --nproc_per_node=1 -m scripts.train --config=projects/tutorials/mnist/config.py + + # Get the config file from the input arguments. + parser = argparse.ArgumentParser(description="Training") + parser.add_argument("--config", help="Path to the config file", required=True) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + args = parser.parse_args() + config_module = get_config_module(args.config) + config = importlib.import_module(config_module).make_config() + config = override(config, args.opts) + if args.dryrun: + logging.info( + "Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True) + ) + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(f"{config.job.path_local}/config.yaml") + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/tokenizer/training/trainer.py b/cosmos_predict1/tokenizer/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e63cb884cbc11d95f40f8a7338ec0e72439bc54e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/trainer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.data + +from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer +from cosmos_predict1.utils import ema, misc +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class TokenizerTrainer(Trainer): + """The tokenizers traine, extended from Trainer. + + It extends model training functionality. + + Attributes: + checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. + training_timer (misc.Timer): Timer object to time code blocks and functions. + """ + + def __init__(self, config): + super(TokenizerTrainer, self).__init__(config) + self.model_config = config.model.config + self.checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, _ = model.validation_step(data_batch, iteration) + with ema.ema_scope(model, enabled=model.config.ema.enabled): + ema_output_batch, loss = model.validation_step(data_batch, iteration, ema_model=True) + output_batch.update(ema_output_batch) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/utils/__init__.py b/cosmos_predict1/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/utils/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/utils/base_world_generation_pipeline.py b/cosmos_predict1/utils/base_world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..385eb883216f9579433800aaac565b6caffa727c --- /dev/null +++ b/cosmos_predict1/utils/base_world_generation_pipeline.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from abc import ABC +from typing import Any + +import numpy as np +import torch + +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.auxiliary.t5_text_encoder import CosmosT5TextEncoder, DummyT5TextEncoder + + +class BaseWorldGenerationPipeline(ABC): + def __init__( + self, + inference_type: str | None = None, + checkpoint_dir: str | None = None, + checkpoint_name: str | None = None, + has_text_input: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + disable_prompt_encoder: bool = False, + ): + """Initialize base world generation pipeline. + + This abstract base class provides core functionality for world generation models including: + - Model loading and initialization + - Text encoding and embedding + - Safety checks and content filtering + - Memory management through model offloading + + Args: + inference_type: The type of inference pipeline ("text2world" or "video2world") + checkpoint_dir: Root directory containing model checkpoints + checkpoint_name: Name of the specific checkpoint file to load + has_text_input: Whether the pipeline takes text input for world generation + offload_network: If True, moves main model to CPU after inference + offload_tokenizer: If True, moves tokenizer to CPU after use + offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding + offload_guardrail_models: If True, moves safety models to CPU after checks + disable_guardrail: If True, disable guardrail + disable_prompt_encoder: If True, disable prompt encoder + """ + self.inference_type = inference_type + self.checkpoint_dir = checkpoint_dir + self.checkpoint_name = checkpoint_name + self.has_text_input = has_text_input + + # Add offloading flags + self.offload_network = offload_network + self.offload_tokenizer = offload_tokenizer + self.offload_text_encoder_model = offload_text_encoder_model + self.offload_guardrail_models = offload_guardrail_models + + self.disable_guardrail = disable_guardrail + self.disable_prompt_encoder = disable_prompt_encoder + + # Initialize model instances + self.text_guardrail = None + self.video_guardrail = None + self.text_encoder = None + self.model = None + + self._load_model() + + if not self.offload_text_encoder_model or self.disable_prompt_encoder: + self._load_text_encoder_model() + if not self.disable_guardrail and not self.offload_guardrail_models: + if self.has_text_input: + self._load_text_guardrail() + self._load_video_guardrail() + if not self.offload_network: + self._load_network() + if not self.offload_tokenizer: + self._load_tokenizer() + + def _load_tokenizer(self): + pass + + def _load_network(self): + pass + + def _load_model(self, checkpoint_name: str) -> Any: + """Load the world generation model from a checkpoint. + + This abstract method must be implemented by subclasses to load their specific + model architecture and weights. + + Args: + checkpoint_name: Path to the model checkpoint file + + Returns: + The loaded model instance + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + pass + + def _load_text_encoder_model(self): + """Load the T5 text encoder model. + + Initializes and loads the T5 encoder model used for converting text prompts + into embeddings that condition the world generation model. + + Returns: + Loaded T5 text encoder model instance + """ + if self.disable_prompt_encoder: + self.text_encoder = DummyT5TextEncoder(device="cuda") + else: + self.text_encoder = CosmosT5TextEncoder(cache_dir=os.path.join(self.checkpoint_dir, "google-t5/t5-11b")) + + def _load_text_guardrail(self): + """Load text safety classifier models. + + Initializes models used for checking input prompts against safety policies. + Models are loaded from the specified guardrail directory. + """ + self.text_guardrail = guardrail_presets.create_text_guardrail_runner(checkpoint_dir=self.checkpoint_dir) + + def _load_video_guardrail(self): + """Load video safety classifier models. + + Initializes models used for validating generated video content against + safety policies. Models are loaded from the specified guardrail directory. + """ + self.video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=self.checkpoint_dir) + + def _offload_network(self): + if self.model.model: + del self.model.model + self.model.model = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_tokenizer(self): + if self.model.tokenizer: + del self.model.tokenizer + self.model.tokenizer = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_guardrail_models(self): + """Offload safety classifier models to reduce memory usage. + + Moves safety models to CPU and clears GPU memory if they are no longer needed. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_guardrail: + del self.text_guardrail + self.text_guardrail = None + if self.video_guardrail: + del self.video_guardrail + self.video_guardrail = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_text_encoder_model(self): + """Offload T5 text encoder to reduce memory usage. + + Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_encoder: + del self.text_encoder + self.text_encoder = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world latents using the model. + + This abstract method must be implemented by subclasses to define their specific + generation process. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + + Returns: + torch.Tensor: Generated world representation tensor + """ + pass + + def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world representation with memory management. + + Handles loading the model before inference and offloading afterward if enabled. + This helps minimize GPU memory usage during inference. + + Args: + *args: Arguments passed to _run_model + **kwargs: Keyword arguments passed to _run_model + + Returns: + np.ndarray: Generated world representation as numpy array + """ + pass + + def _run_guardrail_on_prompt(self, prompt: str) -> bool: + """Check if prompt meets safety requirements. + + Validates the input prompt against safety policies using loaded guardrail models. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) + + def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: + """Check prompt safety with memory management. + + Validates prompt safety while handling model loading/offloading to manage memory. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + if self.offload_guardrail_models: + self._load_text_guardrail() + + is_safe = self._run_guardrail_on_prompt(prompt) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + + return is_safe + + def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: + """Check if video meets safety requirements. + + Validates generated video content against safety policies using guardrail models. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video if safe, None if unsafe + """ + return guardrail_presets.run_video_guardrail(video, self.video_guardrail) + + def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: + """Check if generated video meets safety requirements. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video frames if safe, None otherwise + + Note: + Guardrail models are offloaded after checks if enabled. + """ + if self.offload_guardrail_models: + self._load_video_guardrail() + + video = self._run_guardrail_on_video(video) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + return video + + def _run_text_embedding_on_prompt( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompts to embeddings. + + Processes text prompts into embedding tensors that condition the generation model. + + Args: + prompts: List of text prompts to encode + **kwargs: Additional arguments for text encoding + + Returns: + tuple containing: + - List of text embedding tensors for each prompt + - List of attention masks for each embedding + """ + + embeddings = [] + masks = [] + for prompt in prompts: + embedding, mask = self.text_encoder.encode_prompts( + [prompt], + **kwargs, + ) + embeddings.append(embedding) + masks.append(mask) + + return embeddings, masks + + def _run_text_embedding_on_prompt_with_offload( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompt into embeddings using T5 encoder. + + Args: + prompt: Processed and validated text prompt + + Returns: + Text embedding tensor to condition diffusion model + + Note: + T5 model is offloaded after encoding if enabled. + """ + if self.offload_text_encoder_model: + self._load_text_encoder_model() + + embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) + + if self.offload_text_encoder_model: + self._offload_text_encoder_model() + return embeddings, masks + + def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: + """Decode model outputs into final world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + samples: Raw output tensor from the generation model + + Returns: + np.ndarray: Decoded world representation + """ + pass + + def generate(self, *args: Any, **kwargs: Any): + """Generate world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + """ + pass diff --git a/cosmos_predict1/utils/callback.py b/cosmos_predict1/utils/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f4b32beed9e4fb9a279a7d5c28fa278bfa4478 --- /dev/null +++ b/cosmos_predict1/utils/callback.py @@ -0,0 +1,403 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +import warnings +from typing import TYPE_CHECKING, Any, Callable, Optional + +import omegaconf +import torch +import torch.utils.data +import tqdm + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.misc import get_local_tensor_if_DTensor + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import Config + from cosmos_predict1.utils.model import Model + from cosmos_predict1.utils.trainer import Trainer + + +class CallBackGroup: + """A class for hosting a collection of callback objects. + + It is used to execute callback functions of multiple callback objects with the same method name. + When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs + self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. + + Attributes: + _callbacks (list[Callback]): List of callback objects. + """ + + def __init__(self, config: Config, trainer: Trainer) -> None: + """Initializes the list of callback objects. + + Args: + config (Config): The config object for the codebase. + trainer (Trainer): The main trainer. + """ + self._callbacks = [] + callback_configs = config.trainer.callbacks + if callback_configs: + if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): + warnings.warn( + "The 'config.trainer.callbacks' parameter should be a dict instead of a list. " + "Please update your code", + DeprecationWarning, + stacklevel=2, + ) + callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} + for callback_name, current_callback_cfg in callback_configs.items(): + if "_target_" not in current_callback_cfg: + log.critical( + f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" + ) + continue + log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}") + _callback = instantiate(current_callback_cfg) + assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." + _callback.config = config + _callback.trainer = trainer + self._callbacks.append(_callback) + + def __getattr__(self, method_name: str) -> Callable: + """Loops through the callback objects to call the corresponding callback function. + + Args: + method_name (str): Callback method name. + """ + + def multi_callback_wrapper(*args, **kwargs) -> None: + for callback in self._callbacks: + assert hasattr(callback, method_name) + method = getattr(callback, method_name) + assert callable(method) + _ = method(*args, **kwargs) + + return multi_callback_wrapper + + +class Callback: + """The base class for all callbacks. + + All callbacks should inherit from this class and adhere to the established method names and signatures. + """ + + def __init__(self, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + """Initializes a Callback object. + + Args: + config (Optional[Config]): The configuration object for the codebase, if available. + trainer (Optional[Trainer]): The main trainer handling the training loop, if available. + + Notes: + The config and trainer parameters are optional to maintain backward compatibility. + In future releases, these parameters will be removed. Upon using these parameters, a deprecation + warning will be issued. + + """ + if config is not None or trainer is not None: + warnings.warn( + "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " + "Please update your code to create Callback instances without these parameters.", + DeprecationWarning, + stacklevel=2, + ) + del config, trainer + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_before_forward(self, iteration: int = 0) -> None: + pass + + def on_after_forward(self, iteration: int = 0) -> None: + pass + + def on_before_backward( + self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 + ) -> None: + pass + + def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: + pass + + def on_before_dataloading(self, iteration: int = 0) -> None: + pass + + def on_after_dataloading(self, iteration: int = 0) -> None: + pass + + def on_optimizer_init_start(self) -> None: + pass + + def on_optimizer_init_end(self) -> None: + pass + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + pass + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + pass + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + pass + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_load_checkpoint_start(self, model: Model) -> None: + pass + + def on_load_checkpoint_end(self, model: Model) -> None: + pass + + def on_load_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_save_checkpoint_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_success(self, iteration: int = 0) -> None: + pass + + def on_save_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_train_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_app_end(self) -> None: + pass + + +class EMAModelCallback(Callback): + """The callback class for tracking EMA model weights.""" + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # Set up the EMA model weight tracker. + if model.config.ema.enabled: + assert hasattr(model, "ema"), "EMA should be initialized from Model" + # EMA model must be kept in FP32 precision. + model.ema = model.ema.to(dtype=torch.float32) + else: + assert not hasattr(model, "ema"), "There should be no EMA initialized." + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # Update the EMA model with the new regular weights. + if model.config.ema.enabled: + model.ema.update_average(model, iteration) + + +class ProgressBarCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.train_pbar.update() + + @distributed.rank0_only + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + if self.config.trainer.max_val_iter is not None: + num_iter = self.config.trainer.max_val_iter + else: + num_iter = len(dataloader_val) + assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" + self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) + + @distributed.rank0_only + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.val_pbar.update() + + @distributed.rank0_only + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + self.val_pbar.close() + + @distributed.rank0_only + def on_train_end(self, model: Model, iteration: int = 0) -> None: + self.trainer.checkpointer.finalize() + self.train_pbar.close() + + +class IterationLoggerCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + self.start_iteration_time = time.time() + self.elapsed_iteration_time = 0 + + @distributed.rank0_only + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + self.start_iteration_time = time.time() + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.elapsed_iteration_time += time.time() - self.start_iteration_time + + if iteration % self.config.trainer.logging_iter == 0: + avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter + log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") + + self.elapsed_iteration_time = 0 + + +class GradClipCallback(Callback): + """The callback class for gradient clipping.""" + + def __init__( + self, + config: Optional["Config"] = None, + trainer: Optional["Trainer"] = None, + grad_clip_norm: float = 1.0, + ): + super().__init__(config, trainer) + self.grad_clip_norm = grad_clip_norm + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) + + +class LowPrecisionCallback(Callback): + """The callback class handling low precision training""" + + def __init__(self, update_iter: int, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + super().__init__(config, trainer) + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + if iteration % self.update_iter == 0: + if getattr(optimizer, "master_weights", False): + params, master_params = [], [] + for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): + for p, p_master in zip(group["params"], group_master["params"]): + params.append(get_local_tensor_if_DTensor(p.data)) + master_params.append(p_master.data) + torch._foreach_copy_(params, master_params) diff --git a/cosmos_predict1/utils/callbacks/grad_clip.py b/cosmos_predict1/utils/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f320b6f79e1e289117d8190b5f6df52cf64ae --- /dev/null +++ b/cosmos_predict1/utils/callbacks/grad_clip.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callback import Callback + + +@torch.jit.script +def _fused_nan_to_num(params: List[torch.Tensor]): + for param in params: + torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) + + +class GradClip(Callback): + def __init__( + self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False + ): + self.clip_norm = clip_norm + self.force_finite = force_finite + self.model_key = model_key + self.fsdp_enabled = fsdp_enabled + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + + # select sub-network if specified + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + + if self.force_finite: + params = [] + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + # check if FSDP is used + # total_norm + if isinstance(model, FSDP) and self.fsdp_enabled: + model.clip_grad_norm_(self.clip_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) diff --git a/cosmos_predict1/utils/checkpointer.py b/cosmos_predict1/utils/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..91142e0bc9ded0c3d6e6d834c824db6e5551738e --- /dev/null +++ b/cosmos_predict1/utils/checkpointer.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import TYPE_CHECKING + +import torch + +from cosmos_predict1.utils import callback, distributed, log, misc +from cosmos_predict1.utils.model import Model + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import CheckpointConfig, JobConfig + + +class Checkpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = f"iter_{iteration:09}.pt" + + if distributed.get_rank() == 0: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + only_resume_scheduler = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + only_resume_scheduler = self.only_load_scheduler_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + only_resume_scheduler = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + if "model" in state_dict: + model.load_state_dict(state_dict["model"], strict=self.strict_resume) + else: + model.load_state_dict(state_dict, strict=self.strict_resume) + if resume or only_resume_scheduler: + iteration = state_dict["iteration"] + assert scheduler + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + else: + iteration = 0 + if resume: + assert optimizer + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() diff --git a/cosmos_predict1/utils/config.py b/cosmos_predict1/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08db8872dd9c3c9f821ae68ff1d1544fe6cc90 --- /dev/null +++ b/cosmos_predict1/utils/config.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional, Type, TypeVar, Union + +import attrs +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.utils import callback +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.misc import Color + +T = TypeVar("T") + + +def _is_attrs_instance(obj: object) -> bool: + """ + Helper function to check if an object is an instance of an attrs-defined class. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs-defined class, False otherwise. + """ + return hasattr(obj, "__attrs_attrs__") + + +def make_freezable(cls: T) -> T: + """ + A decorator that adds the capability to freeze instances of an attrs-defined class. + + NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need + to hack on a "_is_frozen" attribute. + + This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. + Once an instance is frozen, its attributes cannot be changed. It also recursively freezes + any attrs-defined objects that are attributes of the class. + + Usage: + @make_freezable + @attrs.define(slots=False) + class MyClass: + attribute1: int + attribute2: str + + obj = MyClass(1, 'a') + obj.freeze() # Freeze the instance + obj.attribute1 = 2 # Raises AttributeError + + Args: + cls: The class to be decorated. + + Returns: + The decorated class with added freezing capability. + """ + + if not hasattr(cls, "__dict__"): + raise TypeError( + "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " + "class was defined with `@attrs.define(slots=False)`" + ) + + original_setattr = cls.__setattr__ + + def setattr_override(self, key, value) -> None: # noqa: ANN001 + """ + Override __setattr__ to allow modifications during initialization + and prevent modifications once the instance is frozen. + """ + if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": + raise AttributeError("Cannot modify frozen instance") + original_setattr(self, key, value) # type: ignore + + cls.__setattr__ = setattr_override # type: ignore + + def freeze(self: object) -> None: + """ + Freeze the instance and all its attrs-defined attributes. + """ + for _, value in attrs.asdict(self, recurse=False).items(): + if _is_attrs_instance(value) and hasattr(value, "freeze"): + value.freeze() + self._is_frozen = True # type: ignore + + cls.freeze = freeze # type: ignore + + return cls + + +def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: + """ + Recursively pretty prints attrs objects with color. + """ + + assert attrs.has(obj.__class__) + + lines: list[str] = [] + for attribute in attrs.fields(obj.__class__): + value = getattr(obj, attribute.name) + if attrs.has(value.__class__): + if use_color: + lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") + else: + lines.append(" " * indent + "* " + attribute.name + ":") + lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) + else: + if use_color: + lines.append( + " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) + ) + else: + lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) + return "\n".join(lines) + + +def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str: + """ + Pretty prints overrides. + """ + + lines: list[str] = [] + lines.append(Color.cyan("* ") + Color.green("overrides") + ": ") + for override in overrides: + if override == "--": + continue + attribute_name, attribute_value = override.split("=") + if use_color: + lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value)) + else: + lines.append(" " + "* " + attribute_name + ": " + str(attribute_value)) + + return "\n".join(lines) + + +@make_freezable +@attrs.define(slots=False) +class JobConfig: + # Project name. + project: str = "" + # Experiment name. + group: str = "" + # Run/job name. + name: str = "" + + @property + def path(self) -> str: + return f"{self.project}/{self.group}/{self.name}" + + @property + def path_local(self) -> str: + local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") + return f"{local_root}/{self.path}" + + +@make_freezable +@attrs.define(slots=False) +class EMAConfig: + # Enable tracking a set of exponential moving average (EMA) weights. + enabled: bool = False + # EMA decay rate. + beta: float = 0.9999 + # Enable removing "_orig_mod-" from buffer names that is added by torch.compile + torch_compile_buffer_renaming: bool = False + + +@make_freezable +@attrs.define(slots=False) +class DDPConfig: + # Traverse the computation graph to find parameters that don't receive gradients. + find_unused_parameters: bool = False + # Set to True if the computation graph does not change during the whole training loop. + static_graph: bool = True + # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. + broadcast_buffers: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CuDNNConfig: + # Set to True for better reproducibility of the results (only using deterministic cudnn functions). + deterministic: bool = False + # If set to True, cudnn will benchmark several algorithms and pick the fastest one. + benchmark: bool = True + + +@make_freezable +@attrs.define(slots=False) +class JITConfig: + # Enable exporting a JIT compiled model. + enabled: bool = False + # Input tensor shape, for example input. + input_shape: Union[list[int], None] = None + # Device to compile onto. + device: str = "cuda" + # # Data type to compile onto. + dtype: str = "bfloat16" + # Strict mode for PyTorch JIT. + strict: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig: + # possible checkpoint class + type: Optional[Dict] = None + # for dcp, whether to use async mode + dcp_async_mode_enabled: bool = False + # Save the checkpoint every N iterations. + save_iter: int = 999999999 + # Path of model weights to resume the checkpoint from. + load_path: str = "" + # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. + load_training_state: bool = False + # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. + only_load_scheduler_state: bool = False + # Load state_dict to the models in strict mode. + strict_resume: bool = True + # Print detailed information during checkpoint saving/loading. + verbose: bool = True + # Configs for JIT compiling EMA model. + jit: JITConfig = attrs.field(factory=JITConfig) + # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] + keys_not_to_resume: list[str] = [] + # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). + broadcast_via_filesystem: bool = False + load_ema_to_reg: bool = False + async_saving: bool = True + + +@make_freezable +@attrs.define(slots=False) +class TrainerConfig: + from cosmos_predict1.utils.trainer import Trainer + + type: Type[Trainer] = Trainer + # Set the callback class. + # Defaults to the callbacks below. + callbacks: LazyDict = LazyDict( + dict( + ema=L(callback.EMAModelCallback)(), + progress_bar=L(callback.ProgressBarCallback)(), + ) + ) + # distributed parallelism strategy + distributed_parallelism: str = "ddp" + # Distributed data parallel configs. + ddp: DDPConfig = attrs.field(factory=DDPConfig) + # cuDNN configs. + cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) + # Set the random seed. + seed: int = 0 + # Gradient scaler arguments (for torch.amp.GradScaler). + grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) + # Maximum number of iterations to train the model. + max_iter: int = 999999999 + # Maximum number of iterations to validate the model. If None, validate on the entire dataset. + max_val_iter: int | None = None + # How often we log the training stats. + logging_iter: int = 100 + # Whether we want to run the validation routines. + run_validation: bool = True + # How often we evaluate on the validation set. + validation_iter: int = 999999999 + # Kill the process after N seconds since the last iteration (usually means dead job). + timeout_period: int = 999999999 + # Tensor memory organization format. + memory_format: torch.memory_format = torch.preserve_format + # Gradient accumulation (update step every N iteration). + grad_accum_iter: int = 1 + # # Profiling config + # profiling: Profiling = attrs.field(factory=Profiling) + + +@make_freezable +@attrs.define(slots=False) +class Config: + """Config for a job. + + See /README.md/Configuration System for more info. + """ + + # Model configs. + model: LazyDict + # Optimizer configs. + optimizer: LazyDict = LazyDict(dict(dummy=None)) + # Scheduler configs. + scheduler: LazyDict = LazyDict(dict(dummy=None)) + # Training data configs. + dataloader_train: LazyDict = LazyDict(dict(dummy=None)) + # Validation data configs. + dataloader_val: LazyDict = LazyDict(dict(dummy=None)) + + # Training job configs. + job: JobConfig = attrs.field(factory=JobConfig) + + # Trainer configs. + trainer: TrainerConfig = attrs.field(factory=TrainerConfig) + + # Megatron-Core configs + model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) + + # Checkpointer configs. + checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) + + def pretty_print(self, use_color: bool = False) -> str: + return _pretty_print_attrs_instance(self, 0, use_color) + + # Training job configs. + job: JobConfig = attrs.field(factory=JobConfig) + + def to_dict(self) -> dict[str, Any]: + return attrs.asdict(self) + + def validate(self) -> None: + """Validate that the config has all required fields.""" + assert self.job.project != "", "Project name is required." + assert self.job.group != "", "Group name is required." + assert self.job.name != "", "Job name is required." diff --git a/cosmos_predict1/utils/config_helper.py b/cosmos_predict1/utils/config_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1d21934aa68fdb6bcae35777c38a8dff644d2 --- /dev/null +++ b/cosmos_predict1/utils/config_helper.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +import pkgutil +import sys +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass +from typing import Any, Dict, Optional + +import attr +import attrs +from hydra import compose, initialize +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig, OmegaConf + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import Config + + +def is_attrs_or_dataclass(obj) -> bool: + """ + Check if the object is an instance of an attrs class or a dataclass. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. + """ + return is_dataclass(obj) or attr.has(type(obj)) + + +def get_fields(obj): + """ + Get the fields of an attrs class or a dataclass. + + Args: + obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. + + Returns: + list: A list of field names. + + Raises: + ValueError: If the object is neither an attrs class nor a dataclass. + """ + if is_dataclass(obj): + return [field.name for field in dataclass_fields(obj)] + elif attr.has(type(obj)): + return [field.name for field in attr.fields(type(obj))] + else: + raise ValueError("The object is neither an attrs class nor a dataclass.") + + +def override(config: Config, overrides: Optional[list[str]] = None) -> Config: + """ + :param config: the instance of class `Config` (usually from `make_config`) + :param overrides: list of overrides for config + :return: the composed instance of class `Config` + """ + # Store the class of the config for reconstruction after overriding. + # config_class = type(config) + + # Convert Config object to a DictConfig object + config_dict = attrs.asdict(config) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + # Enforce "--" separator between the script arguments and overriding configs. + if overrides: + if overrides[0] != "--": + raise ValueError('Hydra config overrides must be separated with a "--" token.') + overrides = overrides[1:] + # Use Hydra to handle overrides + cs = ConfigStore.instance() + cs.store(name="config", node=config_omegaconf) + with initialize(version_base=None): + config_omegaconf = compose(config_name="config", overrides=overrides) + OmegaConf.resolve(config_omegaconf) + + def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: + """ + Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data + + Args: + ref_instance: The reference instance to determine the type and fields when needed + kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data + + Returns: + Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data + + Raises: + AssertionError: If the fields do not match or if extra keys are found. + Exception: If there is an error constructing the new instance. + """ + is_type = is_attrs_or_dataclass(ref_instance) + if not is_type: + return kwargs + else: + ref_fields = set(get_fields(ref_instance)) + assert isinstance(kwargs, dict) or isinstance( + kwargs, DictConfig + ), "kwargs must be a dictionary or a DictConfig" + keys = set(kwargs.keys()) + + # ref_fields must equal to or include all keys + extra_keys = keys - ref_fields + assert ref_fields == keys or keys.issubset( + ref_fields + ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" + + resolved_kwargs: Dict[str, Any] = {} + for f in keys: + resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) + try: + new_instance = type(ref_instance)(**resolved_kwargs) + except Exception as e: + log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") + log.error(e) + raise e + return new_instance + + config = config_from_dict(config, config_omegaconf) + + return config + + +def get_config_module(config_file: str) -> str: + if not config_file.endswith(".py"): + log.error("Config file cannot be specified as module.") + log.error("Please provide the path to the Python config file (relative to the Cosmos root).") + + cosmos_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + assert os.path.isfile(config_file) or os.path.isfile(os.path.join(cosmos_root, config_file)), \ + f"Cosmos config file ({config_file}) not found." + + # Convert to importable module format. + config_module = config_file.replace("/", ".").replace(".py", "") + return config_module + + +def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: + """ + Import all modules from the specified package path recursively. + + This function is typically used in conjunction with Hydra to ensure that all modules + within a specified package are imported, which is necessary for registering configurations. + + Example usage: + ```python + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True, skip_underscore=False) + ``` + + Args: + package_path (str): The dotted path to the package from which to import all modules. + reload (bool): Flag to determine whether to reload modules if they're already imported. + skip_underscore (bool): If True, skips importing modules that start with an underscore. + """ + log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") + package = importlib.import_module(package_path) + package_directory = package.__path__ + + def import_modules_recursively(directory: str, prefix: str) -> None: + """ + Recursively imports or reloads all modules in the given directory. + + Args: + directory (str): The file system path to the current package directory. + prefix (str): The module prefix (e.g., 'models.diffusion.config'). + """ + for _, module_name, is_pkg in pkgutil.iter_modules([directory]): + if skip_underscore and module_name.startswith("_"): + log.debug(f"Skipping module {module_name} as it starts with an underscore") + continue + + full_module_name = f"{prefix}.{module_name}" + log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") + + if full_module_name in sys.modules and reload: + importlib.reload(sys.modules[full_module_name]) + else: + importlib.import_module(full_module_name) + + if is_pkg: + sub_package_directory = os.path.join(directory, module_name) + import_modules_recursively(sub_package_directory, full_module_name) + + for directory in package_directory: + import_modules_recursively(directory, package_path) diff --git a/cosmos_predict1/utils/ddp_checkpointer.py b/cosmos_predict1/utils/ddp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..593a5c932c7f274973bbc390cf2386c1fa59df2f --- /dev/null +++ b/cosmos_predict1/utils/ddp_checkpointer.py @@ -0,0 +1,436 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from collections import namedtuple +from typing import Any, Dict, Optional, Set, Tuple, Union + +import torch +import torch.distributed +from megatron.core import parallel_state +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.base import AbstractCheckpointer +from cosmos_predict1.utils.checkpointer.safe_broadcast import broadcast_object +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + +StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) + + +class Checkpointer(AbstractCheckpointer): + """ + Checkpointer for DDP. + Note: This implementation only supports local filesystem. + """ + + KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] + KEYS_TO_POSTFIX = { + "model": "model", + "optim": "optim", + "scheduler": "scheduler", + "trainer": "", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + ep_world_size = parallel_state.get_expert_model_parallel_world_size() + assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." + assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." + self.mp_world_size = parallel_state.get_model_parallel_group().size() + if self.mp_world_size > 1 and self.__class__ == Checkpointer: + raise NotImplementedError( + "Model Parallelism (MP) is enabled - " + "you should use TensorParallel Checkpointer instead of DDP Checkpointer." + ) + # DDP rank (with context parallelism considered) + self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) + # Context parallelism rank + self.cp_rank = parallel_state.get_context_parallel_rank() + # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) + self.mp_rank = parallel_state.get_model_parallel_group().rank() + # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() + if self.broadcast_via_filesystem: + log.info("Broadcasting checkpoint data via the local filesystem.") + if not self.strict_resume: + log.warning("Strict resume mode is off. Some model parameters may not be loaded.") + + # collect ranks of all model parallel groups + all_ranks = [None for _ in range(distributed.get_world_size())] + torch.distributed.all_gather_object( + all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) + ) + all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) + for ranks in all_ranks: + group = torch.distributed.new_group(list(ranks), backend="gloo") + if distributed.get_rank() in ranks: + self.mp_gloo_pg = group + + self.print("Checkpointer Initialized.") + + def print(self, message: str): + """ + Print message to the console. Include the parallelism rank information when verbose is set to True. + """ + if self.verbose: + log.info( + f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", + rank0_only=False, + ) + else: + log.info(message, rank0_only=True) + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + del model + assert key in self.KEYS_TO_SAVE + post_fix = self.KEYS_TO_POSTFIX[key] + + if post_fix: + _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") + else: + _ckpt_path = checkpoint_path + return _ckpt_path + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = self.format_checkpoint_filename(model, iteration) + state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) + state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) + if state_dict: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: + new_dict = {} + for key, _state_dict in state_dict.items(): + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) + checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) + new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) + return new_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). + + Args: + state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=True, # optional for fast backend, cpu heavy + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) + + def format_checkpoint_filename(self, model: Model, iteration: int) -> str: + """Generate the checkpoint file name. + + Args: + iteration (int): The current iteration number. + + Returns: + checkpoint_file (str): The checkpoint file name. + """ + del self, model + return f"iter_{iteration:09}.pt" + + @misc.timer("generate saving state dict") + def generate_save_state_dict( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> Optional[Dict[str, Any]]: + state_dict = {} + + if self.rank_dp_w_cp == 0: + trainer_state = dict( + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + model_state = model.state_dict() + optim_state = optimizer.state_dict() + scheduler_state = scheduler.state_dict() + self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) + + trainer_state, model_state, optim_state, scheduler_state = misc.to( + [trainer_state, model_state, optim_state, scheduler_state], device="cpu" + ) + + state_dict = { + "model": model_state, + "optim": optim_state, + "scheduler": scheduler_state, + } + if distributed.get_rank() == 0: # only rank 0 saves trainer state + state_dict["trainer"] = trainer_state + return state_dict + return state_dict + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast. + + The main steps are: + 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. + + This approach ensures that each MP rank loads its specific part of the model, which is + crucial for Model Parallelism where different parts of the model are distributed across + multiple GPUs. + + When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can + be set to True. This allows each rank to load its specific checkpoint from the local filesystem + instead of receiving it via network broadcast, which could be more efficient in some cases. + + For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = easy_io.load(local_cache_path, fast_backend=True) + else: + _state_dict = easy_io.load(_ckpt_path, fast_backend=True) + self.print(f"Downloading checkpoint from: {_ckpt_path}") + if self.broadcast_via_filesystem: + # Save the checkpoint to the local filesystem + easy_io.dump(_state_dict, local_cache_path, fast_backend=True) + state_dict[key] = _state_dict + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) + else: + # Broadcast the checkpoint to all GPUs of the current DDP rank + group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) + min_rank = min(get_process_group_ranks(group)) + + _state_dict = broadcast_object( + state_dict[key] if self.rank_dp_w_cp == 0 else None, + min_rank, + group=group, + device=torch.device(torch.cuda.current_device()), + ) + if self.rank_dp_w_cp == 0: + self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') + else: + state_dict[key] = _state_dict + self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') + + return state_dict + + def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: + latest_checkpoint_file = self._read_latest_checkpoint_file() + + resume_keys = [] + + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) + resume_keys.extend(self.KEYS_TO_SAVE) + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + if self.load_training_state: + resume_keys.extend(self.KEYS_TO_SAVE) + else: + resume_keys.append("model") + if self.only_load_scheduler_state: + resume_keys.append("scheduler") + else: + checkpoint_path = None + if len(self.keys_not_to_resume) > 0: + for key in self.keys_not_to_resume: + assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" + resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] + return set(resume_keys), checkpoint_path + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + resume_keys, checkpoint_path = self.keys_to_resume_during_load() + + iteration = 0 + + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) + + if "trainer" in state_dict: + trainer_state = state_dict["trainer"] + log.critical(state_dict.keys(), rank0_only=False) + log.critical(trainer_state, rank0_only=False) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(trainer_state["grad_scaler"]) + self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) + iteration = trainer_state["iteration"] + if "optim" in state_dict: + assert optimizer + optimizer_state = state_dict["optim"] + log.info("- Loading the optimizer...") + optimizer.load_state_dict(optimizer_state) + if "scheduler" in state_dict: + assert scheduler + scheduler_state = state_dict["scheduler"] + log.info("- Loading the scheduler...") + scheduler.load_state_dict(scheduler_state) + scheduler.last_epoch = iteration + if "model" in state_dict: + model_state = state_dict["model"] + log.info("- Loading the model...") + # model.load_state_dict(model_state) + if self.strict_resume: + log.info("\t Strict resume mode is on.") + else: + log.info("\t Strict resume mode is off.") + model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) + log.info(f"\t {model_load_info}") + self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: + """Write json file to save number of seen samples and number of iterations. + + Args: + checkpoint_file (str): iteration number for the saved checkpoint + trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. + """ + # filename: iter_xxxxxxxxx_trained_data_record.json + checkpoint_path = os.path.join( + self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" + ) + easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos_predict1/utils/device.py b/cosmos_predict1/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..db486afabd4ae0bf11feb05d8a4efd96690ce64b --- /dev/null +++ b/cosmos_predict1/utils/device.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os + +import pynvml + + +class Device: + """A class to handle NVIDIA GPU device operations using NVML. + + This class provides an interface to access and manage NVIDIA GPU devices, + including retrieving device information and CPU affinity settings. + + Attributes: + _nvml_affinity_elements (int): Number of 64-bit elements needed to represent CPU affinity + """ + + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore + + def __init__(self, device_idx: int): + """Initialize a Device instance for a specific GPU. + + Args: + device_idx (int): Index of the GPU device to manage + + Raises: + NVMLError: If the device cannot be found or initialized + """ + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def get_cpu_affinity(self) -> list[int]: + """Get the CPU affinity mask for this GPU device. + + Retrieves the CPU affinity mask indicating which CPU cores are assigned + to this GPU device. The affinity is returned as a list of CPU core indices. + + Returns: + list[int]: List of CPU core indices that have affinity with this GPU + + Raises: + NVMLError: If the CPU affinity information cannot be retrieved + + Example: + >>> device = Device(0) + >>> device.get_cpu_affinity() + [0, 1, 2, 3] # Shows this GPU has affinity with CPU cores 0-3 + """ + affinity_string = "" + for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): + # assume nvml returns list of 64 bit ints + affinity_string = "{:064b}".format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + return [i for i, e in enumerate(affinity_list) if e != 0] diff --git a/cosmos_predict1/utils/distributed.py b/cosmos_predict1/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..b827c3bacab1e093dbb586bf1dcffe86ae5fa825 --- /dev/null +++ b/cosmos_predict1/utils/distributed.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import ctypes +import functools +import os +from contextlib import contextmanager +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Callable, Container, Optional + +import pynvml +import torch +import torch.distributed as dist +from torch.distributed import get_process_group_ranks + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.device import Device + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import DDPConfig + +if dist.is_available(): + from torch.distributed.distributed_c10d import _get_default_group + from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes + + +try: + from megatron.core import parallel_state +except ImportError: + print("Megatron-core is not installed.") + + +def init() -> int | None: + """Initialize distributed training.""" + # Set GPU affinity. + pynvml.nvmlInit() + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = Device(local_rank) + # os.sched_setaffinity(0, device.get_cpu_affinity()) + # Set up NCCL communication. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + if dist.is_available(): + if dist.is_initialized(): + return torch.cuda.current_device() + torch.cuda.set_device(local_rank) + # Get the timeout value from environment variable + timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) + # Convert the timeout to an integer (if it isn't already) and then to a timedelta + timeout_timedelta = timedelta(seconds=int(timeout_seconds)) + dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) + log.critical( + f"Initialized distributed program with local rank {local_rank} with timeout {timeout_seconds}", + rank0_only=False, + ) + # Increase the L2 fetch granularity for faster speed. + _libcudart = ctypes.CDLL("libcudart.so") + # Set device limit on the current device. + p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) + log.info(f"Running with {get_world_size()} GPUs.") + + +def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: + """Get world size. How many GPUs are available in this job. + + Returns: + world_size (int): The total number of GPUs available in this job. + """ + world_size = 1 + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size(group) + return world_size + + +def is_rank0() -> bool: + """Check if current process is the master GPU. + + Returns: + (bool): True if this function is called from the master GPU, else False. + """ + return get_rank() == 0 + + +def is_local_rank0() -> bool: + """Check if current process is the local master GPU in the current node. + + Returns: + (bool): True if this function is called from the local master GPU, else False. + """ + return torch.cuda.current_device() == 0 + + +def device_with_rank(device: str) -> str: + """If the device is 'cuda' and parallelism over GPUs is enabled, returns + Otherwise, returns the device as-is.""" + if device == 'cuda': + return f'cuda:{get_rank()}' + return device + + +def rank0_only(func: Callable) -> Callable: + """Apply this function only to the master GPU. + + Example usage: + @rank0_only + def func(x): + return x + 3 + + Args: + func (Callable): a function. + + Returns: + (Callable): A function wrapper executing the function only on the master GPU. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + return func(*args, **kwargs) + else: + return None + + return wrapper + + +def barrier() -> None: + """Barrier for all GPUs.""" + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def rank0_first(func: Callable) -> Callable: + """run the function on rank 0 first, then on other ranks.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + result = func(*args, **kwargs) + barrier() + if not is_rank0(): + result = func(*args, **kwargs) + return result + + return wrapper + + +def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: + """Wraps the model to enable data parallalism for training across multiple GPU devices. + + Args: + config_ddp (DDPConfig): The data parallel config. + model (torch.nn.Module): The PyTorch module. + + Returns: + model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper + if distributed environment is available, otherwise return the original model. + """ + if dist.is_available() and dist.is_initialized(): + local_rank = int(os.getenv("LOCAL_RANK", 0)) + try: + ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + except Exception as e: + log.info(e) + log.info("parallel_state not initialized, treating all GPUs equally for DDP") + ddp_group = None + + model = DistributedDataParallel( + model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=config_ddp.find_unused_parameters, + static_graph=config_ddp.static_graph, + broadcast_buffers=config_ddp.broadcast_buffers, + process_group=ddp_group, + ) + return model + + +class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): + """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). + + This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that + model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling + model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> + training_step), allowing us to preserve the function names and signatures. + """ + + def __init__(self, model: torch.nn.Module, *args, **kwargs): + super().__init__(model, *args, **kwargs) + self.show_sync_grad_static_graph_warning = True + + def training_step(self, *args, **kwargs) -> Any: + # Cache the original model.forward() method. + original_forward = self.module.forward + + def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 + # Unpatch immediately before calling training_step() because itself may want to call the real forward. + self.module.forward = original_forward + # The actual .training_step(). + return self.module.training_step(*_args, **_kwargs) + + # Patch the original_module's forward so we can redirect the arguments back to the real method. + self.module.forward = wrapped_training_step + # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). + # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. + return self(*args, **kwargs) + + +@contextmanager +def ddp_sync_grad(model, enabled): + r""" + Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. + Modified from: + https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync + Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. + + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + assert isinstance(model, torch.nn.Module) + if isinstance(model, DistributedDataParallel): + old_require_backward_grad_sync = model.require_backward_grad_sync + if model.static_graph and model.require_backward_grad_sync != enabled: + if model.show_sync_grad_static_graph_warning: + log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") + model.show_sync_grad_static_graph_warning = False + else: + model.require_backward_grad_sync = enabled + try: + yield + finally: + if isinstance(model, DistributedDataParallel): + model.require_backward_grad_sync = old_require_backward_grad_sync + + +def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: + """Aggregate the list of data batches from all devices and process the results. + + This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. + It will return the data/output of the entire validation set in its original index order. The sizes of data_batches + in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be + created before calling dis.all_gather(). + + Args: + data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where + leaf entries are tensors. + + Returns: + data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where + leaf entries are concatenated tensors. + """ + if isinstance(data_batches[0], torch.Tensor): + # Concatenate the local data batches. + data_concat = torch.cat(data_batches, dim=0) # type: ignore + # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. + max_num_local_samples = torch.tensor(len(data_concat), device="cuda") + dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) + if len(data_concat) < max_num_local_samples: + assert len(data_concat) + 1 == max_num_local_samples + dummy = torch.empty_like(data_concat[:1]) + data_concat = torch.cat([data_concat, dummy], dim=0) + dummy_count = torch.tensor(1, device="cuda") + else: + dummy_count = torch.tensor(0, device="cuda") + # Get all concatenated batches from all ranks and concatenate again. + dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) + data_concat = all_gather_tensor(data_concat.contiguous()) + data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) + # Remove the dummy samples. + if dummy_count > 0: + data_collate = data_collate[:-dummy_count] + elif isinstance(data_batches[0], collections.abc.Mapping): + data_collate = dict() + for key in data_batches[0].keys(): + data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore + else: + raise TypeError + return data_collate + + +@torch.no_grad() +def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: + """Gather the corresponding tensor from all GPU devices to a list. + + Args: + tensor (torch.Tensor): Pytorch tensor. + + Returns: + tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. + """ + tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] + dist.all_gather(tensor_list, tensor) + return tensor_list + + +def broadcast(tensor, src, group=None, async_op=False): + world_size = get_world_size() + if world_size < 2: + return tensor + dist.broadcast(tensor, src=src, group=group, async_op=async_op) + + +def sync_model_states( + model: torch.nn.Module, + process_group: Optional[dist.ProcessGroup] = None, + src: int = 0, + params_and_buffers_to_ignore: Optional[Container[str]] = None, + broadcast_buffers: bool = True, +): + """ + Modify based on DDP source code + Synchronizes the parameters and buffers of a model across different processes in a distributed setting. + + This function ensures that all processes in the specified process group have the same initial parameters and + buffers from the source rank, typically rank 0. It is useful when different processes start with different model + states and a synchronization is required to ensure consistency across all ranks. + + Args: + model (nn.Module): The model whose parameters and buffers are to be synchronized. + process_group (dist.ProcessGroup, optional): The process group for communication. If None, + the default group is used. Defaults to None. + src (int, optional): The source rank from which parameters and buffers will be broadcasted. + Defaults to 0. + params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer + names to exclude from synchronization. Defaults to None, which means all parameters and buffers are + included. + broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True. + + Side Effects: + This function modifies the state of the model in-place to synchronize it with the source rank's model state. + + Raises: + RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised. + + Examples: + >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth + >>> # useful and save our time when model weights are huge + >>> if dist.get_rank == 0: + >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path)) + >>> dist.barrir() + >>> sync_model_states(model) # sync rank0 weights to other ranks + """ + if process_group is None: + process_group = _get_default_group() + if not params_and_buffers_to_ignore: + params_and_buffers_to_ignore = set() + + log.info( + f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}." + ) + + # Build tuple of (module, parameter) for all parameters that require grads. + modules_and_parameters = [ + (module, parameter) + for module_name, module in model.named_modules() + for parameter in [ + param + # Note that we access module.named_parameters instead of + # parameters(module). parameters(module) is only needed in the + # single-process multi device case, where it accesses replicated + # parameters through _former_parameters. + for param_name, param in module.named_parameters(recurse=False) + if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore + # if param.requires_grad + # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore + ] + ] + + # Deduplicate any parameters that might be shared across child modules. + memo = set() + modules_and_parameters = [ + # "p not in memo" is the deduplication check. + # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. + (m, p) + for m, p in modules_and_parameters + if p not in memo and not memo.add(p) # type: ignore[func-returns-value] + ] + + # Build list of parameters. + parameters = [parameter for _, parameter in modules_and_parameters] + if len(parameters) == 0: + return + + _verify_param_shape_across_processes(process_group, parameters) + + _sync_module_states( + module=model, + process_group=process_group, + broadcast_bucket_size=int(250 * 1024 * 1024), + src=src, + params_and_buffers_to_ignore=params_and_buffers_to_ignore, + broadcast_buffers=broadcast_buffers, + ) + + +def dist_reduce_tensor(tensor, rank=0, reduce="mean"): + r"""Reduce to rank 0""" + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.reduce(tensor, dst=rank) + if get_rank() == rank: + if reduce == "mean": + tensor /= world_size + elif reduce == "sum": + pass + else: + raise NotImplementedError + return tensor diff --git a/cosmos_predict1/utils/easy_io/__init__.py b/cosmos_predict1/utils/easy_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/utils/easy_io/backends/__init__.py b/cosmos_predict1/utils/easy_io/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c85c64735eb12bc6ed1e45b7681684efa0dbace --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/__init__.py @@ -0,0 +1,13 @@ +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_predict1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_predict1.utils.easy_io.backends.local_backend import LocalBackend +from cosmos_predict1.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend + +__all__ = [ + "BaseStorageBackend", + "LocalBackend", + "HTTPBackend", + "register_backend", + "backends", + "prefix_to_backends", +] diff --git a/cosmos_predict1/utils/easy_io/backends/base_backend.py b/cosmos_predict1/utils/easy_io/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2db3b921f0b6fdb3aaea867c0bb3cafdb5e59888 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/base_backend.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from abc import ABCMeta, abstractmethod + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def has_method(obj, method): + return hasattr(obj, method) and callable(getattr(obj, method)) + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: :meth:`get()` and + :meth:`get_text()`. + + - :meth:`get()` reads the file as a byte stream. + - :meth:`get_text()` reads the file as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + # This attribute will be deprecated in future. + _allow_symlink = False + + @property + def allow_symlink(self): + return self._allow_symlink + + @property + def name(self): + return self.__class__.__name__ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass diff --git a/cosmos_predict1/utils/easy_io/backends/http_backend.py b/cosmos_predict1/utils/easy_io/backends/http_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4c14c481c91e8c551ca898237bae39229ecd82 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/http_backend.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Union +from urllib.request import urlopen + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath: str) -> bytes: + """Read bytes from a given ``filepath``. + + Args: + filepath (str): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get('http://path/of/file') + b'hello world' + """ + return urlopen(filepath).read() + + def get_text(self, filepath, encoding="utf-8") -> str: + """Read text from a given ``filepath``. + + Args: + filepath (str): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_text('http://path/of/file') + 'hello world' + """ + return urlopen(filepath).read().decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with backend.get_local_path('http://path/of/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) diff --git a/cosmos_predict1/utils/easy_io/backends/local_backend.py b/cosmos_predict1/utils/easy_io/backends/local_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2d712bb53ddd844e20350c236dd5cfb999b60fc9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/local_backend.py @@ -0,0 +1,550 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import os.path as osp +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist + + +class LocalBackend(BaseStorageBackend): + """Raw local storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get(filepath) + b'hello world' + """ + with open(filepath, "rb") as f: + value = f.read() + return value + + def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + with open(filepath, encoding=encoding) as f: + text = f.read() + return text + + def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put(b'hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + if isinstance(obj, io.BytesIO): + obj.seek(0) + obj = obj.getvalue() + with open(filepath, "wb") as f: + f.write(obj) + + def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_text('hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, "w", encoding=encoding) as f: + f.write(obj) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.exists(filepath) + True + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/dir' + >>> backend.isdir(filepath) + True + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.isfile(filepath) + True + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + + Examples: + >>> backend = LocalBackend() + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> backend.join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Only for unified API and does nothing. + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> backend = LocalBackend() + >>> with backend.get_local_path('abc/def.jpg') as path: + ... # do something here + """ + yield filepath + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> backend.copyfile(src, dst) + '/path1/of/dir/file' + """ + return shutil.copy(src, dst) + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree(src, dst) + '/path/of/dir2' + """ + return shutil.copytree(src, dst) + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a local file src to dst and return the destination file. Same + as :meth:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_from_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_from_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. Same as + :meth:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + dst_type: Optional[str] = None, + ) -> str: + """Copy the file src to local dst and return the destination file. Same + as :meth:`copyfile`. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_to_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_to_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.remove(filepath) + """ + if not self.exists(filepath): + raise FileNotFoundError(f"filepath {filepath} does not exist") + + if self.isdir(filepath): + raise IsADirectoryError("filepath should be a file") + + os.remove(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> dir_path = '/path/of/dir' + >>> backend.rmtree(dir_path) + """ + shutil.rmtree(dir_path) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directly copy src + to dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + + Returns: + bool: Return True if successfully create a symbolic link pointing + to src. Otherwise, return False. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> backend.copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> backend.copy_if_symlink_fails(src, dst) + True + """ + try: + os.symlink(src, dst) + return True + except Exception: + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file( + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = LocalBackend() + >>> dir_path = '/path/of/dir' + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ # noqa: E501 + if list_dir and suffix is not None: + raise TypeError("`suffix` should be None when `list_dir` is True") + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError("`suffix` must be a string or tuple of strings") + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_predict1/utils/easy_io/backends/registry_utils.py b/cosmos_predict1/utils/easy_io/backends/registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd70c3e548455f362b36cb9803693fa1ab5fbdbe --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/registry_utils.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional, Type, Union + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_predict1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_predict1.utils.easy_io.backends.local_backend import LocalBackend + +backends: dict = {} +prefix_to_backends: dict = {} + + +def _register_backend( + name: str, + backend: Type[BaseStorageBackend], + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (BaseStorageBackend): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + """ + global backends, prefix_to_backends + + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class, but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + + if name in backends and not force: + raise ValueError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + for prefix in prefixes: + if prefix in prefix_to_backends and not force: + raise ValueError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + prefix_to_backends[prefix] = backend + + +def register_backend( + name: str, + backend: Optional[Type[BaseStorageBackend]] = None, + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + + This method can be used as a normal method or a decorator. + + Examples: + + >>> class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + >>> register_backend('new', NewBackend) + + >>> @register_backend('new') + ... class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + """ + if backend is not None: + _register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + _register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + +register_backend("local", LocalBackend, prefixes="") +register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/cosmos_predict1/utils/easy_io/easy_io.py b/cosmos_predict1/utils/easy_io/easy_io.py new file mode 100644 index 0000000000000000000000000000000000000000..de7189abf9def860d77bbbc778eb76658df41a2a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/easy_io.py @@ -0,0 +1,1066 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import warnings +from contextlib import contextmanager +from io import BytesIO, StringIO +from pathlib import Path +from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends import backends, prefix_to_backends +from cosmos_predict1.utils.easy_io.file_client import FileClient +from cosmos_predict1.utils.easy_io.handlers import file_handlers + +backend_instances: dict = {} + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +def _parse_uri_prefix(uri: Union[str, Path]) -> str: + """Parse the prefix of uri. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> _parse_uri_prefix('/home/path/of/your/file') + '' + >>> _parse_uri_prefix('s3://path/of/your/file') + 's3' + >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') + 's3' + + Returns: + str: Return the prefix of uri if the uri contains '://'. Otherwise, + return ''. + """ + assert is_filepath(uri) + uri = str(uri) + # if uri does not contains '://', the uri will be handled by + # LocalBackend by default + if "://" not in uri: + return "" + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + +def _get_file_backend(prefix: str, backend_args: dict): + """Return a file backend based on the prefix or backend_args. + + Args: + prefix (str): Prefix of uri. + backend_args (dict): Arguments to instantiate the corresponding + backend. + """ + # backend name has a higher priority + if "backend" in backend_args: + # backend_args should not be modified + backend_args_bak = backend_args.copy() + backend_name = backend_args_bak.pop("backend") + backend = backends[backend_name](**backend_args_bak) + else: + backend = prefix_to_backends[prefix](**backend_args) + return backend + + +def get_file_backend( + uri: Union[str, Path, None] = None, + *, + backend_args: Optional[dict] = None, + enable_singleton: bool = False, + backend_key: Optional[str] = None, +): + """Return a file backend based on the prefix of uri or backend_args. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + enable_singleton (bool): Whether to enable the singleton pattern. + If it is True, the backend created will be reused if the + signature is same with the previous one. Defaults to False. + backend_key: str: The key to register the backend. Defaults to None. + + Returns: + BaseStorageBackend: Instantiated Backend object. + + Examples: + >>> # get file backend based on the prefix of uri + >>> uri = 's3://path/of/your/file' + >>> backend = get_file_backend(uri) + >>> # get file backend based on the backend_args + >>> backend = get_file_backend(backend_args={'backend': 's3'}) + >>> # backend name has a higher priority if 'backend' in backend_args + >>> backend = get_file_backend(uri, backend_args={'backend': 's3'}) + """ + global backend_instances + if backend_key is not None: + if backend_key in backend_instances: + return backend_instances[backend_key] + + if backend_args is None: + backend_args = {} + + if uri is None and "backend" not in backend_args and backend_key is None: + raise ValueError( + 'uri should not be None when "backend" does not exist in ' "backend_args and backend_key is None" + ) + + if uri is not None: + prefix = _parse_uri_prefix(uri) + else: + prefix = "" + + if enable_singleton: + unique_key = f"{prefix}:{json.dumps(backend_args)}" + if unique_key in backend_instances: + return backend_instances[unique_key] + + backend = _get_file_backend(prefix, backend_args) + backend_instances[unique_key] = backend + if backend_key is not None: + backend_instances[backend_key] = backend + return backend + else: + backend = _get_file_backend(prefix, backend_args) + return backend + + +def get( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> filepath = '/path/of/file' + >>> get(filepath) + b'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get(filepath) + + +def get_text( + filepath: Union[str, Path], + encoding="utf-8", + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> filepath = '/path/of/file' + >>> get_text(filepath) + 'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get_text(filepath, encoding) + + +def put( + obj: bytes, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put(b'hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put(obj, filepath) + + +def put_text( + obj: str, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + ``filepath``. Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put_text('hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put_text(obj, filepath) + + +def exists( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> exists(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.exists(filepath) + + +def isdir( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/dir' + >>> isdir(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isdir(filepath) + + +def isfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> isfile(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isfile(filepath) + + +def join_path( + filepath: Union[str, Path], + *filepaths: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + *filepaths (str or Path): Other paths to be concatenated. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The result of concatenation. + + Examples: + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.join_path(filepath, *filepaths) + + +@contextmanager +def get_local_path( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself and it will + not be released (removed). + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: Only yield one path. + + Examples: + >>> with get_local_path('abc/def.jpg') as path: + ... # do something here + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + with backend.get_local_path(str(filepath)) as local_path: + yield local_path + + +def copyfile( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError will + be raised. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> copyfile(src, dst) + '/path1/of/dir/file' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile(src, dst) + + +def copytree( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will be + raised. + + Examples: + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> copytree(src, dst) + '/path/of/dir2' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree(src, dst) + + +def copyfile_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a local file src to dst and return the destination file. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = 's3://openmmlab/mmengine/file1' + >>> # src will be copied to 's3://openmmlab/mmengine/file1' + >>> copyfile_from_local(src, dst) + s3://openmmlab/mmengine/file1 + + >>> # dst is a directory + >>> dst = 's3://openmmlab/mmengine' + >>> # src will be copied to 's3://openmmlab/mmengine/file'' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/file' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_from_local(src, dst) + + +def copytree_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = '/path/of/dir' + >>> dst = 's3://openmmlab/mmengine/dir' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/dir' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_from_local(src, dst) + + +def copyfile_to_local( + src: Union[str, Path], + dst: Union[str, Path], + dst_type: str, # Choose from ["file", "dir"] + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = 's3://openmmlab/mmengine/file' + >>> dst = '/path/of/file' + >>> # src will be copied to '/path/of/file' + >>> copyfile_to_local(src, dst) + '/path/of/file' + + >>> # dst is a directory + >>> dst = '/path/of/dir' + >>> # src will be copied to '/path/of/dir/file' + >>> copyfile_to_local(src, dst) + '/path/of/dir/file' + """ + assert dst_type in ["file", "dir"] + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_to_local(src, dst, dst_type=dst_type) + + +def copytree_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = 's3://openmmlab/mmengine/dir' + >>> dst = '/path/of/dir' + >>> copytree_to_local(src, dst) + '/path/of/dir' + """ + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_to_local(src, dst) + + +def remove( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> filepath = '/path/of/file' + >>> remove(filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.remove(filepath) + + +def rmtree( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> dir_path = '/path/of/dir' + >>> rmtree(dir_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.rmtree(dir_path) + + +def copy_if_symlink_fails( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directory copy src to + dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return True if successfully create a symbolic link pointing to + src. Otherwise, return False. + + Examples: + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> copy_if_symlink_fails(src, dst) + True + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copy_if_symlink_fails(src, dst) + + +def list_dir( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +): + """List all folders in an S3 bucket with a given prefix. + + Args: + dir_path (str | Path): Path of the directory. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir(dir_path): + ... print(file_path) + """ + if not dir_path.endswith("/"): + dir_path += "/" + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + + return backend.list_dir(dir_path) + + +def list_dir_or_file( + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # list those files and directories in current directory + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) + + +def generate_presigned_url( + url: str, + client_method: str = "get_object", + expires_in: int = 3600, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on s3 backend. + + Note: + Now only work on s3 backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Defaults to 'get_object'. + expires_in (int): expires, in seconds. Defaults to 3600. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Generated presigned url. + """ + backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.generate_presigned_url(url, client_method, expires_in) + + +def load( + file: Union[str, Path, IO[Any]], + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + ``load`` supports loading data from serialized files those can be storaged + in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in s3 + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split(".")[-1] + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO(file_backend.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + if fast_backend: + if hasattr(file_backend, "fast_get"): + with BytesIO(file_backend.fast_get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + warnings.warn( + f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get" + ) + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, "read"): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump( + obj: Any, + file: Union[str, Path, IO[Any], None] = None, + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + ``dump`` supports dumping data as strings or to files which is saved to + different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + backend_key: str: The key to register the backend. Defaults to None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or s3 + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split(".")[-1] + elif file is None: + raise ValueError("file_format must be specified since file is None") + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args" and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_backend.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + if fast_backend: + if hasattr(file_backend, "fast_put"): + file_backend.fast_put(f, file) + else: + warnings.warn("fast_backend is not supported by the backend, fallback to normal put") + file_backend.put(f, file) + else: + file_backend.put(f, file) + elif hasattr(file, "write"): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') diff --git a/cosmos_predict1/utils/easy_io/file_client.py b/cosmos_predict1/utils/easy_io/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..8c963e39515e494a9e4df3288baf68d90769d292 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/file_client.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +class HardDiskBackend(LocalBackend): + """Raw hard disks storage backend.""" + + @property + def name(self): + return self.__class__.__name__ + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Warning: + `FileClient` will be deprecated in future. Please use io functions + in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "memcached", "lmdb", "http" and "s3". Defaults to None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Defaults to None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='s3') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='s3', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='s3') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + "disk": HardDiskBackend, + "http": HTTPBackend, + } + + _prefix_to_backends: dict = { + "http": HTTPBackend, + "https": HTTPBackend, + } + + _instances: dict = {} + + client: Any + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = "disk" + if backend is not None and backend not in cls._backends: + raise ValueError( + f"Backend {backend} is not supported. Currently supported ones" f" are {list(cls._backends.keys())}" + ) + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f"prefix {prefix} is not supported. Currently supported ones " + f"are {list(cls._prefix_to_backends.keys())}" + ) + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f"{backend}:{prefix}" + for key, value in kwargs.items(): + arg_key += f":{key}:{value}" + + # if a backend was overridden, it will create a new object + if arg_key in cls._instances: + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' else + ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if "://" not in uri: + return None + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + @classmethod + def infer_client( + cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None, + ) -> "FileClient": + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Defaults to None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Defaults to None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 's3'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + if not force and name in cls._backends: + raise KeyError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + + if name in cls._backends and force: + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, cls._backends[name]): + cls._instances.pop(arg_key) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + overridden_backend = cls._prefix_to_backends[prefix] + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, overridden_backend): + cls._instances.pop(arg_key) + else: + raise KeyError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Defaults to None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Defaults to 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file( # pylint: disable=too-many-arguments + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_predict1/utils/easy_io/handlers/__init__.py b/cosmos_predict1/utils/easy_io/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2d900319026db39d87cc206881c40df9aedb97 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_predict1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_predict1.utils.easy_io.handlers.registry_utils import file_handlers, register_handler +from cosmos_predict1.utils.easy_io.handlers.yaml_handler import YamlHandler + +__all__ = [ + "BaseFileHandler", + "JsonHandler", + "PickleHandler", + "YamlHandler", + "register_handler", + "file_handlers", +] diff --git a/cosmos_predict1/utils/easy_io/handlers/base.py b/cosmos_predict1/utils/easy_io/handlers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5dcbcabc40807706eeb43d1a598571c51922a8 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/base.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode="r", **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/csv_handler.py b/cosmos_predict1/utils/easy_io/handlers/csv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..58d6493be50257de285669727f62a61372b1cf0a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/csv_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from io import StringIO + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class CsvHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + reader = csv.reader(file) + return list(reader) + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + writer = csv.writer(file) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + output = StringIO() + writer = csv.writer(output) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + return output.getvalue() diff --git a/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py b/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..205f6abb2a002438bd072dd86e21ea845c35c8bd --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import pickle +from io import BytesIO +from typing import Any + +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler + + +class GzipHandler(PickleHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="rb") as f: + return pickle.load(f) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="wb") as f: + pickle.dump(obj, f) diff --git a/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py b/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..30551d176fc842e71d91e9a62c336a60244f289a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO + +import numpy as np +import torch + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + +try: + import imageio +except ImportError: + imageio = None + + +class ImageioVideoHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + file.seek(0) + video_reader = imageio.get_reader(file, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() + + def dump_to_fileobj( + self, + obj: np.ndarray | torch.Tensor, + file: IO[bytes], + format: str = "mp4", # pylint: disable=redefined-builtin + fps: int = 17, + quality: int = 5, + **kwargs, + ): + """ + Save an array of video frames to a file-like object using imageio. + + Parameters: + obj (np.ndarray): An array of frames to be saved as video. + file (IO[bytes]): A file-like object to which the video data will be written. + format (str): Format of the video file (default 'mp4'). + fps (int): Frames per second of the output video (default 30). + + """ + if isinstance(obj, torch.Tensor): + assert obj.dtype == torch.uint8 + obj = obj.cpu().numpy() + h, w = obj.shape[1:-1] + kwargs = { + "fps": fps, + "quality": quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{w}x{h}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(file, obj, format, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/json_handler.py b/cosmos_predict1/utils/easy_io/handlers/json_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe9ffbe2aa20c4ea467bcdec29c8e7c2917c473 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/json_handler.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonHandler(BaseFileHandler): + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("default", set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("default", set_default) + return json.dumps(obj, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py b/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b30ce6b1959b05268a842409274e1251a4765672 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import IO + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonlHandler(BaseFileHandler): + """Handler for JSON lines (JSONL) files.""" + + def load_from_fileobj(self, file: IO[bytes]): + """Load JSON objects from a newline-delimited JSON (JSONL) file object. + + Returns: + A list of Python objects loaded from each JSON line. + """ + data = [] + for line in file: + line = line.strip() + if not line: + continue # skip empty lines if any + data.append(json.loads(line)) + return data + + def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) file object. + + Args: + obj: A list (or iterable) of objects to dump line by line. + """ + kwargs.setdefault("default", set_default) + for item in obj: + file.write(json.dumps(item, **kwargs) + "\n") + + def dump_to_str(self, obj, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" + kwargs.setdefault("default", set_default) + lines = [json.dumps(item, **kwargs) for item in obj] + return "\n".join(lines) + + +if __name__ == "__main__": + from cosmos_predict1.utils.easy_io import easy_io + + easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) + easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) diff --git a/cosmos_predict1/utils/easy_io/handlers/np_handler.py b/cosmos_predict1/utils/easy_io/handlers/np_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..29cb8d55c181c14a06022d97523166b3763cb753 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/np_handler.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO +from typing import IO, Any + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class NumpyHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: + """ + Load a NumPy array from a file-like object. + + Parameters: + file (IO[bytes]): The file-like object containing the NumPy array data. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return np.load(file, **kwargs) + + def load_from_path(self, filepath: str, **kwargs) -> Any: + """ + Load a NumPy array from a file path. + + Parameters: + filepath (str): The path to the file to load. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: + """ + Serialize a NumPy array to a string in binary format. + + Parameters: + obj (np.ndarray): The NumPy array to serialize. + **kwargs: Additional keyword arguments passed to `np.save`. + + Returns: + str: The serialized NumPy array as a string. + """ + with BytesIO() as f: + np.save(f, obj, **kwargs) + return f.getvalue() + + def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): + """ + Dump a NumPy array to a file-like object. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + file (IO[bytes]): The file-like object to which the array is dumped. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + np.save(file, obj, **kwargs) + + def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): + """ + Dump a NumPy array to a file path. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + filepath (str): The file path where the array should be saved. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + with open(filepath, "wb") as f: + np.save(f, obj, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py b/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..cdcac6e6eb82f8e92c79c4fef16e1f5b68dbd82c --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class PandasHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pd.read_csv(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + obj.to_csv(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError("PandasHandler does not support dumping to str") diff --git a/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py b/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..66bc10d5f2da24785cef4216e2961747c52eb756 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from io import BytesIO +from typing import Any + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("protocol", 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + kwargs.setdefault("protocol", 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + with open(filepath, "wb") as f: + pickle.dump(obj, f, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/pil_handler.py b/cosmos_predict1/utils/easy_io/handlers/pil_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a473bf486ce87df49e9f48afc6efec81abe8cbc4 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pil_handler.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO, Optional, Tuple, Union + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + +try: + from PIL import Image +except ImportError: + Image = None + + +class PILHandler(BaseFileHandler): + format: str + str_like = False + + def load_from_fileobj( + self, + file: IO[bytes], + fmt: str = "pil", + size: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): + """ + Load an image from a file-like object and return it in a specified format. + + Args: + file (IO[bytes]): A file-like object containing the image data. + fmt (str): The format to convert the image into. Options are \ + 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ + 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). + size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ + or a tuple of (width, height). If specified, the image is resized accordingly. + **kwargs: Additional keyword arguments that can be passed to conversion functions. + + Returns: + Image data in the format specified by `fmt`. + + Raises: + IOError: If the image cannot be loaded or processed. + ValueError: If the specified format is unsupported. + """ + try: + img = Image.open(file) + img.load() # Explicitly load the image data + if size is not None: + if isinstance(size, int): + size = ( + size, + size, + ) # create a tuple if only one integer is provided + img = img.resize(size, Image.ANTIALIAS) + + # Return the image in the requested format + if fmt in ["numpy", "np", "npy"]: + return np.array(img, **kwargs) + if fmt == "pil": + return img + if fmt in ["th", "torch"]: + import torch + + # Convert to tensor + img_tensor = torch.from_numpy(np.array(img, **kwargs)) + # Convert image from HxWxC to CxHxW + if img_tensor.ndim == 3: + img_tensor = img_tensor.permute(2, 0, 1) + return img_tensor + raise ValueError( + "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." + ) + except Exception as e: + raise IOError(f"Unable to load image: {e}") from e + + def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): + if "format" not in kwargs: + kwargs["format"] = self.format + kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() + obj.save(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/registry_utils.py b/cosmos_predict1/utils/easy_io/handlers/registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7edccc3d7c445d7034c028950cd823a17aef8b --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/registry_utils.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_predict1.utils.easy_io.handlers.csv_handler import CsvHandler +from cosmos_predict1.utils.easy_io.handlers.gzip_handler import GzipHandler +from cosmos_predict1.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler +from cosmos_predict1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_predict1.utils.easy_io.handlers.jsonl_handler import JsonlHandler +from cosmos_predict1.utils.easy_io.handlers.np_handler import NumpyHandler +from cosmos_predict1.utils.easy_io.handlers.pandas_handler import PandasHandler +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_predict1.utils.easy_io.handlers.pil_handler import PILHandler +from cosmos_predict1.utils.easy_io.handlers.tarfile_handler import TarHandler +from cosmos_predict1.utils.easy_io.handlers.torch_handler import TorchHandler +from cosmos_predict1.utils.easy_io.handlers.torchjit_handler import TorchJitHandler +from cosmos_predict1.utils.easy_io.handlers.txt_handler import TxtHandler +from cosmos_predict1.utils.easy_io.handlers.yaml_handler import YamlHandler + +file_handlers = { + "json": JsonHandler(), + "yaml": YamlHandler(), + "yml": YamlHandler(), + "pickle": PickleHandler(), + "pkl": PickleHandler(), + "tar": TarHandler(), + "jit": TorchJitHandler(), + "npy": NumpyHandler(), + "txt": TxtHandler(), + "csv": CsvHandler(), + "pandas": PandasHandler(), + "gz": GzipHandler(), + "jsonl": JsonlHandler(), +} + +for torch_type in ["pt", "pth", "ckpt"]: + file_handlers[torch_type] = TorchHandler() +for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: + file_handlers[img_type] = PILHandler() + file_handlers[img_type].format = img_type +for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: + file_handlers[video_type] = ImageioVideoHandler() + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") + if isinstance(file_formats, str): + file_formats = [file_formats] + if not all([isinstance(item, str) for item in file_formats]): + raise TypeError("file_formats must be a str or a list of str") + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py b/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..687e8a8a3adafeaf88b8a2644472943698dffe26 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tarfile + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TarHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, mode="r|*", **kwargs): + return tarfile.open(fileobj=file, mode=mode, **kwargs) + + def load_from_path(self, filepath, mode="r|*", **kwargs): + return tarfile.open(filepath, mode=mode, **kwargs) + + def dump_to_fileobj(self, obj, file, mode="w", **kwargs): + with tarfile.open(fileobj=file, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with tarfile.open(filepath, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/torch_handler.py b/cosmos_predict1/utils/easy_io/handlers/torch_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f64eafe59a9593c4e8aed6513092e9604faae378 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/torch_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py b/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..71598cdaf2679ed47293dcc0410be28b9b4b0a91 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchJitHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.jit.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.jit.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/txt_handler.py b/cosmos_predict1/utils/easy_io/handlers/txt_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..007d7f661d0d5ad001207a0424c35e51aea9e1a9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/txt_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TxtHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + return file.read() + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + file.write(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + return obj diff --git a/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py b/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ede0280f41d7fa29468d46141ff31e038dbcad4c --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +try: + from yaml import CDumper as Dumper # type: ignore + from yaml import CLoader as Loader # type: ignore +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault("Loader", Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("Dumper", Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("Dumper", Dumper) + return yaml.dump(obj, **kwargs) diff --git a/cosmos_predict1/utils/ema.py b/cosmos_predict1/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..7d883648a2ca969fb11e61591c01e85f8d488513 --- /dev/null +++ b/cosmos_predict1/utils/ema.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union + +import numpy as np +import torch +from megatron.core import parallel_state + +from cosmos_predict1.utils import distributed, log + +if TYPE_CHECKING: + from cosmos_predict1.utils.model import Model + + +class FastEmaModelUpdater: + """ + This class is used to update target model~(EMA) given source model~(regular model) and beta. + The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`. + Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape. + The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes. + """ + + def __init__(self): + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: + target_list = [] + source_list = [] + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + assert ( + tgt_params.dtype == torch.float32 + ), f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." + target_list.append(tgt_params) + source_list.append(src_params.data) + torch._foreach_mul_(target_list, beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) + + def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + tgt_params.data.copy_(src_params.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + +def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str: + """ + This function creates buffer name used by EMA from parameter's name + + Args: + param_name (str): Model's parameter name + Returns: + buffer_name (str): buffer name to be used for given parameter name + """ + + buffer_name = param_name.replace(".", "-") + + if torch_compile_buffer_renaming: + # torch.compile() adds _orig_mod to state dict names, this way we get original name + buffer_name = buffer_name.replace("_orig_mod-", "") + + return buffer_name + + +class EMAModelTracker(torch.nn.Module): + """This is a class to track the EMA model weights. + + The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the + regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's + implementation of EMA. There are no optimizable parameters. + + Attributes: + collected_params (list): temporarily stores the regular weights while in EMA mode. + beta (float): EMA decay rate. (default: 0.9999). + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + + def __init__(self, model: Model, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + beta (float): EMA decay rate. (default: 0.9999). + """ + super().__init__() + self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming + if not 0.0 <= beta <= 1.0: + raise ValueError("Decay must be between 0 and 1") + self.beta = beta + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + self.register_buffer(buffer_name, param.clone().detach().data) + self.collected_params = [] + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + del iteration + target_list = [] + source_list = [] + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead." + target_list.append(buffer) + source_list.append(param.data) + torch._foreach_mul_(target_list, self.beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta) + + def copy_to(self, model: Model) -> None: + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + param.data.copy_(buffer.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: Union[float, List[float]], num: int = 1, enabled: bool = True + ) -> Optional[EMAModelTracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model to be tracked. + rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided, + it corresponds to rates for different ranks. + num (int, optional): The number of leading ranks to consider for different rates. + Defaults to 1. + enabled (bool, optional): Flag to enable or disable the creation of the tracker. + If False, returns None. Defaults to True. + + Returns: + Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None. + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2) + >>> print(tracker) + + Notes: + If `rate` is a list and the current rank is less than `num`, the rate for the current rank + is used. If the current rank exceeds `num`, the first rate in the list is used by default. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + rate = rate if isinstance(rate, list) else [rate] + num = min(num, len(rate)) + rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0] + if cur_dp_rank < num: + print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}") + return cls(model, rate) + + +class PowerEMATracker(EMAModelTracker): + def __init__(self, model: Model, s: float = 0.1, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + s (float): EMA decay rate. See EDM2 paper + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming) + self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.exp + 1) + self.beta = beta + + super().update_average(model, iteration) + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True + ) -> Optional[PowerEMATracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model for which the EMA tracker is being set up. + num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged. + rate (float): The base decay rate for the EMA calculation. + enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None. + Defaults to True. + + Returns: + Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None. + + Raises: + None + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99) + >>> print(tracker) + + Notes: + The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`. + If the rank is greater than or equal to `num`, the base rate is used without modification. This approach + allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization + in a distributed training scenario. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + + divider = 2**cur_dp_rank if cur_dp_rank < num else 1 + if cur_dp_rank < num: + print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}") + return cls(model, rate / divider) + + +@contextmanager +def ema_scope(model: Model, enabled: bool = False) -> Generator[None, None, None]: + """Context manager for switching between regular and EMA model weights. + + Args: + model (Model): The PyTorch model. + enabled (bool): Whether switching to EMA weights is enabled (default: False). + """ + if enabled: + assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker)) + model.ema.cache(model.parameters()) + model.ema.copy_to(model) + log.info("EMA: switched to EMA weights.") + try: + yield None + finally: + if enabled: + model.ema.restore(model.parameters()) + log.info("EMA: restored regular weights.") diff --git a/cosmos_predict1/utils/env_parsers/cred_env_parser.py b/cosmos_predict1/utils/env_parsers/cred_env_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa34d8402cfd80d86bed87be31863b3afd0ceaa --- /dev/null +++ b/cosmos_predict1/utils/env_parsers/cred_env_parser.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.env_parsers.env_parser import EnvParser +from cosmos_predict1.utils.validator import String + + +class CredentialEnvParser(EnvParser): + APP_ENV = String(default="") + PROD_FT_AWS_CREDS_ACCESS_KEY_ID = String(default="") + PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY = String(default="") + PROD_FT_AWS_CREDS_ENDPOINT_URL = String(default="https://s3.us-west-2.amazonaws.com") + PROD_FT_AWS_CREDS_REGION_NAME = String(default="us-west-2") + + PROD_S3_CHECKPOINT_ACCESS_KEY_ID = String(default="") + PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY = String(default="") + PROD_S3_CHECKPOINT_ENDPOINT_URL = String(default="") + PROD_S3_CHECKPOINT_REGION_NAME = String(default="") + + PROD_TEAM_DIR_ACCESS_KEY_ID = String(default="") + PROD_TEAM_DIR_SECRET_ACCESS_KEY = String(default="") + PROD_TEAM_DIR_ENDPOINT_URL = String(default="") + PROD_TEAM_DIR_REGION_NAME = String(default="") + + PICASSO_AUTH_MODEL_REGISTRY_API_KEY = String(default="") + PICASSO_API_ENDPOINT_URL = String(default="https://meeocvslt2.execute-api.us-west-2.amazonaws.com") + + +CRED_ENVS = CredentialEnvParser() +CRED_ENVS_DICT = { + "PROD_FT_AWS_CREDS": { + "aws_access_key_id": CRED_ENVS.PROD_FT_AWS_CREDS_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_FT_AWS_CREDS_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_FT_AWS_CREDS_REGION_NAME, + }, + "PROD_S3_CHECKPOINT": { + "aws_access_key_id": CRED_ENVS.PROD_S3_CHECKPOINT_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_S3_CHECKPOINT_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_S3_CHECKPOINT_REGION_NAME, + }, + "PROD_TEAM_DIR": { + "aws_access_key_id": CRED_ENVS.PROD_TEAM_DIR_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_TEAM_DIR_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_TEAM_DIR_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_TEAM_DIR_REGION_NAME, + }, +} diff --git a/cosmos_predict1/utils/env_parsers/env_parser.py b/cosmos_predict1/utils/env_parsers/env_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1579cbddb1edd8eff1209aa48e7b5cd674f71e2a --- /dev/null +++ b/cosmos_predict1/utils/env_parsers/env_parser.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.validator import JsonDict, Validator + +""" +Base class for parsing environment variables using validators. +Class will go through its list of validators and retrieve values from same named environment variables. +Validators provide: +- default value +- typed parsing +- enforments of mandatory values + +Additionally the environment variables can be passed as single base64 encoded string. + +we cannot enforce that a component isn't directly using the environment variables. +so evaluation of params should throw error to make sure actual env var is correct. +""" + + +class EnvParser: + def __init__(self, b64_str=None): + if b64_str: + log.critical(f"b64_str recieved: {b64_str}") + self.from_b64(b64_str) + else: + self.from_env() + + def from_env(self): + validators = self.get_val_dict() + for key in validators.keys(): + val = os.getenv(key.upper()) + log.debug(f"getting env var {key.upper()}: {val}") + if val: + setattr(self, key, val) + self.check_mandatory_values() + + def from_json(self, file_name): + with open(file_name, "r") as f: + log.info(f"Reading env params from {file_name}") + dict = json.load(f) + for key, value in dict.items(): + setattr(self, key, value) + self.check_mandatory_values() + + def to_b64(self): + json_str = self.to_json() + # create bytes-like object for b64 encoder + json_str_bytes = json_str.encode() + b64_str = base64.b64encode(json_str_bytes).decode() + + print(b64_str) + return b64_str + + def from_b64(self, b64_str): + json_str = base64.b64decode(b64_str).decode() + dict = json.loads(json_str) + for key, value in dict.items(): + setattr(self, key, value) + self.check_mandatory_values() + + def check_mandatory_values(self): + for key, validator in self.get_val_dict().items(): + if getattr(self, key) is None and validator.default is None: + raise ValueError(f"Missing mandatory env var: {key}") + + @classmethod + def get_val_dict(cls): + log.debug(f"getting val dict of {cls.__name__}") + val_dict = {} + val_dict.update({key: value for key, value in cls.__dict__.items() if isinstance(value, Validator)}) + + return val_dict + + def dump_validators(self): + validators = self.get_val_dict() + for key, value in validators.items(): + log.debug(f"{key}: {value.__get__(self)}") + + def to_json(self, file_name=None): + dict = { + key.upper(): value.__get__(self) + for key, value in EnvParser.__dict__.items() + if isinstance(value, Validator) + } + json_str = json.dumps(dict, indent=4) + print(json_str) + + if file_name: + with open(file_name, "w") as f: + log.info(f"Writing env params to {file_name}") + f.write(json_str) + + return json_str + + def to_string_dict(self): + result = {} + for key, validator in self.get_val_dict().items(): + value = getattr(self, key) + if value is None: + value = validator.default + if isinstance(validator, JsonDict): + value = json.dumps(value) + else: + value = str(value) + result[key] = value + return result + + def __str__(self): + return ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) diff --git a/cosmos_predict1/utils/fsdp_checkpointer.py b/cosmos_predict1/utils/fsdp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba4d670f8f7083d19737dc013bf59715205d266 --- /dev/null +++ b/cosmos_predict1/utils/fsdp_checkpointer.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import gc +import os +import threading + +import torch +from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + +from cosmos_predict1.utils import callback, distributed, log, misc +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.fsdp_optim_fix import scatter_full_optim_state_dict +from cosmos_predict1.utils.model import Model + + +class FSDPCheckpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path + self.load_training_state = config_checkpoint.load_training_state + self.save_thread = None + self.config_checkpoint = config_checkpoint + + def _load_ckpt_file_during_init(self): + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}") + log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}") + log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)") + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}") + if resume: + log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)") + else: + log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)") + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + log.critical("[Checkpoint] No checkpoint path specified") + log.critical("[Checkpoint] Starting fresh training with random initialization") + return checkpoint_path, resume + + @misc.timer("FSDP.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + if ema_id > 0: + assert is_ema, "ema_id should be used with is_ema=True" + checkpoint_path, _ = self._load_ckpt_file_during_init() + if checkpoint_path is not None: + tag = "reg" if not is_ema else "ema" + default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if not os.path.exists(default_checkpoint_path): + default_checkpoint_path = checkpoint_path # starting from the release checkpoint + log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}") + if tag == "ema" and ema_id > 0: + _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt") + _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if self._check_checkpoint_exists(_checkpoint_path, is_raise=False): + default_checkpoint_path = _checkpoint_path + else: + print( + f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} " + f"(fallback to {default_checkpoint_path})" + ) + checkpoint_path = default_checkpoint_path + self._check_checkpoint_exists(checkpoint_path) + + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the model...") + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + else: + log.info(f"is_ema={is_ema} model is not found and loaded.") + + @misc.timer("FSDP.load_optim_scheduler_during_init") + def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler): + checkpoint_path, resume = self._load_ckpt_file_during_init() + log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}") + if checkpoint_path is not None: + if resume: + checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt") + self._check_checkpoint_exists(checkpoint_path) + if distributed.get_rank() == 0: + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False + ) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the optimizer (FSDP scatter)...") + else: + state_dict = { + "optimizer": None, + "scheduler": None, + } + distributed.barrier() + sharded_optimizer_state_dict = scatter_full_optim_state_dict( # <---- FSDP + state_dict["optimizer"], + fsdp_model, + ) + log.info("- Loading the optimizer (FSDP load_state_dict)...") + log.info(optimizer.load_state_dict(sharded_optimizer_state_dict)) + log.critical("Skip loading the scheduler...") + return + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + + @misc.timer("FSDP get_optim_scheduler_state") + def get_optim_scheduler_state(self, optim, fsdp_model, scheduler): + with FSDP.state_dict_type( + fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + scheduler_statedict = scheduler.state_dict() + return { + "optimizer": optim_statedict, + "scheduler": scheduler_statedict, + } + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + async_saving: bool = True, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + model_state_dict = model.state_dict_model() + optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler) + torch.cuda.empty_cache() + state_dict = dict( + iteration=iteration, + ) + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + + postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix() + if replicate_idx == 0 and shard_idx == 0: + pass # save whole; it is rank0 + elif replicate_idx < total_ema_num and shard_idx == 0: + model_state_dict["model"] = None # only save ema + optim_scheduler_state_dict = None + state_dict = None + else: + return + + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + if async_saving: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=( + model_state_dict, + optim_scheduler_state_dict, + state_dict, + checkpoint_file, + distributed.get_rank(), + ), + ) + self.save_thread.start() + log.info("checkpoint saving from an async thread") + else: + torch.cuda.empty_cache() + # Run the checkpoint saver in the current thread. + self._save_worker_local( + model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank() + ) + log.info("checkpoint saved within the main thread") + del model_state_dict, optim_scheduler_state_dict, state_dict + gc.collect() + torch.cuda.empty_cache() + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local( + self, + model_state_dict: dict[str, torch.Tensor], + optim_scheduler_state_dict: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + checkpoint_file: str, + rank: int = 0, + ) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"] + if model_state_dict is not None: + torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt")) + if ema_model_state_dict is not None: + torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt")) + if optim_scheduler_state_dict is not None: + torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt")) + if state_dict is not None: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (FSDPDiffModle): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + del optimizer, grad_scaler + checkpoint_path, resume = self._load_ckpt_file_during_init() + iteration = 0 + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + if resume: + iteration = state_dict["iteration"] + log.success("Done with loading the checkpoint.") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + if scheduler is not None: + scheduler.last_epoch = iteration + log.critical(f"resume scheduler from {iteration}", rank0_only=False) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + if checkpoint_file is None: + log.warning(f"Latest ckpt file not found: {latest_path}") + else: + log.info(f"Found latest checkpoint: {checkpoint_file}") + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + if is_raise: + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + return False + return True + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + +class FSDPInferenceCheckpointer: + def __init__( + self, + ckpt_path: str, + strict_resume: bool = True, + ): + self.ckpt_path = ckpt_path + self.strict_resume = strict_resume + + @misc.timer("FSDPInferenceCheckpointer.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + del ema_id + if is_ema: + log.warning("EMA model is not supported in inference mode.") + return + assert easy_io.exists(self.ckpt_path) + log.info(f"Loading from {self.ckpt_path}") + state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + + def load_optim_scheduler_during_init(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def save(self, *args, **kwargs): + """ + We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def load(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + return 0 diff --git a/cosmos_predict1/utils/fsdp_optim_fix.py b/cosmos_predict1/utils/fsdp_optim_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..a08aa943828d4c9c385d25873528edae0a84ec24 --- /dev/null +++ b/cosmos_predict1/utils/fsdp_optim_fix.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: skip_file + +""" +torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode +torch impl uses state.rank and dist.rank() inconsistently +The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode +Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2 +""" + +import copy +import warnings +from typing import Any, Dict, Iterable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._optim_utils import ( + _flatten_optim_state, + _FSDPState, + _get_fqn_to_fsdp_param_info, + _get_param_to_fqns, + _OptimStateKey, + _PosDimTensorInfo, + _shard_orig_param_state, + tree_map_only, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: Dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> Dict[str, Any]: + objects: List[Any] = [None] + if fsdp_state.rank == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank() == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any: + if dist.get_rank() == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + assert state.dim() == 0, ( + "For non-zero ranks, a tensor state should have zero dimension, " + "but got the state with shape {state.shape()}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _flatten_optim_state_dict( + optim_state_dict: Dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict") + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn].keys(): + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}." + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.") + + else: + raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.") + else: # do not flatten non-FSDP parameters' states + assert len(fqns) == 1 + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _optim_state_dict_to_load_impl( + optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + The internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + assert optim_input is None and not rank0_only + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params + assert all( + use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) + ), "Not all FSDP modules have the same _use_orig_params value" + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + +def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[Dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, +) -> Dict[str, Any]: + """ + Scatters the full optimizer state dict from rank 0 to all other ranks, + returning the sharded optimizer state dict on each rank. The return + value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load") + return _optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) diff --git a/cosmos_predict1/utils/fused_adam.py b/cosmos_predict1/utils/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..192e29552f8fcaec5b50f325d35c32b6807948f1 --- /dev/null +++ b/cosmos_predict1/utils/fused_adam.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from apex.multi_tensor_apply import multi_tensor_applier + +from cosmos_predict1.utils import distributed, log + + +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Currently GPU-only. Requires Apex to be installed via + ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters + into one or a few kernel launches. + + :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adam_w_mode=False``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + ... + opt.step() + + :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, + you may choose any ``opt_level``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") + ... + opt.step() + + In general, ``opt_level="O1"`` is recommended. + + + .. warning:: + A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. + These additional arguments are now deprecated and unnecessary. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + capturable (bool, optional): whether to use the version of the optimizer + that can be used with CUDA Graphs. (default: False) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16 mixed precision training, currently can + only be used with capturable set to True. (default: False) + + .. _Adam - A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adam_w_mode=True, + weight_decay=0.0, + amsgrad=False, + capturable=False, + master_weights=False, + ): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + if master_weights and not capturable: + raise RuntimeError("Master weights is currently only supported with the capturable version.") + # If the optimizer is capturable then LR should be a tensor (on GPU) + log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}") + lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adam_w_mode = 1 if adam_w_mode else 0 + + self.capturable = capturable + self.master_weights = master_weights + + self.param_groups_master = None + + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + for item in ["lr"]: + if isinstance(group[item], float): + group[item] = torch.tensor(group[item], dtype=torch.float32) + self.param_groups[idx][item] = group[item].to(device=device) + + self._step_supports_amp_scaling = True + + if multi_tensor_applier.available: + import amp_C + + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + self.multi_tensor_adam = amp_C.multi_tensor_adam + self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable + self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master + else: + raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions") + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. " + "Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + loss = None + if closure is not None: + loss = closure() + + if self.param_groups_master is None: + # Create full precision master weights + self.param_groups_master = [] + for i, pg in enumerate(self.param_groups): + param_list = pg["params"] + self.param_groups_master.append( + { + "params": [p.clone().detach().float() if self.master_weights else None for p in param_list], + } + ) + + for group, group_master in zip(self.param_groups, self.param_groups_master): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + bias_correction = 1 if "bias_correction" in group and group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + if self.capturable: + group["step"] = ( + group["step"].to(device=device) + if isinstance(group["step"], torch.Tensor) + else torch.tensor(group["step"], dtype=torch.int32, device=device) + ) + group["step"] += (self._dummy_overflow_buf != 1).to(torch.int) + else: + group["step"] += 1 + else: + group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) + + if self.capturable: + group["lr"] = ( + group["lr"].to(device=device) + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32, device=device) + ) + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_bf, p_bf, m_bf, v_bf = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + p_16_master = [] + p_32_master = [] + bf16_master = [] + + for p, p_master in zip(group["params"], group_master["params"]): + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).float() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).float() + + if p.dtype == torch.float16: + if self.master_weights: + p_16_master.append(p_master.data) + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) + elif p.dtype == torch.bfloat16: + if self.master_weights: + bf16_master.append(p_master.data) + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state["exp_avg"]) + v_bf.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + if self.master_weights: + p_32_master.append(p_master.data) + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) + else: + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + if self.capturable: + # overflow check of gradients + found_inf = ( + grad_scaler._check_inf_per_device(self)[device] + if grad_scaler is not None + else torch.zeros((1,), device=device) + ) + self._dummy_overflow_buf.copy_(found_inf) + + # get unscale scale factor + scale, inv_scale = None, None + if grad_scaler: + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + scale = torch.ones((1,), device=device, dtype=torch.float32) + inv_scale = torch.ones((1,), device=device, dtype=torch.float32) + + if len(g_16) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_bf) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_32) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + else: + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_bf) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + for group in self.param_groups: + if self.capturable: + group["lr"] = ( + group["lr"].cuda() + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32).cuda() + ) + + if "step" in group: + if self.capturable: + if distributed.get_rank() == 0: + step = ( + group["step"].cuda() + if isinstance(group["step"], torch.Tensor) + else torch.tensor([group["step"]], dtype=torch.int32).cuda() + ) + else: + step = torch.zeros(1, dtype=torch.int32).cuda() + # make it compatible with FSDP optimizer + distributed.broadcast(step, 0) + group["step"] = step + elif isinstance(group["step"], torch.Tensor): + group["step"] = group["step"].item() + for p in group["params"]: + state = self.state[p] + if "exp_avg" in state: + state["exp_avg"] = state["exp_avg"].float() + state["exp_avg_sq"] = state["exp_avg_sq"].float() diff --git a/cosmos_predict1/utils/io.py b/cosmos_predict1/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..c877aa41fd6b90638281f048bac23fc8214b84be --- /dev/null +++ b/cosmos_predict1/utils/io.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from io import BytesIO +from typing import Dict, List + +import imageio +import numpy as np + + +def read_prompts_from_file(prompt_file: str) -> List[Dict[str, str]]: + """Read prompts from a JSONL file where each line is a dict with 'prompt' key and optionally 'visual_input' key. + + Args: + prompt_file (str): Path to JSONL file containing prompts + + Returns: + List[Dict[str, str]]: List of prompt dictionaries + """ + prompts = [] + with open(prompt_file, "r") as f: + for line in f: + prompt_dict = json.loads(line.strip()) + prompts.append(prompt_dict) + return prompts + + +def save_video(video, fps, H, W, video_save_quality, video_save_path): + """Save video frames to file. + + Args: + grid (np.ndarray): Video frames array [T,H,W,C] + fps (int): Frames per second + H (int): Frame height + W (int): Frame width + video_save_quality (int): Video encoding quality (0-10) + video_save_path (str): Output video file path + """ + kwargs = { + "fps": fps, + "quality": video_save_quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{W}x{H}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(video_save_path, video, "mp4", **kwargs) + + +def load_from_fileobj(filepath: str, format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + with open(filepath, "rb") as f: + value = f.read() + with BytesIO(value) as f: + f.seek(0) + video_reader = imageio.get_reader(f, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() diff --git a/cosmos_predict1/utils/lazy_config/__init__.py b/cosmos_predict1/utils/lazy_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3df830db623db39690c68ae09fa7e576cea5d0c --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from omegaconf import DictConfig, OmegaConf + +from cosmos_predict1.utils.lazy_config.instantiate import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyCall, LazyConfig +from cosmos_predict1.utils.lazy_config.omegaconf_patch import to_object + +OmegaConf.to_object = to_object + +PLACEHOLDER = None +LazyDict = DictConfig + +__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] + + +DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py + + +def fixup_module_metadata(module_name, namespace, keys=None): + """ + Fix the __qualname__ of module members to be their exported api name, so + when they are referenced in docs, sphinx can find them. Reference: + https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 + """ + if not DOC_BUILDING: + return + seen_ids = set() + + def fix_one(qualname, name, obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualitied + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + if keys is None: + keys = namespace.keys() + for objname in keys: + if not objname.startswith("_"): + obj = namespace[objname] + fix_one(objname, objname, obj) + + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/cosmos_predict1/utils/lazy_config/file_io.py b/cosmos_predict1/utils/lazy_config/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..d9caf0081976dd08ab6ea1c04ad53304bc51d05d --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/file_io.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler +from iopath.common.file_io import PathManager as PathManagerBase + +__all__ = ["PathManager", "PathHandler"] + + +PathManager = PathManagerBase() +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) diff --git a/cosmos_predict1/utils/lazy_config/instantiate.py b/cosmos_predict1/utils/lazy_config/instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..3c87b7a555b2292468360b015b8100d09cf5cc59 --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/instantiate.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc as abc +import dataclasses +import logging +from typing import Any + +import attrs + +from cosmos_predict1.utils.lazy_config.registry import _convert_target_to_string, locate + +__all__ = ["dump_dataclass", "instantiate"] + + +def is_dataclass_or_attrs(target): + return dataclasses.is_dataclass(target) or attrs.has(target) + + +def dump_dataclass(obj: Any): + """ + Dump a dataclass recursively into a dict that can be later instantiated. + + Args: + obj: a dataclass object + + Returns: + dict + """ + assert dataclasses.is_dataclass(obj) and not isinstance( + obj, type + ), "dump_dataclass() requires an instance of a dataclass." + ret = {"_target_": _convert_target_to_string(type(obj))} + for f in dataclasses.fields(obj): + v = getattr(obj, f.name) + if dataclasses.is_dataclass(v): + v = dump_dataclass(v) + if isinstance(v, (list, tuple)): + v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] + ret[f.name] = v + return ret + + +def instantiate(cfg, *args, **kwargs): + """ + Recursively instantiate objects defined in dictionaries by + "_target_" and arguments. + + Args: + cfg: a dict-like object with "_target_" that defines the caller, and + other keys that define the arguments + args: Optional positional parameters pass-through. + kwargs: Optional named parameters pass-through. + + Returns: + object instantiated by cfg + """ + from omegaconf import DictConfig, ListConfig, OmegaConf + + if isinstance(cfg, ListConfig): + lst = [instantiate(x) for x in cfg] + return ListConfig(lst, flags={"allow_objects": True}) + if isinstance(cfg, list): + # Specialize for list, because many classes take + # list[objects] as arguments, such as ResNet, DatasetMapper + return [instantiate(x) for x in cfg] + + # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), + # instantiate it to the actual dataclass. + if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): + return OmegaConf.to_object(cfg) + + if isinstance(cfg, abc.Mapping) and "_target_" in cfg: + # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, + # but faster: https://github.com/facebookresearch/hydra/issues/1200 + cfg = {k: instantiate(v) for k, v in cfg.items()} + cls = cfg.pop("_target_") + cls = instantiate(cls) + + if isinstance(cls, str): + cls_name = cls + cls = locate(cls_name) + assert cls is not None, cls_name + else: + try: + cls_name = cls.__module__ + "." + cls.__qualname__ + except Exception: + # target could be anything, so the above could fail + cls_name = str(cls) + assert callable(cls), f"_target_ {cls} does not define a callable object" + try: + # override config with kwargs + instantiate_kwargs = {} + instantiate_kwargs.update(cfg) + instantiate_kwargs.update(kwargs) + return cls(*args, **instantiate_kwargs) + except TypeError: + logger = logging.getLogger(__name__) + logger.error(f"Error when instantiating {cls_name}!") + raise + return cfg # return as-is if don't know what to do diff --git a/cosmos_predict1/utils/lazy_config/lazy.py b/cosmos_predict1/utils/lazy_config/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..24063a35533ffcdba4b435dc388d87535ad4d330 --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/lazy.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import builtins +import collections.abc as abc +import importlib +import inspect +import logging +import os +import pickle +import uuid +from collections import OrderedDict +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import is_dataclass +from typing import Any, Dict, List, Tuple, Union + +import attrs +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +from cosmos_predict1.utils.lazy_config.file_io import PathManager +from cosmos_predict1.utils.lazy_config.registry import _convert_target_to_string + +try: + import dill as dill_pickle +except ImportError: + dill_pickle = None +try: + import cloudpickle +except ImportError: + cloudpickle = None + +__all__ = ["LazyCall", "LazyConfig"] + + +def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: + return OrderedDict(sorted(d.items(), key=lambda x: x[0])) + + +def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: + if isinstance(obj, dict): + return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) + elif isinstance(obj, list): + return [sort_recursive(item) for item in obj] + return obj + + +yaml.add_representer(OrderedDict, dict_representer) + + +def get_default_params(cls_or_func): + if callable(cls_or_func): + # inspect signature for function + signature = inspect.signature(cls_or_func) + else: + # inspect signature for class + signature = inspect.signature(cls_or_func.__init__) + params = signature.parameters + default_params = { + name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty + } + return default_params + + +class LazyCall: + """ + Wrap a callable so that when it's called, the call will not be executed, + but returns a dict that describes the call. + + LazyCall object has to be called with only keyword arguments. Positional + arguments are not yet supported. + + Examples: + :: + from detectron2.config import instantiate, LazyCall + + layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) + layer_cfg.out_channels = 64 # can edit it afterwards + layer = instantiate(layer_cfg) + """ + + def __init__(self, target): + if not (callable(target) or isinstance(target, (str, abc.Mapping))): + raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") + self._target = target + + def __call__(self, **kwargs): + if is_dataclass(self._target) or attrs.has(self._target): + # omegaconf object cannot hold dataclass type + # https://github.com/omry/omegaconf/issues/784 + target = _convert_target_to_string(self._target) + else: + target = self._target + kwargs["_target_"] = target + + _final_params = get_default_params(self._target) + _final_params.update(kwargs) + + return DictConfig(content=_final_params, flags={"allow_objects": True}) + + +def _visit_dict_config(cfg, func): + """ + Apply func recursively to all DictConfig in cfg. + """ + if isinstance(cfg, DictConfig): + func(cfg) + for v in cfg.values(): + _visit_dict_config(v, func) + elif isinstance(cfg, ListConfig): + for v in cfg: + _visit_dict_config(v, func) + + +def _validate_py_syntax(filename): + # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + with PathManager.open(filename, "r") as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError(f"Config file {filename} has syntax error!") from e + + +def _cast_to_config(obj): + # if given a dict, return DictConfig instead + if isinstance(obj, dict): + return DictConfig(obj, flags={"allow_objects": True}) + return obj + + +_CFG_PACKAGE_NAME = "detectron2._cfg_loader" +""" +A namespace to put all imported config into. +""" + + +def _random_package_name(filename): + # generate a random package name when loading config files + return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) + + +@contextmanager +def _patch_import(): + """ + Enhance relative import statements in config files, so that they: + 1. locate files purely based on relative location, regardless of packages. + e.g. you can import file without having __init__ + 2. do not cache modules globally; modifications of module states has no side effect + 3. support other storage system through PathManager, so config files can be in the cloud + 4. imported dict are turned into omegaconf.DictConfig automatically + """ + old_import = builtins.__import__ + + def find_relative_file(original_file, relative_import_path, level): + # NOTE: "from . import x" is not handled. Because then it's unclear + # if such import should produce `x` as a python module or DictConfig. + # This can be discussed further if needed. + relative_import_err = """ +Relative import of directories is not allowed within config files. +Within a config file, relative import can only import other config files. +""".replace( + "\n", " " + ) + if not len(relative_import_path): + raise ImportError(relative_import_err) + + cur_file = os.path.dirname(original_file) + for _ in range(level - 1): + cur_file = os.path.dirname(cur_file) + cur_name = relative_import_path.lstrip(".") + for part in cur_name.split("."): + cur_file = os.path.join(cur_file, part) + if not cur_file.endswith(".py"): + cur_file += ".py" + if not PathManager.isfile(cur_file): + cur_file_no_suffix = cur_file[: -len(".py")] + if PathManager.isdir(cur_file_no_suffix): + raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) + else: + raise ImportError( + f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist." + ) + return cur_file + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + if ( + # Only deal with relative imports inside config files + level != 0 + and globals is not None + and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) + ): + cur_file = find_relative_file(globals["__file__"], name, level) + _validate_py_syntax(cur_file) + spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) + module = importlib.util.module_from_spec(spec) + module.__file__ = cur_file + with PathManager.open(cur_file) as f: + content = f.read() + exec(compile(content, cur_file, "exec"), module.__dict__) + for name in fromlist: # turn imported dict into DictConfig automatically + val = _cast_to_config(module.__dict__[name]) + module.__dict__[name] = val + return module + return old_import(name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + yield new_import + builtins.__import__ = old_import + + +class LazyConfig: + """ + Provide methods to save, load, and overrides an omegaconf config object + which may contain definition of lazily-constructed objects. + """ + + @staticmethod + def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Similar to :meth:`load()`, but load path relative to the caller's + source file. + + This has the same functionality as a relative import, except that this method + accepts filename as a string, so more characters are allowed in the filename. + """ + caller_frame = inspect.stack()[1] + caller_fname = caller_frame[0].f_code.co_filename + assert caller_fname != "", "load_rel Unable to find caller" + caller_dir = os.path.dirname(caller_fname) + filename = os.path.join(caller_dir, filename) + return LazyConfig.load(filename, keys) + + @staticmethod + def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Load a config file. + + Args: + filename: absolute path or relative path w.r.t. the current working directory + keys: keys to load and return. If not given, return all keys + (whose values are config objects) in a dict. + """ + has_keys = keys is not None + filename = filename.replace("/./", "/") # redundant + if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: + raise ValueError(f"Config file {filename} has to be a python or yaml file.") + if filename.endswith(".py"): + _validate_py_syntax(filename) + + with _patch_import(): + # Record the filename + module_namespace = { + "__file__": filename, + "__package__": _random_package_name(filename), + } + with PathManager.open(filename) as f: + content = f.read() + # Compile first with filename to: + # 1. make filename appears in stacktrace + # 2. make load_rel able to find its parent's (possibly remote) location + exec(compile(content, filename, "exec"), module_namespace) + + ret = module_namespace + else: + with PathManager.open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + + if has_keys: + if isinstance(keys, str): + return _cast_to_config(ret[keys]) + else: + return tuple(_cast_to_config(ret[a]) for a in keys) + else: + if filename.endswith(".py"): + # when not specified, only load those that are config objects + ret = DictConfig( + { + name: _cast_to_config(value) + for name, value in ret.items() + if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") + }, + flags={"allow_objects": True}, + ) + return ret + + @staticmethod + def save_pkl(cfg, filename: str) -> str: + """ + Saves a Config object to a file using pickle serialization. This method is typically used + when the configuration object contains complex objects, such as lambdas, that are not supported by + simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration + object before serialization to ensure that the original object remains unmodified. + + Args: + cfg: A Config object to be serialized and saved. + filename: The path and name of the file where the configuration should be saved. The function + assumes the file extension indicates a pickle format (e.g., .pkl). + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location + or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using pickle. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cfg, f) + logger.warning(f"Config is saved using pickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead") + if dill_pickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(dill_pickle.dumps(cfg, recurse=True), f) + logger.warning(f"Config is saved using dill at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + if cloudpickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cloudpickle.dumps(cfg), f) + logger.warning(f"Config is saved using cloudpickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + else: + logger.error("cloudpickle is not available. Cannot save the config.") + raise e + + return filename + + @staticmethod + def save_yaml(cfg, filename: str) -> str: + """ + Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization. + + Args: + cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types. + filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'. + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using YAML. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + # Define a function to check if an item is serializable to YAML + def is_serializable(item): + try: + OmegaConf.to_yaml(item) + return True + except Exception as e: + return False + + # Function to convert unserializable items to strings + def serialize_config(config): + if isinstance(config, DictConfig): + for key, value in config.items(): + if isinstance(value, (DictConfig, ListConfig)): + try: + if "_target_" in value: + default_params = get_default_params(value["_target_"]) + for default_key, default_v in default_params.items(): + if default_key not in value: + value[default_key] = default_v + except Exception as e: + logger.error(f"Failed to add default argument values: {e}") + + serialize_config(value) + else: + if not is_serializable(value) and value is not None: + config[key] = str(value) + elif isinstance(config, ListConfig): + for i, item in enumerate(config): + if isinstance(item, (DictConfig, ListConfig)): + serialize_config(item) + else: + if not is_serializable(item) and item is not None: + config[i] = str(item) + else: + raise NotImplementedError("Input config must be a DictConfig or ListConfig.") + return config + + # Convert Config object to a DictConfig object. + config_dict = attrs.asdict(cfg) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + + # Serialize the DictConfig object by converting non-serializable objects to strings. + config_omegaconf = serialize_config(config_omegaconf) + + config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) + sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict) + with open(filename, "w") as f: + yaml.dump(sorted_config, f, default_flow_style=False) + logger.warning(f"Config is saved using omegaconf at {filename}.") + return filename diff --git a/cosmos_predict1/utils/lazy_config/omegaconf_patch.py b/cosmos_predict1/utils/lazy_config/omegaconf_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..39dca42a0a71383de919b750cedf2606faae206d --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/omegaconf_patch.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Union + +from omegaconf import OmegaConf +from omegaconf.base import DictKeyType, SCMode +from omegaconf.dictconfig import DictConfig # pragma: no cover + + +def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: + """ + Converts an OmegaConf configuration object to a native Python container (dict or list), unless + the configuration is specifically created by LazyCall, in which case the original configuration + is returned directly. + + This function serves as a modification of the original `to_object` method from OmegaConf, + preventing DictConfig objects created by LazyCall from being automatically converted to Python + dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended + structure and behavior. + + Differences from OmegaConf's original `to_object`: + - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. + + Reference: + - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 + + Args: + cfg (Any): The OmegaConf configuration object to convert. + + Returns: + Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if + `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. + + Examples: + >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) + >>> to_object(cfg) + DictConfig({"key": "value", "_target_": "Model"}) + + >>> cfg = DictConfig({"list": [1, 2, 3]}) + >>> to_object(cfg) + {'list': [1, 2, 3]} + """ + if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): + return cfg + + return OmegaConf.to_container( + cfg=cfg, + resolve=True, + throw_on_missing=True, + enum_to_str=False, + structured_config_mode=SCMode.INSTANTIATE, + ) diff --git a/cosmos_predict1/utils/lazy_config/registry.py b/cosmos_predict1/utils/lazy_config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7c09eb428a97927d5f0407e2328a3f43afbf38fc --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/registry.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pydoc +from typing import Any + +""" +`locate` provide ways to map a string (typically found +in config files) to callable objects. +""" + +__all__ = ["locate"] + + +def _convert_target_to_string(t: Any) -> str: + """ + Inverse of ``locate()``. + + Args: + t: any object with ``__module__`` and ``__qualname__`` + """ + module, qualname = t.__module__, t.__qualname__ + + # Compress the path to this object, e.g. ``module.submodule._impl.class`` + # may become ``module.submodule.class``, if the later also resolves to the same + # object. This simplifies the string, and also is less affected by moving the + # class implementation. + module_parts = module.split(".") + for k in range(1, len(module_parts)): + prefix = ".".join(module_parts[:k]) + candidate = f"{prefix}.{qualname}" + try: + if locate(candidate) is t: + return candidate + except ImportError: + pass + return f"{module}.{qualname}" + + +def locate(name: str) -> Any: + """ + Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, + such as "module.submodule.class_name". + + Raise Exception if it cannot be found. + """ + obj = pydoc.locate(name) + + # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly + # by pydoc.locate. Try a private function from hydra. + if obj is None: + try: + # from hydra.utils import get_method - will print many errors + from hydra.utils import _locate + except ImportError as e: + raise ImportError(f"Cannot dynamically locate object {name}!") from e + else: + obj = _locate(name) # it raises if fails + + return obj diff --git a/cosmos_predict1/utils/log.py b/cosmos_predict1/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..45f98624193c5551c6c390dd0110d1440a610133 --- /dev/null +++ b/cosmos_predict1/utils/log.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +from typing import Any, Optional + +import torch.distributed as dist +from loguru._logger import Core, Logger +from tqdm import tqdm + +RANK0_ONLY = True +LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") + +logger = Logger( + core=Core(), + exception=None, + depth=1, + record=False, + lazy=False, + colors=False, + raw=False, + capture=True, + patchers=[], + extra={}, +) + +atexit.register(logger.remove) + + +def _add_relative_path(record: dict[str, Any]) -> None: + start = os.getcwd() + record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) + + +*options, _, extra = logger._options # type: ignore +logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore + + +def init_loguru_stdout() -> None: + logger.remove() + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + lambda msg: tqdm.write(msg, end=""), # stdout is replaced with tqdm.write to avoid tqdm log pollution.. + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + filter=_rank0_only_filter, + ) + + +def init_loguru_file(path: str) -> None: + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + path, + encoding="utf8", + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + rotation="100 MB", + filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, + enqueue=True, + ) + + +def get_machine_format() -> str: + node_id = os.environ.get("NGC_ARRAY_INDEX", "0") + num_nodes = int(os.environ.get("NGC_ARRAY_SIZE", "1")) + machine_format = "" + rank = 0 + if dist.is_available(): + if not RANK0_ONLY and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + machine_format = ( + f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " + ) + return machine_format + + +def get_message_format() -> str: + message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" + return message_format + + +def _rank0_only_filter(record: Any) -> bool: + is_rank0 = record["extra"].get("rank0_only", True) + if _get_rank() == 0 and is_rank0: + return True + if not is_rank0: + record["message"] = f"[RANK {_get_rank()}] " + record["message"] + return not is_rank0 + + +def trace(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) + + +def debug(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) + + +def info(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) + + +def success(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) + + +def warning(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) + + +def error(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) + + +def critical(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) + + +def exception(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) + + +def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +# Execute at import time. +init_loguru_stdout() diff --git a/cosmos_predict1/utils/misc.py b/cosmos_predict1/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..ce923cc656093ff4edd6c2a8bb01568c033795c9 --- /dev/null +++ b/cosmos_predict1/utils/misc.py @@ -0,0 +1,557 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import functools +import json +import os +import random +import time +from contextlib import ContextDecorator +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, TypeVar +from urllib.parse import urlparse + +import boto3 +import numpy as np +import termcolor +import torch +from torch import nn +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed._tensor.api import DTensor + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.easy_io import easy_io + + +def to( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + memory_format: torch.memory_format = torch.preserve_format, +) -> Any: + """Recursively cast data into the specified device, dtype, and/or memory_format. + + The input data can be a tensor, a list of tensors, a dict of tensors. + See the documentation for torch.Tensor.to() for details. + + Args: + data (Any): Input data. + device (str | torch.device): GPU device (default: None). + dtype (torch.dtype): data type (default: None). + memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). + + Returns: + data (Any): Data cast to the specified device, dtype, and/or memory_format. + """ + assert ( + device is not None or dtype is not None or memory_format is not None + ), "at least one of device, dtype, memory_format should be specified" + if isinstance(data, torch.Tensor): + is_cpu = (isinstance(device, str) and device == "cpu") or ( + isinstance(device, torch.device) and device.type == "cpu" + ) + data = data.to( + device=device, + dtype=dtype, + memory_format=memory_format, + non_blocking=(not is_cpu), + ) + return data + elif isinstance(data, collections.abc.Mapping): + return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) + else: + return data + + +def serialize(data: Any) -> Any: + """Serialize data by hierarchically traversing through iterables. + + Args: + data (Any): Input data. + + Returns: + data (Any): Serialized data. + """ + if isinstance(data, collections.abc.Mapping): + return type(data)({key: serialize(data[key]) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([serialize(elem) for elem in data]) + else: + try: + json.dumps(data) + except TypeError: + data = str(data) + return data + + +def print_environ_variables(env_vars: list[str]) -> None: + """Print a specific list of environment variables. + + Args: + env_vars (list[str]): List of specified environment variables. + """ + for env_var in env_vars: + if env_var in os.environ: + log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") + else: + log.warning(f"Environment variable {Color.green(env_var)} not set!") + + +def set_random_seed(seed: int, by_rank: bool = False) -> None: + """Set random seed. This includes random, numpy, Pytorch. + + Args: + seed (int): Random seed. + by_rank (bool): if true, each GPU will use a different random seed. + """ + if by_rank: + seed += distributed.get_rank() + log.info(f"Using random seed {seed}.") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # sets seed on the current CPU & all GPUs + + +def arch_invariant_rand( + shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None +): + """Produce a GPU-architecture-invariant randomized Torch tensor. + + Args: + shape (list or tuple of ints): Output tensor shape. + dtype (torch.dtype): Output tensor type. + device (torch.device): Device holding the output. + seed (int): Optional randomization seed. + + Returns: + tensor (torch.tensor): Randomly-generated tensor. + """ + # Create a random number generator, optionally seeded + rng = np.random.RandomState(seed) + + # # Generate random numbers using the generator + random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution + + # Convert to torch tensor and return + return torch.from_numpy(random_array).to(dtype=dtype, device=device) + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class timer(ContextDecorator): # noqa: N801 + """Simple timer for timing the execution of code. + + It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. + + Example: + def func_a(): + time.sleep(1) + with timer("func_a"): + func_a() + + @timer("func_b) + def func_b(): + time.sleep(1) + func_b() + """ + + def __init__(self, context: str, debug: bool = False): + self.context = context + self.debug = debug + + def __enter__(self) -> None: + self.tic = time.time() + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + time_spent = time.time() - self.tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + + def __call__(self, func: T) -> T: + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + tic = time.time() + result = func(*args, **kwargs) + time_spent = time.time() - tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + return result + + return wrapper # type: ignore + + +class TrainingTimer: + """Timer for timing the execution of code, aggregating over multiple training iterations. + + It is used as a context manager to measure the execution time of code and store the timing results + for each function. The context managers can be nested. + + Attributes: + results (dict): A dictionary to store timing results for various code. + + Example: + timer = Timer() + for i in range(100): + with timer("func_a"): + func_a() + avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) + print(f"func_a() took {avg_time} seconds.") + """ + + def __init__(self) -> None: + self.results = dict() + self.average_results = dict() + self.start_time = [] + self.func_stack = [] + self.reset() + + def reset(self) -> None: + self.results = {key: [] for key in self.results} + + def __enter__(self) -> TrainingTimer: + self.start_time.append(time.time()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + end_time = time.time() + result = end_time - self.start_time.pop() + key = self.func_stack.pop() + self.results.setdefault(key, []) + self.results[key].append(result) + + def __call__(self, func_name: str) -> TrainingTimer: + self.func_stack.append(func_name) + return self + + def __getattr__(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def nested(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def compute_average_results(self) -> dict[str, float]: + results = dict() + for key, value_list in self.results.items(): + results[key] = sum(value_list) / len(value_list) + return results + + +def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: + # What to do when the process gets stuck. For now, we simply end the process. + error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." + raise TimeoutError(error_message) + + +class Color: + """A convenience class to colorize strings in the console. + + Example: + import + print("This is {Color.red('important')}.") + """ + + @staticmethod + def red(x: str) -> str: + return termcolor.colored(str(x), color="red") + + @staticmethod + def green(x: str) -> str: + return termcolor.colored(str(x), color="green") + + @staticmethod + def cyan(x: str) -> str: + return termcolor.colored(str(x), color="cyan") + + @staticmethod + def yellow(x: str) -> str: + return termcolor.colored(str(x), color="yellow") + + +class BufferCnt: + """ + Buffer counter which keeps track of the condition when called and returns True when the condition in met "thres" + amount of times, otherwise returns False. + + Example usage: + buf = BufferCnt(thres=3) + for _ in range(5): + if buf(random.random() > 0.5): + print("We got lucky 3 times out of 5.") + + Args: + thres (int): The amount of times the expression needs to be True before returning True. + reset_over_thres (bool): Whether to reset the buffer after returning True. + """ + + def __init__(self, thres=10, reset_over_thres=False): + self._cnt = 0 + self.thres = thres + self.reset_over_thres = reset_over_thres + + def __call__(self, expre, thres=None): + if expre is True: + self._cnt += 1 + else: + self._cnt = 0 + + if thres is None: + thres = self.thres + + if self._cnt >= thres: + if self.reset_over_thres: + self.reset() + return True + + return False + + @property + def cnt(self): + return self._cnt + + def reset(self): + self._cnt = 0 + + +def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: + if isinstance(tensor, DTensor): + local = tensor.to_local() + # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish + # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local + if isinstance(local, AsyncCollectiveTensor): + return local.wait() + else: + return local + return tensor + + +def disabled_train(self: Any, mode: bool = True) -> Any: + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def count_params(model: nn.Module, verbose=False) -> int: + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def expand_dims_like(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def download_from_s3_with_cache( + s3_path: str, + cache_fp: Optional[str] = None, + cache_dir: Optional[str] = None, + rank_sync: bool = True, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """download data from S3 with optional caching. + + This function first attempts to load the data from a local cache file. If + the cache file doesn't exist, it downloads the data from S3 to the cache + location. Caching is performed in a rank-aware manner + using `distributed.barrier()` to ensure only one download occurs across + distributed workers (if `rank_sync` is True). + + Args: + s3_path (str): The S3 path of the data to load. + cache_fp (str, optional): The path to the local cache file. If None, + a filename will be generated based on `s3_path` within `cache_dir`. + cache_dir (str, optional): The directory to store the cache file. If + None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "/tmp") will be used. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. + backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. + + Returns: + cache_fp (str): The path to the local cache file. + + Raises: + FileNotFoundError: If the data cannot be found in S3 or the cache. + """ + cache_dir = os.environ.get("TORCH_HOME") if cache_dir is None else cache_dir + cache_dir = ( + os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir + ) + cache_dir = os.path.expanduser(cache_dir) + if cache_fp is None: + cache_fp = os.path.join(cache_dir, s3_path.replace("s3://", "")) + if not cache_fp.startswith("/"): + cache_fp = os.path.join(cache_dir, cache_fp) + + if distributed.get_rank() == 0: + if os.path.exists(cache_fp): + # check the size of cache_fp + if os.path.getsize(cache_fp) < 1: + os.remove(cache_fp) + log.warning(f"Removed empty cache file {cache_fp}.") + + if rank_sync: + if not os.path.exists(cache_fp): + log.critical(f"Local cache {cache_fp} Not exist! Downloading {s3_path} to {cache_fp}.") + log.info(f"backend_args: {backend_args}") + log.info(f"backend_key: {backend_key}") + + easy_io.copyfile_to_local( + s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key + ) + log.info(f"Downloaded {s3_path} to {cache_fp}.") + else: + log.info(f"Local cache {cache_fp} already exist! {s3_path} -> {cache_fp}.") + + distributed.barrier() + else: + if not os.path.exists(cache_fp): + easy_io.copyfile_to_local( + s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key + ) + + log.info(f"Downloaded {s3_path} to {cache_fp}.") + return cache_fp + + +def load_from_s3_with_cache( + s3_path: str, + cache_fp: Optional[str] = None, + cache_dir: Optional[str] = None, + rank_sync: bool = True, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + easy_io_kwargs: Optional[dict] = None, +) -> Any: + """Loads data from S3 with optional caching. + + This function first attempts to load the data from a local cache file. If + the cache file doesn't exist, it downloads the data from S3 to the cache + location and then loads it. Caching is performed in a rank-aware manner + using `distributed.barrier()` to ensure only one download occurs across + distributed workers (if `rank_sync` is True). + + Args: + s3_path (str): The S3 path of the data to load. + cache_fp (str, optional): The path to the local cache file. If None, + a filename will be generated based on `s3_path` within `cache_dir`. + cache_dir (str, optional): The directory to store the cache file. If + None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "/tmp") will be used. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. + backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. + + Returns: + Any: The loaded data from the S3 path or cache file. + + Raises: + FileNotFoundError: If the data cannot be found in S3 or the cache. + """ + cache_fp = download_from_s3_with_cache(s3_path, cache_fp, cache_dir, rank_sync, backend_args, backend_key) + + if easy_io_kwargs is None: + easy_io_kwargs = {} + return easy_io.load(cache_fp, **easy_io_kwargs) + + +def sync_s3_dir_to_local( + s3_dir: str, + s3_credential_path: str, + cache_dir: Optional[str] = None, + rank_sync: bool = True, +) -> str: + """ + Download an entire directory from S3 to the local cache directory. + + Args: + s3_dir (str): The AWS S3 directory to download. + s3_credential_path (str): The path to the AWS S3 credentials file. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + cache_dir (str, optional): The cache folder to sync the S3 directory to. + If None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "~/.cache/cosmos") will be used. + + Returns: + local_dir (str): The path to the local directory. + """ + if not s3_dir.startswith("s3://"): + # If the directory exists locally, return the local path + assert os.path.exists(s3_dir), f"{s3_dir} is not a S3 path or a local path." + return s3_dir + + # Load AWS credentials from the file + with open(s3_credential_path, "r") as f: + credentials = json.load(f) + + # Create an S3 client + s3 = boto3.client( + "s3", + **credentials, + ) + + # Parse the S3 URL + parsed_url = urlparse(s3_dir) + source_bucket = parsed_url.netloc + source_prefix = parsed_url.path.lstrip("/") + + # If the local directory is not specified, use the default cache directory + cache_dir = ( + os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir + ) + cache_dir = os.path.expanduser(cache_dir) + Path(cache_dir).mkdir(parents=True, exist_ok=True) + + # List objects in the bucket with the given prefix + response = s3.list_objects_v2(Bucket=source_bucket, Prefix=source_prefix) + # Download each matching object + for obj in response.get("Contents", []): + if obj["Key"].startswith(source_prefix): + # Create the full path for the destination file, preserving the directory structure + rel_path = os.path.relpath(obj["Key"], source_prefix) + dest_path = os.path.join(cache_dir, source_prefix, rel_path) + + # Ensure the directory exists + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + + # Check if the file already exists + if os.path.exists(dest_path): + continue + else: + log.info(f"Downloading {obj['Key']} to {dest_path}") + # Download the file + if not rank_sync or distributed.get_rank() == 0: + s3.download_file(source_bucket, obj["Key"], dest_path) + if rank_sync: + distributed.barrier() + local_dir = os.path.join(cache_dir, source_prefix) + return local_dir diff --git a/cosmos_predict1/utils/model.py b/cosmos_predict1/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6add316d05e17ca23ec6b9d7fb56ec744bbb220 --- /dev/null +++ b/cosmos_predict1/utils/model.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from cosmos_predict1.utils.lazy_config import LazyDict, instantiate + + +class Model(torch.nn.Module): + """The base model class. It is inherited from torch.nn.Module. + + All models should inherit Model. It should include the implementions for all the + computation graphs. All inheriting child classes should implement the following methods: + - training_step(): The training step of the model, including the loss computation. + - validation_step(): The validation step of the model, including the loss computation. + - forward(): The computation graph for model inference. + The following methods have default implementations in Model: + - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. + """ + + def __init__(self) -> None: + super().__init__() + self.on_model_init_start(set_barrier=False) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer_config.params = self.parameters() + optimizer = instantiate(optimizer_config) + scheduler_config.optimizer = optimizer + scheduler = instantiate(scheduler_config) + return optimizer, scheduler + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The training step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. + loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.no_grad() + def validation_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The validation step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. + loss (torch.Tensor): The total loss (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.inference_mode() + def forward(self, *args: Any, **kwargs: Any) -> Any: + """The computation graph for model inference. + + Args: + *args: Whatever you decide to pass into the forward method. + **kwargs: Keyword arguments are also possible. + + Return: + Your model's output. + """ + raise NotImplementedError + + def on_model_init_start(self, set_barrier=False) -> None: + return + + def on_model_init_end(self, set_barrier=False) -> None: + return + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + """The model preparation before the training is launched + + Args: + memory_format (torch.memory_format): Memory format of the model. + """ + pass + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + pass + + def on_after_backward(self, iteration: int = 0) -> None: + """Hook after loss.backward() is called. + + This method is called immediately after the backward pass, allowing for custom operations + or modifications to be performed on the gradients before the optimizer step. + + Args: + iteration (int): Current iteration number. + """ + pass diff --git a/cosmos_predict1/utils/parallel_state_helper.py b/cosmos_predict1/utils/parallel_state_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f531ab00c9d45a7dbf5015a43147bf635d72c5ec --- /dev/null +++ b/cosmos_predict1/utils/parallel_state_helper.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state + + +def is_tp_cp_pp_rank0(): + return ( + parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_pipeline_model_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ) diff --git a/cosmos_predict1/utils/scheduler.py b/cosmos_predict1/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d344b990dc55920947645b52d70e201c164f86d7 --- /dev/null +++ b/cosmos_predict1/utils/scheduler.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch + + +class WarmupLambdaLR(torch.optim.lr_scheduler.LambdaLR): + def __init__(self, optimizer, warmup, last_epoch=-1, verbose=False): + # Define the lambda function based on the warmup period + self.warmup = warmup + + def lr_lambda(epoch): + # Increase lr linearly for the first 'warmup' epochs + if epoch < warmup: + return float(epoch + 1) / warmup + # After 'warmup' epochs, keep lr constant + return 1.0 + + # Initialize the parent class with the generated lr_lambda + super(WarmupLambdaLR, self).__init__(optimizer, lr_lambda, last_epoch, verbose) + + +# cosine lr decay scheduler with warmup from https://github.com/karpathy/nanoGPT/blob/master/train.py#L228 +class WarmupCosineLR(torch.optim.lr_scheduler.LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_iters: int, + lr_decay_iters: int, + min_lr: float, + last_epoch: int = -1, + ): + self.warmup_iters = warmup_iters + self.lr_decay_iters = lr_decay_iters + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + # 1) linear warmup for warmup_iters steps + if self.last_epoch < self.warmup_iters: + return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs] + # 2) if it > lr_decay_iters, return min learning rate + if self.last_epoch > self.lr_decay_iters: + return [self.min_lr for _ in self.base_lrs] + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (self.last_epoch - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return [self.min_lr + coeff * (base_lr - self.min_lr) for base_lr in self.base_lrs] diff --git a/cosmos_predict1/utils/trainer.py b/cosmos_predict1/utils/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca0b8349a0dc61dcedf2e603a3fbdf231d1badc --- /dev/null +++ b/cosmos_predict1/utils/trainer.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import signal + +import torch +import torch.distributed as dist +import torch.utils.data +from megatron.core import parallel_state + +from cosmos_predict1.utils import callback, distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.lazy_config import LazyConfig, instantiate +from cosmos_predict1.utils.model import Model + + +class Trainer: + """The base trainer class. + + All trainers should inherit Trainer. It contains the basic functionality for model training + (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), + mixed-precision training (fp16/bf16). + + Attributes: + checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. + training_timer (misc.Timer): Timer object to time code blocks and functions. + """ + + def __init__(self, config): + """Constructor of the trainer. + + Args: + config (Config): The config object for the codebase. + """ + super().__init__() + self.config = config + # Set up the distributed computing environment. + with misc.timer("init_distributed"): + distributed.init() + # Set up parallel states. + if hasattr(config.model, "context_parallel_size"): + if config.model_parallel.context_parallel_size > 1: + raise ValueError( + "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " + "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." + ) + else: + log.critical( + "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." + ) + config.model_parallel.context_parallel_size = config.model.context_parallel_size + parallel_state.initialize_model_parallel( + pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, + tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, + context_parallel_size=config.model_parallel.context_parallel_size, + ) + # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. + # It is not part of the original `parallel_state` API, so we need to set it manually. + parallel_state.sequence_parallel = config.model_parallel.sequence_parallel + if parallel_state.sequence_parallel: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Create the local job directory, save the config file, and pipe to a local log. + if distributed.is_rank0(): + os.makedirs(config.job.path_local, exist_ok=True) + # Save the config as .pkl for reproducibility. + LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") + # Save the config as .yaml for reading or parsing experiment hyperparameters. + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + dist.barrier() + log.init_loguru_file(f"{config.job.path_local}/stdout.log") + if distributed.is_rank0(): + # Print important environment variables and the effective config. + log.info("Config:\n" + config.pretty_print(use_color=True)) + misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) + # Set the random seed. If multi-GPU, different ranks are set with different seeds. + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + # Initialize cuDNN. + torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic + torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark + # Floating-point precision settings. + torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True + # Initialize the callback functions. + self.callbacks = callback.CallBackGroup(config=config, trainer=self) + # Initialize the model checkpointer. + if config.checkpoint.type is None: + self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + self.checkpointer: Checkpointer = instantiate( + config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks + ) + # Initialize the timer for speed benchmarking. + self.training_timer = misc.TrainingTimer() + # Send a TimeoutError if a training step takes over timeout_period seconds. + signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore + + def train( + self, + model: Model, + dataloader_train: torch.utils.data.DataLoader, + dataloader_val: torch.utils.data.DataLoader, + ) -> None: + """The training function. + + Args: + model (Model): The PyTorch model. + dataloader_train (torch.utils.data.DataLoader): The training data loader. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + """ + # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. + model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore + model.on_train_start(self.config.trainer.memory_format) + + # Initialize the optimizer, scheduler, and grad_scaler. + self.callbacks.on_optimizer_init_start() + optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) + grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) + self.callbacks.on_optimizer_init_end() + # Load the model checkpoint and get the starting iteration number. + iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) + grad_accum_iter = 0 + log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + if self.config.trainer.distributed_parallelism == "ddp": + # Create a DDP model wrapper. + model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) + elif self.config.trainer.distributed_parallelism == "fsdp": + model_ddp = model + else: + raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + log.info("Starting training...") + self.callbacks.on_train_start(model, iteration=iteration) + # Initial validation. + if self.config.trainer.run_validation and iteration == 0: + self.validate(model, dataloader_val, iteration=iteration) + _end_training = False + while True: + dataloader_train_iter = iter(dataloader_train) + while True: + self.callbacks.on_before_dataloading(iteration) + with self.training_timer("dataloader_train"): + try: + data_batch = next(dataloader_train_iter) + for k in data_batch.keys(): + if torch.is_tensor(data_batch[k]): + data_batch[k] = data_batch[k].cuda() + except StopIteration: + break + self.callbacks.on_after_dataloading(iteration) + # If max_iter is reached, exit the training loop. + if iteration >= self.config.trainer.max_iter: + _end_training = True + break + # Move all tensors in the data batch to GPU device. + data_batch = misc.to(data_batch, device="cuda") + # The actual training step. + self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) + if not model.training: + model_ddp.train() + assert model_ddp.training, "model_ddp is not in training mode." + assert model.training, "model is not in training mode." + output_batch, loss, grad_accum_iter = self.training_step( + model_ddp, + optimizer, + scheduler, + grad_scaler, + data_batch, + iteration=iteration, + grad_accum_iter=grad_accum_iter, + ) + # Do the following when an actual optimizer (update) step has been made. + iteration += 1 + # Save checkpoint. + if iteration % self.config.checkpoint.save_iter == 0: + async_saving = getattr(self.config.checkpoint, "async_saving", True) + self.checkpointer.save( + model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving + ) + self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) + # Validation. + if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: + self.validate(model, dataloader_val, iteration=iteration) + # This iteration is successful; reset the timeout signal. + signal.alarm(self.config.trainer.timeout_period) + if _end_training: + break + log.success("Done with training.") + if iteration % self.config.checkpoint.save_iter != 0: + async_saving = getattr(self.config.checkpoint, "async_saving", True) + self.checkpointer.save( + model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving + ) + self.callbacks.on_train_end(model, iteration=iteration) + self.checkpointer.finalize() + distributed.barrier() + self.callbacks.on_app_end() + + def training_step( + self, + model_ddp: torch.nn.Module | distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + data: dict[str, torch.Tensor], + iteration: int = 0, + grad_accum_iter: int = 0, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: + """The training step. + + Args: + model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare + module, depending on whether distributed training is enabled or not. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + grad_accum_iter (int): Number of gradient accumulation iterations. + + Returns: + output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). + loss (torch.Tensor): The total loss of the training data batch. + """ + # Only let DDP sync gradient at the last iteration of the gradient accumulation window + with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): + with self.training_timer("forward"): + output_batch, loss = model_ddp.training_step(data, iteration) + self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) + with self.training_timer("backward"): + loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) + loss_scaled.backward() + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_after_backward() + else: + model_ddp.on_after_backward() + self.callbacks.on_after_backward(model_ddp, iteration=iteration) + grad_accum_iter += 1 + if grad_accum_iter == self.config.trainer.grad_accum_iter: + with self.training_timer("optimizer_step"): + self.callbacks.on_before_optimizer_step( + model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration + ) + grad_scaler.step(optimizer) + grad_scaler.update() + scheduler.step() + self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + else: + model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + optimizer.zero_grad(set_to_none=True) + grad_accum_iter = 0 + return output_batch, loss, grad_accum_iter + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + with ema.ema_scope(model, enabled=model.config.ema.enabled): + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, loss = model.validation_step(data_batch, iteration) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/utils/validator.py b/cosmos_predict1/utils/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..837564a84d68079afb90870ed7181937e2f4df73 --- /dev/null +++ b/cosmos_predict1/utils/validator.py @@ -0,0 +1,503 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import base64 +import itertools +import json +import os +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, List, Union + + +# from https://docs.python.org/3/howto/descriptor.html#validator-class +# For usage of hidden flag see the ModelParams class in apis/utils/model_params.py +class Validator(ABC): + # set name is called when the validator is created as class variable + # name is the name of the variable in the owner class, so here we create the name for the backing variable + def __set_name__(self, owner, name): + self.private_name = "_" + name + + def __get__(self, obj, objtype=None): + return getattr(obj, self.private_name, self.default) + + def __set__(self, obj, value): + value = self.validate(value) + setattr(obj, self.private_name, value) + + @abstractmethod + def validate(self, value): + pass + + def json(self): + pass + + +class MultipleOf(Validator): + def __init__(self, default: int, multiple_of: int, type_cast=None, hidden=False, tooltip=None): + if type(multiple_of) is not int: + raise ValueError(f"Expected {multiple_of!r} to be an int") + self.multiple_of = multiple_of + self.default = default + self.type_cast = type_cast + + # For usage of hidden flag see the ModelParams class in apis/utils/model_params.py + # if a parameter is hidden then probe() can't expose the param + # and the param can't be set anymore + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if self.type_cast: + try: + value = self.type_cast(value) + except ValueError: + raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") + + if value % self.multiple_of != 0: + raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") + + return value + + def get_range_iterator(self): + return itertools.count(0, self.multiple_of) + + def __repr__(self) -> str: + return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" + + def json(self): + return { + "type": MultipleOf.__name__, + "default": self.default, + "multiple_of": self.multiple_of, + "tooltip": self.tooltip, + } + + +class OneOf(Validator): + def __init__(self, default, options, type_cast=None, hidden=False, tooltip=None): + self.options = set(options) + self.default = default + self.type_cast = type_cast # Cast the value to this type before checking if it's in options + self.tooltip = tooltip + self.hidden = hidden + + def validate(self, value): + if self.type_cast: + try: + value = self.type_cast(value) + except ValueError: + raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") + + if value not in self.options: + raise ValueError(f"Expected {value!r} to be one of {self.options!r}") + + return value + + def get_range_iterator(self): + return self.options + + def __repr__(self) -> str: + return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" + + def json(self): + return { + "type": OneOf.__name__, + "default": self.default, + "values": list(self.options), + "tooltip": self.tooltip, + } + + +class HumanAttributes(Validator): + def __init__(self, default, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + # hard code the options for now + # we extend this to init parameter as needed + valid_attributes = { + "emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], + "race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], + "gender": ["male", "female"], + "age group": [ + "young", + "teen", + "adult early twenties", + "adult late twenties", + "adult early thirties", + "adult late thirties", + "adult middle aged", + "older adult", + ], + } + + def get_range_iterator(self): + # create a list of all possible combinations + l1 = self.valid_attributes["emotion"] + l2 = self.valid_attributes["race"] + l3 = self.valid_attributes["gender"] + l4 = self.valid_attributes["age group"] + all_combinations = list(itertools.product(l1, l2, l3, l4)) + return iter(all_combinations) + + def validate(self, value): + human_attributes = value.lower() + if human_attributes not in ["none", "random"]: + # In this case, we need for custom attribute string + + attr_string = human_attributes + for attr_key in ["emotion", "race", "gender", "age group"]: + attr_detected = False + for attr_label in self.valid_attributes[attr_key]: + if attr_string.startswith(attr_label): + attr_string = attr_string[len(attr_label) + 1 :] # noqa: E203 + attr_detected = True + break + + if attr_detected is False: + raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") + + return value + + def __repr__(self) -> str: + return f"HumanAttributes({self.private_name=} {self.hidden=})" + + def json(self): + return { + "type": HumanAttributes.__name__, + "default": self.default, + "values": self.valid_attributes, + "tooltip": self.tooltip, + } + + +class Bool(Validator): + def __init__(self, default, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, int): + value = value != 0 + elif isinstance(value, str): + value = value.lower() + if value in ["true", "1"]: + value = True + elif value in ["false", "0"]: + value = False + else: + raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") + elif not isinstance(value, bool): + raise TypeError(f"Expected {value!r} to be an bool") + + return value + + def get_range_iterator(self): + return [True, False] + + def __repr__(self) -> str: + return f"Bool({self.private_name=} {self.default=} {self.hidden=})" + + def json(self): + return { + "type": bool.__name__, + "default": self.default, + "tooltip": self.tooltip, + } + + +class Int(Validator): + def __init__(self, default, min=None, max=None, step=1, hidden=False, tooltip=None): + self.min = min + self.max = max + self.default = default + self.step = step + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, str): + value = int(value) + elif not isinstance(value, int): + raise TypeError(f"Expected {value!r} to be an int") + + if self.min is not None and value < self.min: + raise ValueError(f"Expected {value!r} to be at least {self.min!r}") + if self.max is not None and value > self.max: + raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") + return value + + def get_range_iterator(self): + iter_min = self.min if self.min is not None else self.default + iter_max = self.max if self.max is not None else self.default + return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) + + def __repr__(self) -> str: + return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": int.__name__, + "default": self.default, + "min": self.min, + "max": self.max, + "step": self.step, + "tooltip": self.tooltip, + } + + +class Float(Validator): + def __init__(self, default=0.0, min=None, max=None, step=0.5, hidden=False, tooltip=None): + self.min = min + self.max = max + self.default = default + self.step = step + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, str) or isinstance(value, int): + value = float(value) + elif not isinstance(value, float): + raise TypeError(f"Expected {value!r} to be float") + + if self.min is not None and value < self.min: + raise ValueError(f"Expected {value!r} to be at least {self.min!r}") + if self.max is not None and value > self.max: + raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") + return value + + def get_range_iterator(self): + iter_min = self.min if self.min is not None else self.default + iter_max = self.max if self.max is not None else self.default + return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) + + def __repr__(self) -> str: + return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": float.__name__, + "default": self.default, + "min": self.min, + "max": self.max, + "step": self.step, + "tooltip": self.tooltip, + } + + +class String(Validator): + def __init__(self, default="", min=None, max=None, predicate=None, hidden=False, tooltip=None): + self.min = min + self.max = max + self.predicate = predicate + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if self.min is not None and len(value) < self.min: + raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") + if self.max is not None and len(value) > self.max: + raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") + if self.predicate is not None and not self.predicate(value): + raise ValueError(f"Expected {self.predicate} to be true for {value!r}") + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": str.__name__, + "default": self.default, + "tooltip": self.tooltip, + } + + +class Path(Validator): + def __init__(self, default="", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if not os.path.exists(value): + raise ValueError(f"Expected {value!r} to be a valid path") + + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=}, {self.hidden=})" + + +class InputImage(Validator): + def __init__(self, default="", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + valid_formats = { + "JPEG": ["jpeg", "jpg"], + "JPEG2000": ["jp2"], + "PNG": ["png"], + "GIF": ["gif"], + "BMP": ["bmp"], + } + + valid_extensions = {vi: k for k, v in valid_formats.items() for vi in v} + + def validate(self, value): + _, ext = os.path.splitext(value).lower() + image_format = InputImage.valid_extensions[ext] + + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if not os.path.exists(value): + raise ValueError(f"Expected {value!r} to be a valid path") + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=} {self.hidden=})" + + def json(self): + return { + "type": InputImage.__name__, + "default": self.default, + "values": self.valid_formats, + "tooltip": self.tooltip, + } + + +class MeshFormat(Validator): + """ + Validator class for mesh formats. Valid inputs are either: + - single valid format such as "glb", "obj" + - or a list of valid formats such as "[obj, ply, usdz]" + """ + + valid_formats = {"glb", "usdz", "obj", "ply"} + + def __init__(self, default="glb", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value: str) -> Union[str, List[str]]: + try: + # Attempt to parse the input as a Python list + if value.startswith("[") and value.endswith("]"): + formats = ast.literal_eval(value) + if not all(fmt in MeshFormat.valid_formats for fmt in formats): + raise ValueError(f"Each item must be one of {MeshFormat.valid_formats}") + return formats + elif value in MeshFormat.valid_formats: + return value + else: + raise ValueError(f"Expected {value!r} to be one of {MeshFormat.valid_formats} or a list of them") + except (SyntaxError, ValueError) as e: + # Handle case where the input is neither a valid single format nor a list of valid formats + raise ValueError(f"Invalid format specification: {value}. Error: {str(e)}") + + def __repr__(self) -> str: + return f"MeshFormat(default={self.default}, hidden={self.hidden})" + + def json(self): + return { + "type": MeshFormat.__name__, + "default": self.default, + "values": self.valid_formats, + "tooltip": self.tooltip, + } + + +class JsonDict(Validator): + """ + JSON stringified version of a python dict. + Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' + """ + + def __init__(self, default="", hidden=False): + self.default = default + self.hidden = hidden + + def validate(self, value): + if not value: + return {} + try: + dict = json.loads(value) + return dict + except json.JSONDecodeError as e: + raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") + + def __repr__(self) -> str: + return f"Dict({self.default=} {self.hidden=})" + + +class BytesIOType(Validator): + """ + Validator class for BytesIO. Valid inputs are either: + - bytes + - objects of class BytesIO + - str which can be successfully decoded into BytesIO + """ + + def __init__(self, default=None, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value: Any) -> BytesIO: + if isinstance(value, str): + try: + # Decode the Base64 string + decoded_bytes = base64.b64decode(value) + # Create a BytesIO stream from the decoded bytes + return BytesIO(decoded_bytes) + except (base64.binascii.Error, ValueError) as e: + raise ValueError(f"Invalid Base64 encoded string: {e}") + elif isinstance(value, bytes): + return BytesIO(value) + elif isinstance(value, BytesIO): + return value + else: + raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") + + def __repr__(self) -> str: + return f"BytesIOValidator({self.default=}, {self.hidden=})" + + def json(self): + return { + "type": BytesIO.__name__, + "default": self.default, + "tooltip": self.tooltip, + } diff --git a/cosmos_predict1/utils/visualize/video.py b/cosmos_predict1/utils/visualize/video.py new file mode 100644 index 0000000000000000000000000000000000000000..89fa021ccf8e85aa177f0f30293791c68531d8d0 --- /dev/null +++ b/cosmos_predict1/utils/visualize/video.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO, Any, Union + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image as PILImage +from torch import Tensor + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.easy_io import easy_io + +try: + import ffmpegcv +except Exception as e: # ImportError cannot catch all problems + log.info(e) + ffmpegcv = None + + +def save_video(grid, video_name, fps=30): + grid = (grid * 255).astype(np.uint8) + grid = np.transpose(grid, (1, 2, 3, 0)) + with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: + for frame in grid: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + writer.write(frame) + + +def save_img_or_video(sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24) -> None: + """ + Save a tensor as an image or video file based on shape + + Args: + sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range. + save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object. + fps (int): Frames per second for video. Default is 24. + """ + assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor" + assert isinstance(save_fp_wo_ext, str) or hasattr( + save_fp_wo_ext, "write" + ), "save_fp_wo_ext must be a string or file-like object" + + if torch.is_floating_point(sample_C_T_H_W_in01): + sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1) + else: + assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor" + sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255) + + if sample_C_T_H_W_in01.shape[1] == 1: + save_obj = PILImage.fromarray( + rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8), + mode="RGB", + ) + ext = ".jpg" if isinstance(save_fp_wo_ext, str) else "" + easy_io.dump( + save_obj, + f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, + file_format="jpg", + format="JPEG", + quality=85, + ) + else: + save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8) + ext = ".mp4" if isinstance(save_fp_wo_ext, str) else "" + easy_io.dump( + save_obj, + f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, + file_format="mp4", + format="mp4", + fps=fps, + ) diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf6e230528e0eca4899141db639951f9dab37ea9 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,3 @@ +### Datasets directory + +Datasets used to post-train cosmos models will be saved in this directory. diff --git a/gui/.editorconfig b/gui/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..ea4a729985c1964bf7f567d8fd91a4c003348c8e --- /dev/null +++ b/gui/.editorconfig @@ -0,0 +1,16 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_style = tab +indent_size = 4 +trim_trailing_whitespace = true +max_line_length = 140 + +[*.md] +trim_trailing_whitespace = false + +[*.{clangd,nix,yml}] +indent_style = space +indent_size = 2 diff --git a/gui/.envrc b/gui/.envrc new file mode 100644 index 0000000000000000000000000000000000000000..0d1a669b3d0d817a52eb7e0fe9bf8ba6ed3caec8 --- /dev/null +++ b/gui/.envrc @@ -0,0 +1,5 @@ +strict_env +watch_file ./*.nix +use flake +layout python + diff --git a/gui/.github/workflows/main.yml b/gui/.github/workflows/main.yml new file mode 100644 index 0000000000000000000000000000000000000000..654e1c3215c690580f512f6fd9d6110d342cedf9 --- /dev/null +++ b/gui/.github/workflows/main.yml @@ -0,0 +1,155 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build_linux: + name: Build on linux systems + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-24.04 + cuda: "12.8" + arch: 120 + python: "3.12" + - os: ubuntu-24.04 + cuda: "12.8" + arch: 100 + python: "3.12" + - os: ubuntu-24.04 + cuda: "12.6" + arch: 89 + python: "3.12" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 86 + python: "3.11" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 75 + python: "3.10" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 70 + python: "3.9" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 61 + python: "3.8" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 53 + python: "3.7" + - os: ubuntu-22.04 + cuda: "11.7" + arch: 37 + python: "3.7" + env: + build_dir: "build" + config: "Release" + TCNN_CUDA_ARCHITECTURES: ${{ matrix.arch }} + steps: + - name: Install dependencies + run: sudo apt-get update && sudo apt-get install build-essential python3-dev libglfw3-dev libglew-dev libxinerama-dev libxcursor-dev libxi-dev + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - run: pip install -r requirements.txt + - name: Install CUDA + env: + cuda: ${{ matrix.cuda }} + run: ./dependencies/tiny-cuda-nn/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_ubuntu.sh + shell: bash + - name: Install Vulkan SDK + uses: humbletim/install-vulkan-sdk@c2aa128094d42ba02959a660f03e0a4e012192f9 + - name: CMake + run: cmake . -B ${{ env.build_dir }} ${{ matrix.cmake_flags }} -DCMAKE_BUILD_TYPE=${{ env.config }} + - name: Build + working-directory: ${{ env.build_dir }} + run: cmake --build . --target all --verbose -j `nproc` + + build_windows: + name: Build on Windows + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: windows-2025 + visual_studio: "Visual Studio 17 2022" + cuda: "12.8.0" + arch: 120 + python: "3.12" + recommended_gpus: "RTX-5000" + - os: windows-2025 + visual_studio: "Visual Studio 17 2022" + cuda: "12.6.3" + arch: 89 + python: "3.12" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 86 + python: "3.11" + recommended_gpus: "RTX-3000-and-4000" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 75 + python: "3.10" + recommended_gpus: "RTX-2000" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 70 + python: "3.9" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 61 + python: "3.8" + recommended_gpus: "GTX-1000" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 53 + python: "3.7" + - os: windows-2019 + visual_studio: "Visual Studio 16 2019" + cuda: "11.5.1" + arch: 37 + python: "3.7" + env: + build_dir: "build" + config: "Release" + TCNN_CUDA_ARCHITECTURES: ${{ matrix.arch }} + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - run: pip install -r requirements.txt + - name: Install CUDA + env: + cuda: ${{ matrix.cuda }} + visual_studio: ${{ matrix.visual_studio }} + shell: powershell + run: .\dependencies\tiny-cuda-nn\dependencies\cuda-cmake-github-actions\scripts\actions\install_cuda_windows.ps1 + - name: Install Vulkan SDK + uses: humbletim/install-vulkan-sdk@c2aa128094d42ba02959a660f03e0a4e012192f9 + - name: CMake + run: cmake . -B ${{ env.build_dir }} ${{ matrix.cmake_flags }} -G "${{ matrix.visual_studio }}" -A x64 + - name: Build + working-directory: ${{ env.build_dir }} + run: cmake --build . --config ${{ env.config }} --target ALL_BUILD --verbose diff --git a/gui/.gitignore b/gui/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..48d2f86ca4393cf2fa2d8720b75200880741bda9 --- /dev/null +++ b/gui/.gitignore @@ -0,0 +1,27 @@ +/.cache +/.vscode +/.vs +/.direnv +/build* +/external +/figures +/out +/logs +/results +/tmp +/venv +/video +/rtc +/outputs +/gen3c-gui +/*.dll +/*.so +/*.so.* +/*.exe +/*.json +*.ingp +*.msgpack +*.training +__pycache__ +.DS_Store +imgui.ini diff --git a/gui/.gitmodules b/gui/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..c37ecc532395a13997090f6c7fdb95e182b349bb --- /dev/null +++ b/gui/.gitmodules @@ -0,0 +1,27 @@ +[submodule "dependencies/pybind11"] + path = dependencies/pybind11 + url = https://github.com/Tom94/pybind11 +[submodule "dependencies/glfw"] + path = dependencies/glfw + url = https://github.com/Tom94/glfw +[submodule "dependencies/args"] + path = dependencies/args + url = https://github.com/Taywee/args +[submodule "dependencies/tinylogger"] + path = dependencies/tinylogger + url = https://github.com/Tom94/tinylogger +[submodule "dependencies/imgui"] + path = dependencies/imgui + url = https://github.com/ocornut/imgui.git +[submodule "dependencies/dlss"] + path = dependencies/dlss + url = https://github.com/NVIDIA/DLSS +[submodule "dependencies/OpenXR-SDK"] + path = dependencies/OpenXR-SDK + url = https://github.com/KhronosGroup/OpenXR-SDK.git +[submodule "dependencies/zlib"] + path = dependencies/zlib + url = https://github.com/Tom94/zlib +[submodule "dependencies/fmt"] + path = dependencies/fmt + url = https://github.com/fmtlib/fmt diff --git a/gui/README.md b/gui/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a9f94d9185c739601064b605df6af670869aa8ce --- /dev/null +++ b/gui/README.md @@ -0,0 +1,169 @@ +# GEN3C Graphical User Interface + +GEN3C's GUI helps visualize and author novel camera trajectories to be generated by GEN3C. +The GUI runs on your local machine, while the actual inference takes place either on the same or a remote machine. +This repository contains all code needed: model loading & inference, inference server, and local client GUI. + +
+ GEN3C interactive GUI +
+ + +## Starting the GEN3C inference server + +On the machine that will run inference, start by following the general installation instruction of GEN3C: [INSTALL.md](../INSTALL.md). +Then, install a few additional dependencies for the inference server: + +```bash +conda activate cosmos-predict1 +cd GEN3C/gui +pip install -r ./requirements.txt +``` + +Finally, start the inference server while optionally setting some parameters via `GEN3C_*` environment variables: + +```bash +# If model checkpoints were not downloaded to `GEN3C/checkpoints`, set the paths to the +# checkpoints directory: +# export GEN3C_CKPT_PATH="/path/to/checkpoints" +# Set if you would like to control the number of GPUs used by the inference server. +# By default, it will use as many as are available. +# export GEN3C_GPU_COUNT=1 + +CUDA_HOME=$CONDA_PREFIX fastapi dev --no-reload ./api/server.py --host 0.0.0.0 +``` + +It may take a while to load the model weights. The server is ready when "Uvicorn running on ..." is printed. + + +### SSH tunnel + +If the inference server is running on a remote machine, you may need to open an SSH tunnel on your local machine: + +```bash +# Usage: +# ssh -NL :: + +# Example 1: bind port 8000 of to your local port 8000 +ssh -NL 8000:localhost:8000 + +# Example 2: if is only accessible through , +# bind port 8000 of to your local port <8000>, going through . +ssh -NL 8000::8000 +``` + + +## Starting the GEN3C GUI on your local machine + +**Pre-requisites**: the GUI was written with CUDA, and therefore requires an NVIDIA GPU as well as the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) (version 11 or above). + +On your local machine, clone this repository *including submodules* (`--recursive`) and enter the `gui` subdirectory: + +```bash +git clone --recursive https://github.com/nv-tlabs/GEN3C +cd GEN3C/gui +``` + +Then, build the GUI: + +```bash +cmake . -B build -DCMAKE_BUILD_TYPE=RelWithDebInfo +cmake --build build --config RelWithDebInfo -j +``` + +Then, install the GUI's Python dependencies. Note that the Cosmos dependencies do not need to be installed here, since inference is running in the separate environment that was setup above. + +```bash +# In GEN3C/gui +pip install -r ./requirements.txt +``` + +Finally, start the GUI. It should automatically connect to the inference server that you started above. + +```bash +python api/client.py +``` + +> [!WARNING] +> The server was not tested for concurrent usage by multiple users. Only connect one client to each server instance. + + +## Using the GEN3C GUI + +### Seeding the model + +The model can be easily seeded (initialized) with an image or one of our pre-processed dynamic video examples. +Simply drag & drop the image or pre-processed folder onto the GUI window to trigger seeding. Alternatively, you can specify the path in the "Seeding" section of the right-hand window, then click "Seed". + +The seeding data is uploaded to the server, which initializes the Gen3C 3D cache. +When seeding from a single image, depth is automatically estimated using an off-the-shelf model. The estimated depth is downloaded back to the client in order to display the image as a 3D point cloud in the viewport. + +### Authoring a camera path + +Once the model is seeded, the camera trajectory is initialized with the camera pose and intrinsics estimated from the seeding data. +Using the left-hand window, you can then tweak or replace the camera trajectory in order to generate the scene from novel viewpoints. + +We explain the main camera editing features from top to bottom: +- "Record camera path": when enabled, the camera movement in the viewport will be saved as the camera trajectory in real time. +- "Clear": clear the current camera path, starting from scratch. +- "Init from views": re-initialize the camera path from the seeding data. +- "Load" / "Save": load or save the camera trajectory in JSON format from / to the specified path. +- Keyframe manipulation: + - "Add from cam": add the current viewport camera pose to the camera path. + - "Split": add a new keyframe at the current point along the camera path, as specified by the "Camera path time" slider below. + - `|<`: go to the first keyframe. + - `<` go to the previous keyframe. + - "Read": set the viewport camera to the current camera along the path. + - `>|`: go to the next keyframe. + - `>` go to the last keyframe. + - "Dup": duplicate the current keyframe. + - "Del": delete the current keyframe. + - "Set": set the current keyframe to the current viewport camera. +- Keyframe editing: + - Individual keyframes can be tweaked directly in 3D from the viewport using the red / green / blue gizmo. Select which keyframe to edit using the "Camera path time" slider. + - "Translation" / "Rotation": tweak the position or orientation of the camera at the current keyframe. + - "Local" / "World": edit in local or global coordinate space. + - "Loop path": when selected, the path is made looping. +- Camera path playback: + - "Start": seek to the start of the path. + - "Rev": start playing the path in reverse. + - "Play": start playing the path normally. + - "End": seed to the end of the path. + - "Playback speed": controls how fast the path will play. Note this is unrelated to the framerate or speed at which the video will eventually be generated. + - "Camera path time": seek to specific points along the camera path. The selected point determines which keyframe will be edited when using the buttons above. +- Intrinsics editing: + - "Field of view": field of view in degrees of the current keyframe. Changing this value over time can be used to "zoom in" or out. + - "Apply to all keyframes": apply the current FoV value to all keyframes. +- "Batch keyframe editing": use this section to edit multiple keyframes at once. This is useful when a trajectory has many keyframes, e.g. after seeding from a video or using "Record camera path". +- "Advanced camera path settings": used to control the path interpolation smoothness. + +A preview of the scene from the current camera viewpoint along the trajectory is shown in the middle of the left-hand window. + +> [!NOTE] +> For convenience, you can also drag & drop a path that was previously authored and saved as JSON using the "Save" button above directly onto the viewport to load it. + + +### Starting inference + +Once a camera trajectory has been authored, Gen3C can generate a video based on the seeding data and the camera trajectory. + +The video settings are located under "Video generation" in the left-hand window: +- "**Generate video**": start inference with Gen3C on the server! +- "Visualize rendered 3D cache": include a preview of the rendered Gen3C 3D cache at each frame in the output video. +- "Add Gen3C keyframes to viewport after inference": once inference is complete, add the last frame of each batch of 121 frames to the viewport for preview. +- "Video file": path to the output video file, where it will be saved once inference is complete. Supports Python's `strftime()` [format codes](https://docs.python.org/3.12/library/datetime.html#strftime-and-strptime-format-codes). +- "Duration": duration of the video to generate, in seconds. +- "FPS": framerate of the video to render. +- "Resolution": fixed based on the capability of the model. +- "Export cameras": export the rendered camera trajectory, using a more portable JSON-based format. + +> [!NOTE] +> Depending on the hardware used, inference can take a while. Once inference is complete, the resulting video will be automatically downloaded, written to disk, and opened with your default video player. + +**Video duration**: when setting the duration and framerate of the video to generate, please keep in mind that the model will always generate multiples of 121 frames. +The default duration and framerate are set to correspond to one batch of 121 frames. +Non-multiple durations will be automatically trimmed down to the requested duration, but the frames will be generated regardless. + +Note that when seeding from a dynamic video, the frame count of the generated video should match the frame count of the input video for best results. + +**Resetting the 3D cache**: after generating a video, Gen3C's 3D cache may have been updated with new keyframes on the inference server. If you would like to discard the generated results from the cache, simply click the 'Seed' button again to reset the cache. diff --git a/gui/api/api_types.py b/gui/api/api_types.py new file mode 100644 index 0000000000000000000000000000000000000000..7e753bd9ea60e8c7fc2f4b55b8e48e669d2b52b0 --- /dev/null +++ b/gui/api/api_types.py @@ -0,0 +1,474 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass, asdict +from enum import Enum +import os + +# Enable OpenEXR support in OpenCV (disabled by default, +# see https://github.com/opencv/opencv/issues/21326). +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + +import numpy as np + +from encoding import CompressionFormat, IMAGE_COMPRESSION_FORMATS, compress_images, decompress_buffer, pad_or_trim_array, pad_or_trim_encoded_buffers + + +@dataclass(kw_only=True) +class RequestBase: + request_id: str + + # Shape: [batch, 3, 4] + cameras_to_world: np.ndarray + # Absolute horizontal and vertical focal lengths in pixels. + # They are expressed w.r.t. the `resolutions` field (pixel count and aspect ratio). + # Shape: [batch, 2] + focal_lengths: np.ndarray + # Relative horizontal and vertical principal point (e.g. [0.5, 0.5]) + # Shape: [batch, 2] + principal_points: np.ndarray + # (width, height) in number of pixels. + # Automatically set from the `self.images` field, if any. + # Shape: [batch, 2] + resolutions: np.ndarray = None + + # Number of frames in the original request, without padding. + # Set automatically when calling `pad_to_frame_count()`. + frame_count_without_padding: int | None = None + + + def __post_init__(self): + if hasattr(self, "images"): + if self.resolutions is None: + # Set resolution from the images + self.resolutions = np.tile([[self.images.shape[2], self.images.shape[1]]], + (len(self), 1)) + else: + # Check resolution against the images + assert np.all(self.resolutions == (self.images.shape[2], self.images.shape[1])) + elif self.resolutions is None: + raise ValueError("Missing value `resolutions`") + + n = len(self) + assert self.cameras_to_world.shape == (n, 3, 4) + assert self.focal_lengths.shape == (n, 2) + assert self.principal_points.shape == (n, 2) + assert self.resolutions.shape == (n, 2) + + def world_to_cameras(self) -> np.ndarray: + c2w_complete = np.zeros((self.cameras_to_world.shape[0], 4, 4), dtype=self.cameras_to_world.dtype) + c2w_complete[:, :3, :] = self.cameras_to_world + c2w_complete[:, 3, 3] = 1.0 + return np.linalg.inv(c2w_complete) + + def intrinsics_matrix(self, for_resolutions: np.ndarray | None) -> np.ndarray: + """Returns a batched intrinsics matrix [batch, 3, 3] following the + format used by the Gen3C codebase.""" + result = np.zeros((len(self), 3, 3)) + # Focal length is already absolute + result[:, 0, 0] = self.focal_lengths[:, 0] + result[:, 1, 1] = self.focal_lengths[:, 1] + # Note: convert from relative to absolute principal point + result[:, 0, 2] = self.principal_points[:, 0] * self.resolutions[:, 0] + result[:, 1, 2] = self.principal_points[:, 1] * self.resolutions[:, 1] + result[:, 2, 2] += 1 + + if for_resolutions is not None: + # Resize intrinsics to match the new given resolutions + assert for_resolutions.shape == self.resolutions.shape + result[:, 0, :] *= (for_resolutions[:, 0, None] / self.resolutions[:, 0, None]) + result[:, 1, :] *= (for_resolutions[:, 1, None] / self.resolutions[:, 1, None]) + + return result + + def resolution(self) -> tuple[int, int]: + """Resolution of the first image result in pixels as (width, height).""" + return self.resolutions[0, 0], self.resolutions[0, 1] + + def __len__(self): + return self.cameras_to_world.shape[0] + + + def trim_to_original_frame_count(self, override_frame_count: int | None = None) -> None: + """ + Drop padding entries in order to match the original frame count. + """ + frame_count = override_frame_count or self.frame_count_without_padding + print(f"Trimming {type(self).__name__} from {len(self)} back to original frame count {frame_count}.") + if frame_count is None: + return + self._adjust_frame_count(frame_count) + + def pad_to_frame_count(self, n_frames: int) -> None: + """ + Add padding entries in order to match the desired frame count. + Also records the current frame count as the original frame count. + """ + self.frame_count_without_padding = len(self) + print(f"Padding {type(self).__name__} from {self.frame_count_without_padding} to {n_frames}.") + self._adjust_frame_count(n_frames) + + + def _adjust_frame_count(self, n_frames: int) -> None: + """ + Updates all fields to match the desired frame count. + + If it is higher than the current frame count, the last entry of each field is repeated. + If it is lower, entries are dropped (from the end). + """ + self.cameras_to_world = pad_or_trim_array(self.cameras_to_world, n_frames) + self.focal_lengths = pad_or_trim_array(self.focal_lengths, n_frames) + self.principal_points = pad_or_trim_array(self.principal_points, n_frames) + self.resolutions = pad_or_trim_array(self.resolutions, n_frames) + + +@dataclass(kw_only=True) +class SeedingRequest(RequestBase): + """ + Contains data required to seed the Gen3C model with initial data + to bootstrap generation. + + Note that the intrinsics (defined in `RequestBase`) are provided + as a suggestion and may be ignored by the model, if it is able to + estimate them from the images instead. + TODO: maybe provide a flag to indicate that the model should really + respect the provided values. + """ + # Values in [0, 1]. + # Shape: [batch, height, width, 3], float32 + images: np.ndarray + # Per-pixel depth for each of the given images. If not provided, it + # should be estimated automatically by the model. + # Shape: [batch, height, width], float32 + depths: np.ndarray | None + # Per-pixel mask for each of the given images. If not provided, it + # should be estimated automatically by the model. + # Shape: [batch, height, width], bool + masks: np.ndarray | None = None + + def __post_init__(self): + super().__post_init__() + n = len(self) + assert self.images.shape[0] == n and self.images.ndim == 4, self.images.shape + if self.depths is not None: + assert self.depths.shape[0] == n and self.depths.ndim == 3, self.depths.shape + if self.masks is not None: + assert self.masks.shape[0] == n and self.masks.ndim == 3, self.masks.shape + + + def _adjust_frame_count(self, n_frames: int) -> None: + raise RuntimeError("SeedingRequest: _adjust_frame_count() not supported") + + + def compress(self, + format_rgb: CompressionFormat = CompressionFormat.JPG, + format_depth: CompressionFormat | None = None, + format_mask: CompressionFormat | None = None) -> "CompressedSeedingRequest": + """Compress the images and depths as images and return a `CompressedSeedingRequest`.""" + images_compressed = compress_images(self.images, format_rgb, is_depth=False) + + format_depth = format_depth or CompressionFormat.EXR + depths_compressed = compress_images(self.depths, format_depth, + is_depth=True) + + format_mask = format_mask or CompressionFormat.NPZ + masks_compressed = compress_images(self.masks, format_mask, + is_bool=True) + + kwargs = asdict(self) + # Will be replaced automatically with placeholders of the right shape. + kwargs['images'] = None + kwargs['depths'] = None + kwargs['masks'] = None + return CompressedSeedingRequest( + images_compressed=images_compressed, + images_format=format_rgb, + depths_compressed=depths_compressed, + depths_format=format_depth, + masks_compressed=masks_compressed, + masks_format=format_mask, + **kwargs + ) + + +@dataclass(kw_only=True) +class CompressedSeedingRequest(SeedingRequest): + """ + Same as a `SeedingRequest`, but the image and depth buffers are + sent as compressed JPG / PNG streams instead of raw bytes. + They should be decompressed when received, before being used + as standard `SeedingRequest`s. + """ + # List of compressed images (as raw bytes) + images_compressed: list[bytes] + images_format: CompressionFormat + # List of compressed images (as raw bytes) + depths_compressed: list[bytes] | None + depths_format: CompressionFormat | None + # List of compressed masks (as raw bytes) + masks_compressed: list[bytes] | None + masks_format: CompressionFormat | None + + def __post_init__(self): + # Note: not calling parent checks because our image and depth + # fields are not actually usable as-is. + # super().__post_init__() + + # For convenience, auto-fill placeholder image and depth fields + assert (self.resolutions is not None) or (self.images is not None), \ + "CompressedSeedingRequest: at least one of resolutions or images must be provided" + + w, h = self.resolution() + if self.images is None: + self.images = np.empty((0, h, w, 3), dtype=np.float32) + if (self.depths is None) and (self.depths_compressed is not None): + self.depths = np.empty((0, h, w), dtype=np.float32) + if (self.masks is None) and (self.masks_compressed is not None): + self.masks = np.empty((0, h, w), dtype=np.bool) + + assert self.images.shape[0] == 0, \ + "CompressedSeedingRequest should not have any raw image data"\ + " in `self.images` upon construction." + + def decompress(self) -> None: + """Decompress the images and fill them in-place.""" + self.images = decompress_buffer(self.images_compressed, self.images_format) + self.depths = decompress_buffer(self.depths_compressed, self.depths_format, is_depth=True) + self.masks = decompress_buffer(self.masks_compressed, self.masks_format, is_bool=True) + + +@dataclass(kw_only=True) +class SeedingResult(RequestBase): + """ + Contains the result of a seeding request, + e.g. the depth maps for the seeding images that were estimated by the model + if not provided in the original request. + + Note: since the `depths` field would need to remain relatively high-precision + and lossless when compressed, we don't bother overring a compressed version + of `SeedingResult` for now. + """ + # Per-pixel depth for each of the given images. + # Shape: [batch, height, width] + depths: np.ndarray | None = None + + def __post_init__(self): + super().__post_init__() + n = len(self) + if self.depths is not None: + if self.depths.ndim == 4 and self.depths.shape[1] == 1: + # [batch, 1, height, width] -> [batch, height, width] + self.depths = self.depths.squeeze(1) + assert self.depths.shape[0] == n and self.depths.ndim == 3 + + @staticmethod + def from_request(req: SeedingRequest, fallback_depths: np.ndarray | None) -> "SeedingResult": + resolutions = req.resolutions + if fallback_depths is not None: + resolutions[:, 0] = fallback_depths.shape[2] + resolutions[:, 1] = fallback_depths.shape[1] + + return SeedingResult( + request_id=req.request_id, + cameras_to_world=req.cameras_to_world, + focal_lengths=req.focal_lengths, + principal_points=req.principal_points, + resolutions=resolutions, + depths=None if (req.depths is not None) else fallback_depths, + ) + + def _adjust_frame_count(self, n_frames: int) -> None: + raise RuntimeError("SeedingRequest: _adjust_frame_count() not supported") + + +@dataclass(kw_only=True) +class InferenceRequest(RequestBase): + # Time points for each frame to generate (useful when there's scene dynamics). + # May be ignored by the model. + # Shape: [batch,] + timestamps: np.ndarray + + # Framerate of the generated video (frames per second). Only applicable + # when requesting multiple frames at once. + # May be ignored by the model, or rounded to the nearest integer. + framerate: float = 30.0 + + # Whether to estimate and return depth for each frame in the result. + return_depths: bool = False + + # If inference results will be returned as a compressed video, use this + # encoding quality (0..10). + video_encoding_quality: int = 8 + + # Whether to include the rendered cache in the generated video (for debugging / visualization) + show_cache_renderings: bool = False + + def __post_init__(self): + super().__post_init__() + n = len(self) + assert self.timestamps.shape[0] == n and self.timestamps.ndim == 1, \ + f"Timestamps: expected shape ({n},), found: {self.timestamps.shape}" + assert len(self.focal_lengths) == n + assert len(self.principal_points) == n + assert len(self.resolutions) == n + + def _adjust_frame_count(self, n_frames: int) -> None: + super()._adjust_frame_count(n_frames) + self.timestamps = pad_or_trim_array(self.timestamps, n_frames) + + +@dataclass(kw_only=True) +class InferenceResult(RequestBase): + """ + Note that fields that are already included in the request are repeated here, + simply because the model may not have respected the request. + """ + # The model can use this field to indicate that multiple returned results + # are identical, so that the client can decide to skip adding them. + # It should be ignored if set to None. + result_ids: list[str | None] + + # Shape: [batch,] + timestamps: list[float] + # Shape: [batch, height, width, 3] + images: np.ndarray + # Shape: [batch, height, width] + depths: np.ndarray + + # Time it took to generate the whole batch of results, in milliseconds + runtime_ms: float + + def __post_init__(self): + super().__post_init__() + n = len(self) + assert self.timestamps.shape[0] == n and self.timestamps.ndim == 1, \ + f"Timestamps: expected shape ({n},), found: {self.timestamps.shape}" + assert self.images.ndim == 4 and self.images.shape[0] == n, self.images.shape + assert self.depths.ndim == 3 and self.depths.shape[0] == n, self.depths.shape + + + def _adjust_frame_count(self, n_frames: int) -> None: + super()._adjust_frame_count(n_frames) + self.timestamps = pad_or_trim_array(self.timestamps, n_frames) + + if self.images.shape[0] == 0: + # Fields are just placeholders (compressed request), leave them be. + return + + self.images = pad_or_trim_array(self.images, n_frames) + self.depths = pad_or_trim_array(self.depths, n_frames) + + + +@dataclass(kw_only=True) +class CompressedInferenceResult(InferenceResult): + """ + Same as a `InferenceResult`, but the image and depth buffers are + sent as compressed MP4 / EXR streams instead of raw bytes. + They should be decompressed when received, before being used + as standard `InferenceResult`s. + """ + # List of compressed images (as raw bytes) + images_compressed: list[bytes] + images_format: CompressionFormat + # List of compressed images (as raw bytes) + depths_compressed: list[bytes] | None + depths_format: CompressionFormat | None + + def __post_init__(self): + # Note: not calling parent checks because our image and depth + # fields are not actually usable as-is. + # super().__post_init__() + + # For convenience, auto-fill placeholder image and depth fields + assert (self.resolutions is not None) or (self.images is not None), \ + "CompressedInferenceResult: at least one of resolutions or images must be provided" + w, h = self.resolution() + if self.images is None: + self.images = np.empty((0, h, w, 3), dtype=np.float32) + if (self.depths is None) and (self.depths_compressed is not None): + self.depths = np.empty((0, h, w), dtype=np.float32) + + assert self.images.shape[0] == 0, \ + "CompressedInferenceResult should not have any raw image data" \ + " in `self.images` upon construction." + + if self.images_format == CompressionFormat.MP4: + assert len(self.images_compressed) == 1, \ + "CompressedInferenceResult: with an MP4 compressed result," \ + " there should be only one buffer (the compressed video)." + elif self.depths_compressed is not None: + assert len(self.depths_compressed) == len(self.images_compressed) + assert self.depths_format in IMAGE_COMPRESSION_FORMATS, \ + f"CompressedInferenceResult: depths_format should be an image format, found {self.depths_format}" + + + def _adjust_frame_count(self, n_frames: int) -> None: + super()._adjust_frame_count(n_frames) + self.images_compressed = pad_or_trim_encoded_buffers(self.images_compressed, self.images_format, n_frames) + self.depths_compressed = pad_or_trim_encoded_buffers(self.depths_compressed, self.depths_format, n_frames) + + + def decompress(self) -> None: + """Decompress the images and fill them in-place.""" + self.images = decompress_buffer(self.images_compressed, self.images_format, is_depth=False) + self.depths = decompress_buffer(self.depths_compressed, self.depths_format, is_depth=True) + + + def save_images(self, fname_or_directory: str) -> None: + """Save the compressed images to a file path or directory. + + If a full path is given, the path extension may be overriden based + on the compression format. + """ + fname_or_directory = os.path.realpath(fname_or_directory) + base, ext = os.path.splitext(fname_or_directory) + if not ext: + directory = fname_or_directory + base = "inference_result" + else: + directory = os.path.dirname(fname_or_directory) + base = os.path.splitext(os.path.basename(fname_or_directory))[0] + + os.makedirs(directory, exist_ok=True) + single = len(self.images_compressed) == 1 + for i, buf in enumerate(self.images_compressed): + image_path = os.path.join( + directory, + f"{base}.{self.images_format.value}" + if single else f"base_{i:05d}.{self.images_format.value}" + ) + with open(image_path, "wb") as f: + f.write(buf) + + +class RequestState(Enum): + """ + Note: by "request" we mean an inference request, not an HTTP request. + """ + REQUEST_PENDING = "Request pending" + REQUEST_SENT = "Request sent" + RESULT_PENDING = "Result pending" + COMPLETE = "Completed" + FAILED = "Created" + + +@dataclass(kw_only=True) +class PendingRequest: + request_id: str + state: RequestState + message: str = "" + task: asyncio.Task | None = None diff --git a/gui/api/client.py b/gui/api/client.py new file mode 100644 index 0000000000000000000000000000000000000000..b8650fe90ed60746ea47879a9615ebf212c11bc8 --- /dev/null +++ b/gui/api/client.py @@ -0,0 +1,890 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import code +from copy import deepcopy +from datetime import datetime +import glob +import os +from os.path import realpath, dirname, join +import pickle +import subprocess +import sys +import time + +import cv2 +import httpx +import numpy as np +import pyexr +import tqdm + +ROOT_DIR = realpath(dirname(dirname(__file__))) +DATA_DIR = join(ROOT_DIR, "data") + +sys.path.append(join(ROOT_DIR, "scripts")) + +# Search for pyngp in the build folder. +sys.path += [os.path.dirname(pyd) for pyd in glob.iglob(os.path.join(ROOT_DIR, "build*", "**/*.pyd"), recursive=True)] +sys.path += [os.path.dirname(pyd) for pyd in glob.iglob(os.path.join(ROOT_DIR, "build*", "**/*.so"), recursive=True)] + +import pyngp as ngp +from pyngp import tlog + +from api_types import SeedingRequest, CompressedSeedingRequest, SeedingResult, \ + InferenceRequest, InferenceResult, CompressedInferenceResult, \ + RequestState, PendingRequest +from httpx_utils import httpx_request +from v2v_utils import load_v2v_seeding_data, ensure_alpha_channel, srgb_to_linear + + + +def repl(testbed): + print("-------------------\npress Ctrl-Z to return to gui\n---------------------------") + code.InteractiveConsole(locals=locals()).interact() + print("------- returning to gui...") + + +def open_file_with_default_app(video_path: str) -> None: + """Open the saved video file with the default video player application.""" + try: + if sys.platform == "win32": + # Windows + os.startfile(video_path) + else: + # Avoid venv, etc interfering with the application that will open. + env = os.environ.copy() + for k in ("QT_QPA_PLATFORM_PLUGIN_PATH", "QT_QPA_FONTDIR", "LD_LIBRARY_PATH"): + if k in env: + del env[k] + if sys.platform == "darwin": + # macOS + subprocess.run(["open", video_path], check=True, env=env) + else: + # Linux, etc. + subprocess.run(["xdg-open", video_path], check=True, env=env) + except Exception as e: + tlog.error(f"Failed to open video file: {e}") + + +class Gen3cClient(): + def __init__( + self, + files: list[str], + host: str, + port: int, + width: int = 1920, + height: int = 1080, + vr: bool = False, + request_latency_ms: int = 100, + inference_resolution: tuple[int, int] = (1920, 1080), + # max_pending_requests: int = 2, + max_pending_requests: int = 1, + request_timeous_seconds: float = 1000, + seed_max_frames: int | None = None, + seed_stride: int = 1, + output_dir: str | None = None, + ): + self.url = f"http://{host}:{port}" + self.client_id = f"gen3c{os.getpid()}" + self.request_latency_ms = request_latency_ms + self.inference_resolution = inference_resolution + self.max_pending_requests = max_pending_requests + self.req_timeout_s = request_timeous_seconds + self.seed_max_frames = seed_max_frames + self.seed_stride = seed_stride + + testbed = ngp.Testbed(ngp.TestbedMode.Gen3c) + testbed.root_dir = ROOT_DIR + testbed.set_gen3c_cb(self.gui_callback) + testbed.file_drop_callback = self.file_drop_callback + + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + else: + output_dir = join(ROOT_DIR, "outputs") + testbed.gen3c_output_dir = output_dir + testbed.video_path = join(output_dir, "gen3c_%Y-%m-%d_%H-%M-%S.mp4") + + # --- Check metadata from server to ensure compatibility + testbed.reproject_visualize_src_views = False + testbed.render_aabb.min = np.array([-16.0, -16.0, -16.0]) + testbed.render_aabb.max = np.array([16.0, 16.0, 16.0]) + try: + tlog.info(f"Requesting metadata from server {host}:{port}") + metadata = self.request_metadata_sync() + testbed.render_aabb.min = np.array(metadata["aabb_min"]).astype(np.float32) + testbed.render_aabb.max = np.array(metadata["aabb_max"]).astype(np.float32) + testbed.aabb = ngp.BoundingBox(testbed.render_aabb.min, testbed.render_aabb.max) + testbed.gen3c_info = f"Connected to server {host}:{port}, model name: {metadata.get('model_name')}" + + model_inference_res: list[tuple[int, int]] | None = metadata.get("inference_resolution") + if model_inference_res is not None: + for supported_res in model_inference_res: + if tuple(supported_res) == self.inference_resolution: + break + else: + r = tuple(model_inference_res[0]) + tlog.warning(f"Client inference resolution {self.inference_resolution} is not" + f" supported by the inference server, adopting resolution {r} instead.") + self.inference_resolution = r + testbed.camera_path.render_settings.resolution = self.inference_resolution + + testbed.gen3c_inference_is_connected = True + testbed.gen3c_render_with_gen3c = True + + except httpx.ConnectError as e: + # The metadata-based setup happens only once at startup. Since we failed to + # get the metadata from the server, it's easier to just raise and exit here. + raise RuntimeError( + f"Connection error! Make sure the server was started at: {host}:{port}\n{e}" + ) from e + + testbed.camera_path.render_settings.fps = metadata.get("default_framerate") or 24.0 + self.min_frames_per_request: int = metadata.get("min_frames_per_request", 1) + self.max_frames_per_request: int = metadata.get("max_frames_per_request", 1) + if self.min_frames_per_request > 1: + # Set default render settings such that the model can generate it + # in a single batch exactly. + testbed.camera_path.default_duration_seconds = \ + self.min_frames_per_request / testbed.camera_path.render_settings.fps + testbed.camera_path.duration_seconds = testbed.camera_path.default_duration_seconds + + # Expected time that the model will take to generate each frame, in seconds + self.inference_time_per_frame: float = metadata.get("inference_time_per_frame", 0.0) + # Don't automatically request new frames all the time if inference is slow + testbed.gen3c_auto_inference &= (self.inference_time_per_frame < 1.0) + + self.seeding_pending: bool = False + self.model_requires_seeding: bool = metadata.get("requires_seeding", True) + if self.model_requires_seeding: + testbed.gen3c_info += "\nThis model requires seeding data." + + # Pick a sensible GUI resolution depending on arguments. + sw = width + sh = height + while sw * sh > 1920 * 1080 * 4: + sw = int(sw / 2) + sh = int(sh / 2) + + testbed.init_window(sw, sh) + if vr: + testbed.init_vr() + self.testbed: ngp.Testbed = testbed + + self.lens = ngp.Lens() + self.lens.mode = ngp.LensMode.Perspective + + self.client = httpx.AsyncClient() + + self.last_request_id: int = 0 + self.start_t: float = None + self.last_request_t: float = None + self.pending_requests: dict[str, PendingRequest] = {} + + # Handle files given as command-line arguments. + if files: + self.file_drop_callback(files) + + + + async def run(self): + testbed = self.testbed + + self.start_t = time.monotonic() + self.last_request_t = time.monotonic() + # TODO: any way to make the rendering itself async? (pybind11 support?) + while testbed.frame(): + # --- At each frame + if testbed.want_repl(): + repl(testbed) + + if self.model_requires_seeding and self.seeding_pending and self.testbed.gen3c_seed_path: + tlog.info(f"Loading seeding data with path: {self.testbed.gen3c_seed_path}") + # Load the seeding data. + seed_req = self.load_seeding_data(self.testbed.gen3c_seed_path) + if seed_req is not None: + self.adapt_view_to_cameras(seed_req.cameras_to_world) + # Send the seeding request over to the server (could be a slow upload). + self.send_seeding_request(seed_req) + self.seeding_pending = False + + # Give coroutines a chance to run (especially if there are pending HTTP requests). + # This is essentially a "yield". + # TODO: how can we sleep only for the minimum needed time? + # Probably we would need to request the `testbed`'s frame in an + # async way as well? Something like: + # await testbed.frame() + await asyncio.sleep(0.003 if self._transfer_in_progress() else 0.0001) + + # Check pending inference requests + self.get_request_results() + + # New inference request + # TODO: if there are too many pending requests, cancel the oldest one + # instead of continuing to wait. + now = time.monotonic() + if ((1000 * (now - self.last_request_t) > self.request_latency_ms) + and testbed.gen3c_auto_inference + and testbed.is_rendering + and len(self.pending_requests) < self.max_pending_requests): + + self.request_frames() + + def get_request_results(self): + to_remove = set() + for req_id, state in self.pending_requests.items(): + + if state.state in (RequestState.FAILED, RequestState.COMPLETE): + # Cleanup requests that are done one way or another + to_remove.add(req_id) + + elif state.state == RequestState.REQUEST_PENDING: + # Before checking the results, we wait for the inference request to have + # been received by the server at least. + self.testbed.gen3c_inference_info = f"Waiting for inference request {req_id} to be received by the server..." + continue + + elif state.state == RequestState.REQUEST_SENT: + # Server has received the inference request, we should now start checking results + def on_result_received(result: InferenceResult | None, + response: httpx.Response, failed: bool = False): + if failed: + tlog.error(f"Results request for inference {req_id} failed!\n" + f"{response.content}") + self.testbed.gen3c_inference_info = f"Error: {response.content}" + state.state = RequestState.FAILED + state.task = None + return + + if result is None: + # Result not ready yet, check again soon + state.state = RequestState.REQUEST_SENT + state.task = None + return + + # Actual result received! + assert isinstance(result, InferenceResult) + state.state = RequestState.COMPLETE + self.testbed.gen3c_inference_info = "" + + tlog.success(f"Received results {req_id}: took {result.runtime_ms:.1f} ms to generate.") + + need_frames = self.testbed.gen3c_display_frames or self.testbed.gen3c_save_frames + result.trim_to_original_frame_count() + if isinstance(result, CompressedInferenceResult): + # Save the compressed video straight to disk + video_path = datetime.now().strftime(self.testbed.video_path) + result.save_images(video_path) + tlog.success(f"[+] Wrote generated video to: {video_path}") + + tlog.info(f"Opening file with default application: {video_path}") + open_file_with_default_app(video_path) + if need_frames: + result.decompress() + + # Add all received frames to the viewer. + if self.testbed.gen3c_save_frames: + os.makedirs(self.testbed.gen3c_output_dir, exist_ok=True) + + view_ids = set(self.testbed.src_view_ids()) + for res_i in range(len(result)): + if not need_frames: + continue + + # Only display the result if we don't already have it shown + res_id = result.result_ids[res_i] + + if (res_id is not None) and (res_id in view_ids): + tlog.debug(f"Skipping result since id {res_id} is already displayed") + continue + + # Allow alpha channel to be omitted for faster transfers + image = ensure_alpha_channel(result.images[res_i, ...]) + + if self.testbed.gen3c_save_frames: + safe_res_id = (res_id or f"{res_i:04d}").replace(":", "_") + fname = join(self.testbed.gen3c_output_dir, + f"rgb_{safe_res_id}.exr") + pyexr.write(fname, image) + + fname = join(self.testbed.gen3c_output_dir, + f"depth_{safe_res_id}.exr") + pyexr.write(fname, result.depths[res_i, ...].astype(np.float32)) + tlog.success(f"[+] Wrote inference result to: {fname}") + + if self.testbed.gen3c_display_frames: + has_valid_depth = np.any(np.isfinite(result.depths[res_i, ...])) + if has_valid_depth: + self.testbed.add_src_view( + result.cameras_to_world[res_i, ...], + result.focal_lengths[res_i][0], + result.focal_lengths[res_i][1], + result.principal_points[res_i][0], + result.principal_points[res_i][1], + self.lens, + image, + result.depths[res_i, ...], + result.timestamps[res_i], + is_srgb=True, + ) + self.testbed.reset_accumulation(reset_pip=True) + tlog.info(f"Added {res_id}[{res_i}] to viewer") + else: + tlog.debug(f"Not adding {res_id}[{res_i}] to viewer because it has no valid depth." + " Only keyframes (last frame of each batch) typically have valid depth.") + + # Don't display more than 8 views at once by default to avoid + # slowing down the rendering too much. + self.set_max_number_of_displayed_views(8) + + tlog.debug(f"Checking results of request {req_id}...") + state.state = RequestState.RESULT_PENDING + state.task = self._get_inference_results(req_id, on_result_received) + + elif state.state == RequestState.RESULT_PENDING: + # We already sent a request to check on the results, let's wait until + # a response comes back (through the `on_result_received` cb). + if self.testbed.gen3c_inference_progress < 0: + # Only show the spinner if downloading the results hasn't started yet. + spinner = "|/-\\"[int(4 * time.time()) % 4] + self.testbed.gen3c_inference_info = f"[{spinner}] Waiting for server to complete inference..." + pass + + + for k in to_remove: + del self.pending_requests[k] + self.testbed.camera_path.rendering = len(self.pending_requests) > 0 + + # ---------- + + def request_metadata_sync(self) -> InferenceResult: + # Synchronous request (no need to `await`) + return httpx_request("get", self.url + "/metadata", timeout=self.req_timeout_s).json() + + + def request_frames(self, sync: bool = False) -> asyncio.Task | InferenceResult: + # The user wants a certain number of frames, but the model can only generate + # `self.min_frames_per_request` per request. Pad to get there. + n_desired_frames = int(np.ceil(self.testbed.camera_path.duration_seconds + * self.testbed.camera_path.render_settings.fps)) + n_frames_padded = max( + int(np.ceil(n_desired_frames / self.min_frames_per_request) * self.min_frames_per_request), + self.min_frames_per_request + ) + self.testbed.gen3c_inference_info = ( + f"Requesting {n_desired_frames} frames ({n_frames_padded} total with padding, " + f"model has min batch size {self.min_frames_per_request})." + ) + tlog.info(self.testbed.gen3c_inference_info) + # TODO: enforce `max_frames_per_request` from the server, too (with a clear error message) + now = time.monotonic() + + cameras_to_world = np.repeat(self.testbed.camera_matrix[None, ...], + repeats=n_desired_frames, axis=0) + + # By default, use the preview camera focal length. + # We assume square pixels, so horizontal and vertical focal lengths are equal. + default_focal_length = self.testbed.relative_focal_length * self.inference_resolution[self.testbed.fov_axis] + focal_lengths = np.array([default_focal_length] * n_desired_frames) + + match self.testbed.gen3c_camera_source: + case ngp.Gen3cCameraSource.Fake: + # --- Camera movement: fake based on fixed translation and rotation speeds + counter = np.arange(n_desired_frames)[..., None] + + if np.any(self.testbed.gen3c_rotation_speed != 0): + angles = counter * self.testbed.gen3c_rotation_speed[None, ...] + alphas = angles[:, 0] + betas = angles[:, 1] + gammas = angles[:, 2] + + # TODO: nicer way to build the rotation matrix + fake_rotation = np.tile(np.eye(3, 3)[None, ...], (n_desired_frames, 1, 1)) + fake_rotation[:, 0, 0] = np.cos(betas) * np.cos(gammas) + fake_rotation[:, 0, 1] = ( + np.sin(alphas) * np.sin(betas) * np.cos(gammas) + - np.cos(alphas) * np.sin(gammas) + ) + fake_rotation[:, 0, 2] = ( + np.cos(alphas) * np.sin(betas) * np.cos(gammas) + + np.sin(alphas) * np.sin(gammas) + ) + + fake_rotation[:, 1, 0] = np.cos(betas) * np.sin(gammas) + fake_rotation[:, 1, 1] = ( + np.sin(alphas) * np.sin(betas) * np.sin(gammas) + + np.cos(alphas) * np.cos(gammas) + ) + fake_rotation[:, 1, 2] = ( + np.cos(alphas) * np.sin(betas) * np.sin(gammas) + - np.sin(alphas) * np.cos(gammas) + ) + + fake_rotation[:, 2, 0] = -np.sin(betas) + fake_rotation[:, 2, 1] = np.sin(alphas) * np.cos(betas) + fake_rotation[:, 2, 2] = np.cos(alphas) * np.cos(betas) + + cameras_to_world[:, :3, :3] @= fake_rotation + + if np.any(self.testbed.gen3c_translation_speed != 0): + fake_translation = counter * self.testbed.gen3c_translation_speed[None, ...] + cameras_to_world[:, :, 3] += fake_translation + + case ngp.Gen3cCameraSource.Viewpoint: + # --- Camera movement: based on the current viewpoint + predicted movement + tlog.error("Not implemented: Gen3C camera movement source: Viewpoint") + return + + case ngp.Gen3cCameraSource.Authored: + # --- Camera movement: based on the current authored camera path + keyframes = [ + self.testbed.camera_path.eval_camera_path(t) + for t in np.linspace(0, 1, n_desired_frames, endpoint=True) + ] + cameras_to_world = [ + keyframe.m()[None, ...] + for keyframe in keyframes + ] + cameras_to_world = np.concatenate(cameras_to_world, axis=0) + + focal_lengths = np.stack([ + [ + ngp.fov_to_focal_length(self.inference_resolution[self.testbed.fov_axis], keyframe.fov) + ] * 2 + for keyframe in keyframes + ], axis=0) + + case _: + raise ValueError("Unsupported Gen3C camera movement source:", + self.testbed.gen3c_camera_source) + t0 = now - self.start_t + timestamps = [t0 + i * self.inference_time_per_frame + for i in range(n_desired_frames)] + + request_id = f"{self.client_id}:{self.last_request_id + 1}" + + tlog.debug(f"Creating new request {request_id}") + req = InferenceRequest( + request_id=request_id, + timestamps=np.array(timestamps), + cameras_to_world=cameras_to_world, + focal_lengths=focal_lengths, + principal_points=np.array([self.testbed.screen_center] * n_desired_frames), + resolutions=np.array([self.inference_resolution] * n_desired_frames), + framerate=self.testbed.camera_path.render_settings.fps, + # If we don't need to display the generated frames, we can save time + # by not estimating & downloading depth maps. + return_depths=self.testbed.gen3c_display_frames, + video_encoding_quality=self.testbed.camera_path.render_settings.quality, + show_cache_renderings=self.testbed.gen3c_show_cache_renderings, + ) + # Add any necessary padding to the request to match the server's batch size. + req.pad_to_frame_count(n_frames_padded) + + # Send an inference request to the server and add it to the + # list of pending requests. + self.request_frame(req, sync=sync) + + tlog.info("Waiting for inference results (this may take a while)...") + self.last_request_t = now + self.last_request_id += 1 + + + def request_frame(self, req: InferenceRequest, sync: bool = False) -> asyncio.Task | InferenceResult: + qp = "?sync=1" if sync else "" + url = self.url + "/request-inference" + qp + data = pickle.dumps(req) + + def req_done_cb(task_or_res: asyncio.Task | httpx.Response) -> None: + if sync: + res: httpx.Response = task_or_res + else: + try: + res: httpx.Response = task_or_res.result() + except RuntimeError as e: + tlog.error(f"Inference request task failed!\n{e}") + + if res.status_code != 202: + tlog.error(f"Inference request failed!\n{res.content}") + + if sync: + return pickle.loads(res.content) + else: + if req.request_id not in self.pending_requests: + tlog.error(f"Inference request {req.request_id} was created on the server," + f" but it is not part of our pending requests" + f" (pending: {list(self.pending_requests.keys())})") + + state = self.pending_requests[req.request_id] + state.state = RequestState.REQUEST_SENT + state.task = None + + task_or_res = httpx_request( + "post", url, data=data, timeout=self.req_timeout_s, + async_client=(None if sync else self.client), + callback=req_done_cb + ) + if not sync: + self.pending_requests[req.request_id] = PendingRequest( + request_id=req.request_id, + state=RequestState.REQUEST_PENDING, + task=task_or_res, + ) + return task_or_res + + + def _get_inference_results(self, request_id: str, on_result_received: callable) -> asyncio.Task: + def task_cb(task): + # Hide the progress bar (regardless of success or failure) + self.testbed.gen3c_inference_progress = -1.0 + + try: + res: httpx.Response = task.result() + except RuntimeError as e: + tlog.error(f"Results request task for inference {request_id} failed!\n{e}") + + if res.status_code == 503: + # Result not ready yet, wait some more + on_result_received(result=None, response=res) + return + elif res.status_code != 200: + # Result failed, we shouldn't retry further + on_result_received(result=None, response=res, failed=True) + return + + # Result ready + on_result_received(pickle.loads(res.content), response=res) + return + + def progress_cb(progress: float, bar: tqdm, **kwargs): + total_mb = bar.total / (1024 * 1024) + self.testbed.gen3c_inference_info = f"Downloading inference results ({total_mb:.1f} MB)" + self.testbed.gen3c_inference_progress = progress + + return httpx_request( + "get", + self.url + f"/inference-result?request_id={request_id}", + # Waiting for the model to finish inference can be very long, + # especially for single-GPU inference. + timeout=10 * self.req_timeout_s, + progress=True, + desc=f"Inference results for {request_id}", + async_client=self.client, + callback=task_cb, + progress_callback=progress_cb + ) + + # ---------- + + def load_seeding_data(self, seeding_data_path: str, display: bool = True, + normalize_cameras: bool = False) -> SeedingRequest: + + if not os.path.exists(seeding_data_path): + tlog.error(f"Cannot seed with invalid path: \"{seeding_data_path}\"") + return None + tlog.info(f"Seeding model from \"{seeding_data_path}\"") + + req = load_v2v_seeding_data(seeding_data_path, max_frames=self.seed_max_frames, + frames_stride=self.seed_stride) + + if normalize_cameras: + if isinstance(req, CompressedSeedingRequest): + raise NotImplementedError("Normalizing cameras not implemented for compressed seeding data") + + # Post-process the cameras so that they are centered at (0.5, 0.5, 0.5) + # and so that they fit within a reasonable scale. + current_origins = req.cameras_to_world[:, :3, 3] + current_center = np.mean(current_origins, axis=0) + current_scale = np.mean(np.linalg.norm(current_origins, axis=1)) + # TODO: robust scale estimation using the median depth as well + if req.depths is not None: + median_depth = np.nanmedian(req.depths) + current_scale = max(current_scale, median_depth) + + # tlog.debug(f"Current scale: {current_scale}") + + if current_scale != 0.0: + normalized_origins = (current_origins - current_center) / current_scale + + new_center = np.array([0.5, 0.5, 0.5], dtype=np.float32) + # aabb_scale = np.linalg.norm(self.testbed.render_aabb.max - self.testbed.render_aabb.min) + # new_scale = aabb_scale / 4 + new_scale = 1.0 + req.cameras_to_world[:, :3, 3] = (normalized_origins * new_scale) + new_center + + # Rescale the depth values by the same + req.depths *= new_scale / current_scale + # TODO: retain this information so that we can undo the transform when + # communicating with the server or saving stuff out. + + if display and (req.depths is not None): + # If there's not depth data available at this point, we'll download it from the server + # when seeding is done, and display the frames then. + self.display_seeding_data(req, save_frames=self.testbed.gen3c_save_frames) + + return req + + + def display_seeding_data(self, req: SeedingRequest, res: SeedingResult | None = None, + save_frames: bool = False) -> None: + self.testbed.clear_src_views() + + if isinstance(req, CompressedSeedingRequest): + # Since the de-compression is done inline, we make sure not to + # populate uncompressed data in the request before sending it over. + req = deepcopy(req) + req.decompress() + + images = req.images + depths = req.depths + if res is not None: + # Adopt extrinsics and intrinsics from the server, the model might + # have estimated them better than our hardcoded guess. + focal_lengths = res.focal_lengths.copy() + cameras_to_world = res.cameras_to_world + principal_points = res.principal_points + + if res.depths is not None: + # TODO: the depth estimated by the server may have a completely different scale. + depths = res.depths + if res.depths.shape[1:] != images.shape[1:3]: + # Depth prediction took place on the server at a different resolution, + # let's resize the RGB images to match. + tlog.debug(f"Resizing seeding images for display to match depth resolution {depths.shape[1:3]}") + resized = [] + for i in range(len(req)): + resized.append( + cv2.resize(images[i, ...], (depths.shape[2], depths.shape[1]), + interpolation=cv2.INTER_CUBIC) + ) + # Let's assume that the inference server already adjusted the intrinsics + # to match the requested inference resolution. + # focal_lengths[i, 0] *= depths.shape[2] / images.shape[2] + # focal_lengths[i, 1] *= depths.shape[1] / images.shape[1] + images = np.stack(resized, axis=0) + else: + focal_lengths = req.focal_lengths.copy() + cameras_to_world = req.cameras_to_world + principal_points = req.principal_points + + + if save_frames: + os.makedirs(self.testbed.gen3c_output_dir, exist_ok=True) + for seed_i in range(len(req)): + res_id = f"seeding_{seed_i:04d}" + image = ensure_alpha_channel(images[seed_i, ...]) + + if save_frames: + safe_res_id = res_id + fname = join(self.testbed.gen3c_output_dir, f"rgb_{safe_res_id}.exr") + pyexr.write(fname, image) + + if depths is not None: + fname = join(self.testbed.gen3c_output_dir, f"depth_{safe_res_id}.exr") + pyexr.write(fname, depths[seed_i, ...].astype(np.float32)) + tlog.success(f"[+] Wrote seeding frame to: {fname}") + + if depths is None: + # Still no depth values available, cannot display + continue + + self.testbed.add_src_view( + cameras_to_world[seed_i, ...], + focal_lengths[seed_i][0], + focal_lengths[seed_i][1], + principal_points[seed_i][0], + principal_points[seed_i][1], + self.lens, + image, + depths[seed_i, ...], + # TODO: seeding request could also have timestamps + seed_i * 1 / 30, + is_srgb=True, + ) + tlog.success(f"[+] Displaying seeding view: {res_id}") + + tlog.info(f"Setting camera path from seeding view.") + # First, initialize the camera path from all seeding view. + self.set_max_number_of_displayed_views(len(req)) + self.testbed.init_camera_path_from_reproject_src_cameras() + # Then, limit the number of displayed views so that rendering doesn't slow down too much. + self.set_max_number_of_displayed_views(8) + self.testbed.reset_accumulation(reset_pip=True) + + + def send_seeding_request(self, req: SeedingRequest, sync: bool = False) -> asyncio.Task | None: + """ + Note: we do seeding requests synchronously by default so that we don't have to implement + eager checking, etc. + """ + + qp = "?sync=1" if sync else "" + url = self.url + "/seed-model" + qp + depth_was_missing = (req.depths is None) + + def req_done_cb(task_or_res: asyncio.Task | httpx.Response) -> None: + # Hide the progress bar (regardless of success or failure) + self.testbed.gen3c_seeding_progress = -1.0 + if sync: + res: httpx.Response = task_or_res + else: + try: + res: httpx.Response = task_or_res.result() + except RuntimeError as e: + tlog.error(f"Seeding request task failed!\n{e}") + return + + if res.status_code >= 300: + tlog.error(f"Seeding request failed!\n{res.content}") + return None + + if depth_was_missing: + response: SeedingResult = pickle.loads(res.content) + self.display_seeding_data(req, res=response, save_frames=self.testbed.gen3c_save_frames) + + message = "Model seeded." + self.testbed.gen3c_info = "\n".join([ + self.testbed.gen3c_info.split("\n")[0], + message + ]) + tlog.success(message) + + def progress_cb(progress: float, **kwargs): + self.testbed.gen3c_seeding_progress = progress + + if not isinstance(req, CompressedSeedingRequest): + req = req.compress() + + data = pickle.dumps(req) + try: + progress_direction = "both" if depth_was_missing else "auto" + return httpx_request("post", url, data=data, timeout=self.req_timeout_s, + progress=True, progress_direction=progress_direction, + desc="Seeding", + async_client=(None if sync else self.client), + callback=req_done_cb, + progress_callback=progress_cb) + except (httpx.TimeoutException, httpx.ConnectError) as e: + tlog.error(f"Seeding request failed (timeout or connection error)!\n{e}") + return None + + # ---------- + + def set_max_number_of_displayed_views(self, n_views: int) -> None: + tlog.info(f"Setting max number of displayed views to {n_views}") + # Jump to the last view. + self.testbed.reproject_max_src_view_index = min(self.testbed.reproject_src_views_count(), n_views) + + def _transfer_in_progress(self) -> bool: + return (self.testbed.gen3c_inference_progress >= 0.0) or (self.testbed.gen3c_seeding_progress >= 0.0) + + # ---------- + + def adapt_view_to_cameras(self, cameras_to_world: np.ndarray, + go_to_default_camera: bool = True) -> None: + """ + Analyzes the given set of cameras, and tries to adapt the current + up vector, default camera pose, etc to match. + + Note: this hasn't been tested very thoroughly yet and could easily + do the wrong thing depending on the inputs. + """ + assert cameras_to_world.shape[1:] == (3, 4) + + # --- Up vector + # Average of the cameras' individual up vectors, snapped to an axis. + mean_up = np.mean(cameras_to_world[:, :3, 1], axis=0) + up_axis = np.argmax(np.abs(mean_up)) + up = np.zeros((3,), dtype=np.float32) + up[up_axis] = -np.sign(mean_up[up_axis]) + self.testbed.up_dir = up + + # --- Default camera pose + default_c2w = cameras_to_world[0, :3, :] + + # Note: `default_camera` is technically a 4x3 camera, but the bindings + # expose it as a 3x4 matrix, so we can set it as normal here. + self.testbed.default_camera = default_c2w + tlog.debug(f"Based on the seeding data, setting up dir to {self.testbed.up_dir}" + f" and default camera to:\n{self.testbed.default_camera}") + + if go_to_default_camera: + self.testbed.reset_camera() + + + + def gui_callback(self, event: str) -> bool: + match event: + case "seed_model": + seed_req = self.load_seeding_data(self.testbed.gen3c_seed_path) + if seed_req is not None: + self.adapt_view_to_cameras(seed_req.cameras_to_world) + self.send_seeding_request(seed_req) + # "True" means we handled the event, not that seeding was successful. + return True + + case "request_inference": + self.request_frames(sync=False) + return True + + case "abort_inference": + tlog.info("Aborting inference request...") + tlog.error("Not implemented yet: aborting an ongoing inference request. Ignoring.") + return True + + return False + + def file_drop_callback(self, paths: list[str]) -> bool: + tlog.info(f"Received {len(paths)} file{'s' if len(paths) > 1 else ''} via drag & drop: {paths}") + for path in paths: + ext = os.path.splitext(path)[1].lower() + if os.path.isdir(path) or ext in (".jpg", ".png", ".exr"): + self.testbed.gen3c_seed_path = path + self.seeding_pending = True + elif ext == ".json": + try: + self.testbed.load_camera_path(path) + except RuntimeError as e: + tlog.error(f"Error loading camera path, perhaps the formata is incorrect?\n\t{e}") + else: + tlog.error(f"Don't know how to handle given file: {path}") + return True + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("client.py") + parser.add_argument("files", nargs="*", + help="Files to be loaded. Can be a camera path, scene name," + " seed image, or pre-processed video directory.") + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=8000) + parser.add_argument("--request-latency-ms", "--latency", default=250) + parser.add_argument("--inference-resolution", nargs=2, default=(576, 320)) + parser.add_argument("--vr", action="store_true") + parser.add_argument("--seed-max-frames", type=int, default=None, + help="If seeding from a video, maximum number of frames to use.") + parser.add_argument("--seed-stride", type=int, default=1, + help="If seeding from a video, number of frames to skip when reading (stride).") + parser.add_argument("--output-dir", "-o", type=str, default=None, + help="Directory in which to save the inference results.") + args = parser.parse_args() + + client = Gen3cClient(**vars(args)) + asyncio.run(client.run()) diff --git a/gui/api/encoding.py b/gui/api/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..5da8c68d259953aa849f08d5c58338a8903fd99c --- /dev/null +++ b/gui/api/encoding.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import io +import tempfile + +import cv2 +import numpy as np + +class CompressionFormat(Enum): + JPG = "jpg" + PNG = "png" + EXR = "exr" + MP4 = "mp4" + NPZ = "npz" + +IMAGE_COMPRESSION_FORMATS = (CompressionFormat.JPG, CompressionFormat.PNG, CompressionFormat.EXR) + + +def compress_images(images: np.ndarray | None, format: CompressionFormat, + is_depth: bool = False, is_bool: bool = False) -> list[bytes] | None: + """ + Compress image(s) to the desired image format. + Depth images should be encoded as EXR to preserve the data. + """ + if images is None: + return None + + if is_depth or is_bool: + assert images.ndim == 3, images.shape + else: + assert images.ndim == 4 and images.shape[-1] == 3, images.shape + + flags = [] + if format == CompressionFormat.JPG: + flags = [int(cv2.IMWRITE_JPEG_QUALITY), 100] + + result = [] + if is_depth: + # Note: leave as-is (floating point) to avoid quantization errors. + assert format in (CompressionFormat.EXR, CompressionFormat.NPZ), "Depth images must be encoded as EXR or NPZ" + images = images.astype(np.float32) + elif is_bool: + assert format == CompressionFormat.NPZ, "Bool images (e.g. masks) must be encoded as NPZ" + images = images.astype(np.bool) + else: + images = (images * 255.0).astype(np.uint8) + + if format == CompressionFormat.NPZ: + with io.BytesIO() as f: + np.savez_compressed(f, images) + result.append(f.getvalue()) + + else: + assert format in IMAGE_COMPRESSION_FORMATS, f"Unsupported image compression format: {format}" + for i in range(images.shape[0]): + _, encoded = cv2.imencode(f".{format.value}", images[i], flags) + result.append(encoded.tobytes()) + + return result + + +def decompress_buffer(buffers: list[bytes] | None, format: CompressionFormat, + is_depth: bool = False, is_bool: bool = False) -> np.ndarray | None: + """ + Returns the decoded image as 0..1 float values (or 0..inf for depth). + """ + if buffers is None: + return None + assert not (is_depth and is_bool), "Cannot be both a depth and a bool buffer." + + images = [] + for buf in buffers: + + if format == CompressionFormat.MP4: + assert not is_bool and not is_depth, "Cannot decode a mask or depth from a video." + + # TODO: not sure why, but reading directly from the buffer leads to a segfault. + # cap = cv2.VideoCapture(io.BytesIO(buf), apiPreference=cv2.CAP_FFMPEG, params=[]) + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as f: + f.write(buf) + cap = cv2.VideoCapture(f.name) + + while True: + ret, image = cap.read() + if not ret: + break + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # Note: the conversion from 0..1 to -1..1 will be done by the model. + image = image.astype(np.float32) / 255.0 + images.append(image[None, ...]) + cap.release() + + else: + if format == CompressionFormat.NPZ: + image = np.load(io.BytesIO(buf), allow_pickle=False) + if hasattr(image, "files"): + assert len(image.files) == 1, image.files + image = image[image.files[0]] + # We assume it was saved with the right value range, shape and dtype. + images.append(image) + + else: + buf_np = np.frombuffer(buf, dtype=np.uint8) + + # OpenCV will automatically guess the image format. + flags = cv2.IMREAD_ANYDEPTH if is_depth else cv2.IMREAD_ANYCOLOR + image = np.array(cv2.imdecode(buf_np, flags)) + + if is_bool: + image = image.astype(np.bool) + elif image.dtype == np.uint8: + image = image.astype(np.float32) / 255.0 + + images.append(image[None, ...]) + + return np.concatenate(images, axis=0) + + + +def pad_or_trim_array(arr: np.ndarray | None, target_size: int) -> np.ndarray | None: + """ + Pad or trim the array to the target size. + """ + if arr is None: + return None + + n = arr.shape[0] + if n == target_size: + return arr + elif n > target_size: + return arr[:target_size] + else: + reps = (target_size - n, *([1] * (arr.ndim - 1))) + return np.concatenate([ + arr, + np.tile(arr[-1:], reps) + ], axis=0) + + + +def pad_or_trim_encoded_buffers(buffers: list[bytes] | None, format: CompressionFormat, + target_size: int) -> list[bytes] | None: + """ + Pad or trim the encoded buffers to the target size. + """ + if buffers is None: + return None + + if format in (CompressionFormat.JPG, CompressionFormat.PNG, CompressionFormat.EXR): + # We just assume that there is one buffer per entry + n = len(buffers) + + if n == target_size: + return buffers + elif n > target_size: + return buffers[:target_size] + else: + return buffers + [buffers[-1]] * (target_size - n) + + elif format == CompressionFormat.NPZ: + assert len(buffers) == 1, "NPZ buffers should be a single buffer" + arr = np.load(io.BytesIO(buffers[0]), allow_pickle=False) + if hasattr(arr, "files"): + assert len(arr.files) == 1, arr.files + arr = arr[arr.files[0]] + + arr = pad_or_trim_array(arr, target_size) + with io.BytesIO() as f: + np.savez_compressed(f, arr) + return [f.getvalue()] + + + elif format == CompressionFormat.MP4: + # We assume there is one buffer per video + assert len(buffers) == 1, "MP4 buffers should be a single buffer" + buf = buffers[0] + result = [] + + # TODO: do all this with in-memory buffers instead of temporary files + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as f: + f.write(buf) + + # Read back the video frame by frame + cap = cv2.VideoCapture(f.name) + fps = cap.get(cv2.CAP_PROP_FPS) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as f_out: + out = cv2.VideoWriter(f_out.name, fourcc, fps, (width, height)) + + n_written = 0 + last_frame = None + for _ in range(target_size): + ret, frame = cap.read() + if not ret: + break + out.write(frame) + last_frame = frame + n_written += 1 + + # If target size is longer than the original video, repeat the last valid frame + for i in range(n_written, target_size): + out.write(last_frame) + + out.release() + + f_out.seek(0) + result.append(f_out.read()) + cap.release() + + return result + + else: + raise ValueError(f"Unsupported compression format: {format}") diff --git a/gui/api/httpx_utils.py b/gui/api/httpx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76349557648ab9fb7cb0a7d0fd0dbcb5c5f5b286 --- /dev/null +++ b/gui/api/httpx_utils.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from io import BytesIO +import pickle +import time +from typing import Callable + +import httpx +from tqdm import tqdm + + +def content_with_progress(content, chunk_size=1024, desc="Upload", + progress_callback: Callable[[str, float, tqdm], None] | None = None): + total = len(content) + with tqdm(total=total, unit_scale=True, unit_divisor=1024, unit="B", desc=desc) as progress: + for i in range(0, total, chunk_size): + chunk = content[i:i + chunk_size] + yield chunk + report_progress("upload", len(chunk), progress, callback=progress_callback) + +async def async_content_with_progress(*args, **kwargs): + for chunk in content_with_progress(*args, **kwargs): + yield chunk + + +def streaming_response_to_response(response: httpx.Response, content_bytes: BytesIO) -> httpx.Response: + """ + Convert a streaming response to a non-streaming response. + """ + # TODO: is there a nicer way to get a non-streaming-style Response object, despite + # having used the streaming API above? (for uniform consumption by the caller). + to_remove = set(["is_stream_consumed", "next_request", "is_closed", "content", "stream"] + [ + k for k in response.__dict__ if k.startswith("_") + ]) + kwargs = { k: v for k, v in response.__dict__.items() if k not in to_remove } + + content_bytes.seek(0) + kwargs["content"] = content_bytes.read() + return httpx.Response(**kwargs) + + +def report_progress(direction: str, progress_absolute: int | float, + bar: tqdm, callback: Callable[[str, float, tqdm], None] | None = None): + bar.update(progress_absolute) + if callback is not None: + progress_percent = bar.n / bar.total + callback(direction=direction, progress=progress_percent, bar=bar) + + + +def httpx_request(method: str, + *args, + progress: bool = False, + progress_direction: str = "auto", + desc: str | None = None, + async_client: httpx.AsyncClient | None = None, + callback: Callable | None = None, + progress_callback: Callable[[str, float, tqdm], None] | None = None, + **kwargs) -> httpx.Response | asyncio.Task[httpx.Response]: + is_async = async_client is not None + + progress_download = progress and ( + progress_direction in ("both", "download") + or (progress_direction == "auto" and method.lower() == "get") + ) + progress_upload = progress and ( + progress_direction in ("both", "upload") + or (progress_direction == "auto" and method.lower() == "post") + ) + + if progress_upload: + for key in ("content", "data"): + if key in kwargs: + upload_desc = f"{desc} (upload)" if desc else "Upload" + wrapper = async_content_with_progress if is_async else content_with_progress + kwargs[key] = wrapper(kwargs[key], desc=upload_desc, progress_callback=progress_callback) + + if progress_download: + # Progress bar requested for download, need to use streaming API + + if async_client is None: + content_bytes = BytesIO() + with httpx.stream(method, *args, **kwargs) as response: + total = int(response.headers["Content-Length"]) + with tqdm(total=total, unit_scale=True, unit_divisor=1024, unit="B", desc=desc) as progress: + num_bytes_downloaded = response.num_bytes_downloaded + for chunk in response.iter_bytes(): + report_progress("download", response.num_bytes_downloaded - num_bytes_downloaded, + progress, callback=progress_callback) + + num_bytes_downloaded = response.num_bytes_downloaded + content_bytes.write(chunk) + response = streaming_response_to_response(response, content_bytes) + if callback is not None: + callback(response) + return response + + else: + async def inner(): + content_bytes = BytesIO() + async with async_client.stream(method, *args, **kwargs) as response: + total = int(response.headers["Content-Length"]) + with tqdm(total=total, unit_scale=True, unit_divisor=1024, unit="B", desc=desc) as progress: + num_bytes_downloaded = response.num_bytes_downloaded + async for chunk in response.aiter_bytes(): + report_progress("download", response.num_bytes_downloaded - num_bytes_downloaded, + progress, callback=progress_callback) + num_bytes_downloaded = response.num_bytes_downloaded + content_bytes.write(chunk) + response = streaming_response_to_response(response, content_bytes) + return response + + task = asyncio.create_task(inner()) + if callback is not None: + task.add_done_callback(callback) + return task + + else: + # No download progress bar needed, use standard httpx methods + if is_async: + task = asyncio.create_task( + async_client.request(method, *args, **kwargs) + ) + if callback is not None: + task.add_done_callback(callback) + return task + else: + res = httpx.request(method, *args, **kwargs) + if callback is not None: + callback(res) + return res + + +def benchmark_requests(host, port, n=100): + url = f"http://{host}:{port}/image" + + t0 = time.time() + for i in range(n): + res = httpx.get(url) + loaded = pickle.loads(res.content) + assert "image" in loaded + + elapsed = time.time() - t0 + print(f"Took {elapsed} s = {1000 * elapsed / n} ms/it") diff --git a/gui/api/multi_gpu.py b/gui/api/multi_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..4a44fad801abc02c4523a1010dcc0ad65cff490e --- /dev/null +++ b/gui/api/multi_gpu.py @@ -0,0 +1,354 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import os +import signal + +from loguru import logger as log + +from v2v_utils import move_to_device, clone_tensors + + +TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost' +TORCHRUN_DEFAULT_MASTER_PORT = 12355 + + +def _get_inference_class(cosmos_variant: str): + if cosmos_variant == 'predict1': + from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel + from cosmos_predict1.utils.distributed import is_rank0 + return Gen3cPersistentModel, is_rank0 + else: + raise ValueError(f"Unsupported cosmos variant: {cosmos_variant}") + + +def _inference_worker(rank: int, args: argparse.Namespace, + gpu_count: int, + cosmos_variant: str, + input_queues: 'list[torch.multiprocessing.Queue]', + result_queue: 'torch.multiprocessing.Queue', + attrs_queue: 'torch.multiprocessing.Queue'): + """ + One such function will run, in a separate process, for each GPU. + Each process loads the model and keeps it in memory. + """ + log.debug(f'inference_worker for rank {rank} starting, doing imports now') + import torch + import torch.distributed as dist + + InferenceAR, is_tp_cp_pp_rank0 = _get_inference_class(cosmos_variant) + log.debug(f'inference_worker for rank {rank} done with imports.') + + # The FQDN of the host that is running worker with rank 0; used to initialize the Torch Distributed backend. + os.environ.setdefault("MASTER_ADDR", TORCHRUN_DEFAULT_MASTER_ADDR) + # The port on the MASTER_ADDR that can be used to host the C10d TCP store. + os.environ.setdefault("MASTER_PORT", str(TORCHRUN_DEFAULT_MASTER_PORT)) + # The local rank. + os.environ["LOCAL_RANK"] = str(rank) + # The global rank. + os.environ["RANK"] = str(rank) + # The rank of the worker group. A number between 0 and max_nnodes. When running a single worker group per node, this is the rank of the node. + os.environ["GROUP_RANK"] = str(rank) + # The rank of the worker across all the workers that have the same role. The role of the worker is specified in the WorkerSpec. + os.environ["ROLE_RANK"] = str(rank) + # The local world size (e.g. number of workers running locally); equals to --nproc-per-node specified on torchrun. + os.environ["LOCAL_WORLD_SIZE"] = str(gpu_count) + # The world size (total number of workers in the job). + os.environ["WORLD_SIZE"] = str(gpu_count) + # The total number of workers that was launched with the same role specified in WorkerSpec. + os.environ["ROLE_WORLD_SIZE"] = str(gpu_count) + # # The number of worker group restarts so far. + # os.environ["TORCHELASTIC_RESTART_COUNT"] = TODO + # # The configured maximum number of restarts. + # os.environ["TORCHELASTIC_MAX_RESTARTS"] = TODO + # # Equal to the rendezvous run_id (e.g. unique job id). + # os.environ["TORCHELASTIC_RUN_ID"] = TODO + # # System executable override. If provided, the python user script will use the value of PYTHON_EXEC as executable. The sys.executable is used by default. + # os.environ["PYTHON_EXEC"] = TODO + + # We're already parallelizing over the context, so we can't also parallelize inside the tokenizers (?) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + device = f"cuda:{rank}" + torch.cuda.set_device(rank) + + input_queue = input_queues[rank] + del input_queues + + # Load model once + log.debug(f'inference_worker for rank {rank} creating the model object now') + local_model = InferenceAR(args) + del args + + log.debug(f'inference_worker for rank {rank} ready, pushing a "ready" message to the queue') + result_queue.put((rank, "ready")) + + # Install interrupt signal handler so that we can shut down gracefully. + should_quit = False + def signal_handler(signum, frame): + nonlocal should_quit + log.info(f"[RANK{rank}] Received signal {signum}, shutting down") + should_quit = True + try: + input_queue.put(None) + except ValueError: + pass + + signal.signal(signal.SIGINT, signal_handler) + + while not should_quit: + try: + inputs_task = input_queue.get() + except ValueError: + # Queue was closed, we can exit. + log.debug(f"[RANK{rank}] Input queue was closed, exiting.") + break + if inputs_task is None: + # Special sentinel value to indicate that we are done and can exit. + log.debug(f"[RANK{rank}] Got input {inputs_task}, exiting.") + break + + # Note: we don't need to chunk the inputs for this rank / process, this is done + # automatically in the model. + # Note: we don't need to move the inputs to a specific device either since the + # Gen3C API expects NumPy arrays. + if False: + log.debug(f"[RANK{rank}] Moving task to {device=}") + inputs_task = move_to_device(inputs_task, device) + + # Run the requested task + with torch.no_grad(): + task_type, args, kwargs = inputs_task + log.debug(f"[RANK{rank}] Got task: {task_type=}") + + if task_type == 'inference': + log.debug(f"[RANK{rank}] Running `inference_on_cameras()`...") + output = local_model.inference_on_cameras(*args, **kwargs) + log.debug(f"[RANK{rank}] Done `inference_on_cameras()`!") + + if is_tp_cp_pp_rank0(): + log.debug(f"[RANK{rank}] Moving outputs of `inference_on_cameras()` to the CPU") + output = move_to_device(output, device='cpu') + log.debug(f"[RANK{rank}] Pushing outputs of `inference_on_cameras()` to the results queue") + result_queue.put(output) + + elif task_type == 'seeding': + log.debug(f"[RANK{rank}] Calling `seed_model_from_values()...`") + if cosmos_variant == 'predict1': + output = local_model.seed_model_from_values(*args, **kwargs) + else: + raise NotImplementedError(f"Unsupported cosmos variant: {cosmos_variant}") + output = move_to_device(output, device='cpu') + result_queue.put((rank, "seed_model_from_values_done", output)) + log.debug(f"[RANK{rank}] Done with `seed_model_from_values()`") + + elif task_type == 'clear_cache': + log.debug(f"[RANK{rank}] Calling `clear_cache()...`") + local_model.clear_cache() + result_queue.put((rank, "clear_cache_done")) + log.debug(f"[RANK{rank}] Done with `clear_cache()`") + + elif task_type == 'get_cache_input_depths': + log.debug(f"[RANK{rank}] Calling `get_cache_input_depths()...`") + input_depths = local_model.get_cache_input_depths() + attrs_queue.put(('cache_input_depths', input_depths.cpu(), True)) + log.debug(f"[RANK{rank}] Done with `get_cache_input_depths()`") + + elif task_type == 'getattr': + assert kwargs is None + assert len(args) == 1 + attr_name = args[0] + assert isinstance(attr_name, str) + has_attr = hasattr(local_model, attr_name) + attr_value_or_none = getattr(local_model, attr_name) + + if has_attr and (attr_value_or_none is not None) and torch.is_tensor(attr_value_or_none): + log.debug(f"[RANK{rank}] Attribute {attr_name=} is a torch tensor on " + f"device {attr_value_or_none.device}, cloning it before sending it through the queue") + attr_value_or_none = attr_value_or_none.clone() + + log.debug(f"[RANK{rank}] Pushing attribute value for {attr_name=}") + attrs_queue.put((attr_name, attr_value_or_none, has_attr)) + + else: + raise NotImplementedError(f"Unsupported task type for Cosmos inference worker: {task_type}") + + # Cleanup before exiting + local_model.cleanup() + del local_model + + +def inference_worker(*args, **kwargs): + try: + _inference_worker(*args, **kwargs) + except Exception as e: + import traceback + rank = os.environ.get("LOCAL_RANK", "(unknown)") + log.error(f"[RANK{rank}] encountered exception: {e}. Will re-raise after cleanup." + f" Stack trace:\n{traceback.format_exc()}") + + try: + import torch.distributed as dist + dist.destroy_process_group() + log.info(f"[RANK{rank}] Destroyed model parallel group after catching exception." + " Will re-raise now.") + except Exception as _: + pass + + raise e + + +class MultiGPUInferenceAR(): + """ + Adapter class to run multi-GPU Cosmos inference in the context of the FastAPI inference server. + This class implements the same interface as `InferenceAR`, but spawns one process per GPU and + forwards inference requests to the multiple processes via a work queue. + + The worker processes wait for work from the queue, perform inference, and gather all results + on the rank 0 process. That process then pushes results to the result queue. + """ + def __init__(self, gpu_count: int, cosmos_variant: str, args: argparse.Namespace): + import torch + import torch.multiprocessing as mp + + self.gpu_count = gpu_count + assert self.gpu_count <= torch.cuda.device_count(), \ + f"Requested {self.gpu_count} GPUs, but only {torch.cuda.device_count()} are available." + + ctx = mp.get_context('spawn') + manager = ctx.Manager() + self.input_queues: list[mp.Queue] = [ctx.Queue() for _ in range(self.gpu_count)] + self.result_queue = manager.Queue() + self.attrs_queue = manager.Queue() + + log.info(f"Spawning {self.gpu_count} processes (one per GPU)") + self.ctx = mp.spawn( + inference_worker, + args=(args, self.gpu_count, cosmos_variant, + self.input_queues, self.result_queue, self.attrs_queue), + nprocs=self.gpu_count, + join=False + ) + + log.info(f"Waiting for {self.gpu_count} processes to load the model...") + for _ in range(self.gpu_count): + v = self.result_queue.get() + if not isinstance(v, tuple) or len(v) != 2 or v[1] != "ready": + raise ValueError(f"Expected a 'ready' message from each process, but received: {v}") + log.info(f"Process {v[0]} is ready.") + + + def inference_on_cameras(self, *args, **kwargs): + log.debug(f"inference_on_cameras(): submitting request to {len(self.input_queues)} inference processes.") + for iq in self.input_queues: + # Send the same input to each process + task = ('inference', args, kwargs) + iq.put(task) + + # Wait on the result queue to produce the result (this could take a while). + log.debug(f"inference_on_cameras(): waiting for result...") + outputs = self.result_queue.get() + log.debug(f"inference_on_cameras(): got inference results! Cloning and returning.") + return clone_tensors(outputs) + + + def seed_model_from_values(self, *args, **kwargs): + log.debug(f"seed_model_from_values(): submitting request to {len(self.input_queues)} inference processes.") + for iq in self.input_queues: + task = ('seeding', args, kwargs) + iq.put(task) + + # TODO: refactor this, and maybe use some events or another primitive + log.info(f"Waiting for {self.gpu_count} processes to be done with seeding...") + for i in range(self.gpu_count): + v = self.result_queue.get() + if not isinstance(v, tuple) or len(v) != 3 or v[1] != "seed_model_from_values_done": + raise ValueError(f"Expected a 'seed_model_from_values_done' message from each process, but received: {v}") + log.info(f"Process {v[0]} is done with `seed_model_from_values()`.") + + # Arbitrarily pick the output from the first process + if i == 0: + outputs = v[2] + + return clone_tensors(outputs) + + + def clear_cache(self): + for iq in self.input_queues: + task = ('clear_cache', None, None) + iq.put(task) + + # TODO: refactor this, and maybe use some events or another primitive + log.info(f"Waiting for {self.gpu_count} processes to be done with clear_cache...") + for _ in range(self.gpu_count): + v = self.result_queue.get() + if not isinstance(v, tuple) or len(v) != 2 or v[1] != "clear_cache_done": + raise ValueError(f"Expected a 'clear_cache_done' message from each process, but received: {v}") + log.info(f"Process {v[0]} is done with `clear_cache()`.") + + + def get_cache_input_depths(self): + name = 'cache_input_depths' + task = ('get_cache_input_depths', None, None) + self.input_queues[0].put(task) + + # TODO: refactor this, and maybe use some events or another primitive + looked_up_name, value, exists = self.attrs_queue.get() + if looked_up_name != name: + # TODO: this could be handled better (retry or enforce some ordering maybe). + raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}'," + " there was likely a race condition.") + log.debug(f"Got a valid response, returning value for `get_cache_input_depths()`") + return value + + + def __getattr__(self, name: str): + log.debug(f"__getattr__({name=}) called") + # Note: this will not be called for methods we implement here, or attributes + # that actually exist in this object. + # Query the attribute from rank 0 (arbitrarily) + task = ('getattr', (name,), None) + self.input_queues[0].put(task) + + # Get result (blocking) + log.debug(f"Waiting for response on `attrs_queue`...") + looked_up_name, value, exists = self.attrs_queue.get() + if looked_up_name != name: + # TODO: this could be handled better (retry or enforce some ordering maybe). + raise ValueError(f"Queried model for attribute '{name}' but got attribute '{looked_up_name}'," + " there was likely a race condition.") + if not exists: + raise AttributeError(f"Model has no attribute named '{name}'") + log.debug(f"Got a valid response, returning {name} == {value}") + return value + + + def cleanup(self): + """ + Clean up resources before shutting down. + """ + log.info(f"MultiGPUInferenceAR winding down, asking {len(self.input_queues)} processes to clean up.") + + # "Close" all queues (there's no actual `close` method in PyTorch MP queues) + for iq in self.input_queues: + iq.put(None) + + # Wait for all processes to finish + log.info(f"Waiting for {len(self.input_queues)} processes to finish (join).") + self.ctx.join() + log.info(f"{len(self.input_queues)} processes have finished.") diff --git a/gui/api/pyproject.toml b/gui/api/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..55ec8d784c3cfc688c31cb32a4225e5f551937b8 --- /dev/null +++ b/gui/api/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 diff --git a/gui/api/server.py b/gui/api/server.py new file mode 100644 index 0000000000000000000000000000000000000000..be680316f888ed105f2c352a6b8eb0750cb67125 --- /dev/null +++ b/gui/api/server.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import asynccontextmanager +from dataclasses import dataclass +import logging +import os +import pickle +import traceback + +from fastapi import FastAPI +from fastapi.requests import Request +from fastapi.responses import Response +import imageio.v3 as imageio +import numpy as np + +from api_types import CompressedSeedingRequest +from server_base import InferenceModel +from server_cosmos import CosmosModel + + +# ------------------------------ + +@dataclass +class ServerSettings(): + """ + Note: we use a dataclass + env variables because we can't + easily pass command line arguments through the `fastapi` launcher. + """ + model: str = os.environ.get("GEN3C_MODEL", "cosmos-predict1") + checkpoint_path: str | None = os.environ.get("GEN3C_CKPT_PATH") + data_path: str | None = os.environ.get("GEN3C_DATA_PATH") + + #: Additional latency to add to any inference request, in milliseconds. + inference_latency: int = int(os.environ.get("GEN3C_INFERENCE_LATENCY", 0)) + + #: Number of inference results to keep in cache. + #: This may be useful when multiple requests are in flight and the user hasn't + #: retrieved the results yet. + inference_cache_size: int = int(os.environ.get("GEN3C_INFERENCE_CACHE_SIZE", 15)) + + #: Number of GPUs to use for inference. Leave at 0 to automatically select + #: based on available hardware. + gpu_count: int = int(os.environ.get("GEN3C_GPU_COUNT", 0)) + + +settings = ServerSettings() +model: InferenceModel | None = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + global model + + model_name = settings.model.lower() + if model_name in ("cosmos", "cosmos-predict1"): + cls = CosmosModel + else: + raise ValueError(f"Unsupported model type: '{settings.model}'") + + model = cls(checkpoint_path=settings.checkpoint_path, + data_path=settings.data_path, + fake_delay_ms=settings.inference_latency, + inference_cache_size=settings.inference_cache_size, + gpu_count=settings.gpu_count) + + # --- Startup code + # Pre-render at least one image to make sure everything is running + if not model.requires_seeding(): + await model.make_test_image() + + yield + + # --- Shutdown code + model.cleanup() + del model + +app = FastAPI(lifespan=lifespan) +logger = logging.getLogger('uvicorn.error') + + +# ------------------------------ + +def get_bool_query_param(request: Request, name: str, default: bool) -> bool: + b_str = request.query_params.get(name, "1" if default else "0") + return b_str.lower() in ("1", "true", "yes", "") + + +@app.post("/request-inference", response_class=Response, response_model=None) +async def request_inference(request: Request): + """ + Start a new asynchronous inference job. + """ + sync = get_bool_query_param(request, "sync", default=False) + req: bytes = await request.body() + req = pickle.loads(req) + + try: + if sync: + result = await model.request_inference_sync(req) + return Response(content=pickle.dumps(result), + media_type="application/octet-stream") + else: + model.request_inference(req) + except Exception as e: + logging.error("Inference request failed with exception:" + f"\n{e}\n{traceback.format_exc()}") + return Response(str(e), status_code=400) + + return Response("Request accepted.", status_code=202) + + +@app.post("/seed-model", response_class=Response, response_model=None) +async def seed_model(request: Request): + """ + Start a new asynchronous inference job. + """ + sync = get_bool_query_param(request, "sync", default=False) + req: bytes = await request.body() + req = pickle.loads(req) + + if isinstance(req, CompressedSeedingRequest): + req.decompress() + + try: + # There isn't really anything async about the seeding request being done on the server + # so far, so we just await. This could be changed in the future. + result = await model.seed_model(req) + except Exception as e: + logging.error(f"Seeding request failed with exception:" + f"\n{e}\n{traceback.format_exc()}") + return Response(str(e), status_code=400) + + # return Response("Seeding request accepted.", status_code=(200 if sync else 202)) + return Response(content=pickle.dumps(result), + media_type="application/octet-stream") + + +@app.get("/inference-result", response_class=Response, response_model=None) +async def inference_results_or_none(request_id: str): + try: + result = model.inference_result_or_none(request_id) + except Exception as e: + # TODO: try to differentiate the status codes (doesn't exist, inference failed, etc) + logging.error(f"Inference results request failed with exception:" + f"\n{e}\n{traceback.format_exc()}") + return Response(str(e), status_code=500) + + if result is None: + return Response(content="Result not ready", + status_code=503) + else: + return Response(content=pickle.dumps(result), + media_type="application/octet-stream") + + +@app.get("/image", response_class=Response) +def latest_rgb(format: str = "jpg"): + # We return the data as pickled bytes to avoid the JSON serialization / deserialization overhead. + image = model.get_latest_rgb() + if image is None: + return Response(content="No image available yet.", status_code=404) + + if format == "pickle": + content = pickle.dumps( + { + "image": image, + } + ) + return Response(content=content, media_type="application/octet-stream") + + elif format in ("jpg", "png"): + image = image.copy() + # Allow alpha channel to be omitted for faster transfers + if image.shape[-1] == 3: + image = np.concatenate([ + image, + np.ones((*image.shape[:2], 1)) + ], axis=-1) + + if image.dtype != np.uint8: + # TODO: proper handling of gamma compression, etc + image[:, :, :3] = np.power(image[:, :, :3], 1 / 2.2) * 255 + image[:, :, 3] = image[:, :, 3] * 255 + if format != "png": + image = image[:, :, :3] + + content = imageio.imwrite(uri="", image=image.astype(np.uint8), extension="." + format) + return Response(content=content, media_type=f"image/{format}") + + else: + return Response(f"Unsupported image format: {format}", status_code=400) + + +@app.get("/metadata") +def metadata(): + return model.metadata() diff --git a/gui/api/server_base.py b/gui/api/server_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a992cb2c32b337a2d28fa67db411b4fa3cf3668f --- /dev/null +++ b/gui/api/server_base.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +import asyncio +from os.path import realpath, dirname, join + +from loguru import logger as log +import numpy as np + +from api_types import InferenceRequest, InferenceResult, SeedingRequest + + +ROOT_DIR = realpath(dirname(dirname(dirname(__file__)))) +DATA_DIR = join(ROOT_DIR, "data") + + +class InferenceModel(): + """ + Base class for models that can be served by the inference server + defined in `server.py`. + """ + + def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None, + fake_delay_ms: float = 0, inference_cache_size: int = 15, + compress_inference_results: bool = True) -> None: + + # These paths may be unused by certain inference server types. + self.data_path = data_path + self.checkpoint_path = checkpoint_path + + self.fake_delay_ms = fake_delay_ms + self.inference_cache_size = inference_cache_size + + self.inference_tasks: dict[str, asyncio.Task] = {} + self.inference_results: dict[str, InferenceResult] = {} + self.request_history: set[str] = set() + + # If supported by the model and relevant, compress inference results, + # e.g. as MP4 video, before returning from the server. + self.compress_inference_results: bool = compress_inference_results + + # Can be acquired before starting inference + # if the model can only handle one request at a time + self.inference_lock = asyncio.Lock() + + # The generative model may need to be seeded with one or more initial frames. + self.model_seeded = False + + + # ----------- Inference model interface + + @abstractmethod + async def make_test_image(self): + """Evaluate one default inference request, if possible. + Helps ensuring that the model has been loaded correctly.""" + raise NotImplementedError("make_test_image") + + async def seed_model(self, req: SeedingRequest) -> None: + """By default, no seeding is required so the default implementation just returns.""" + self.model_seeded = True + + @abstractmethod + async def run_inference(self, req: InferenceRequest) -> InferenceResult: + """Evaluate the actual inference model to produce an inference result.""" + raise NotImplementedError("run_inference") + + @abstractmethod + def metadata(self) -> dict: + """Returns metadata about this inference server.""" + raise NotImplementedError("metadata") + + @abstractmethod + def min_frames_per_request(self) -> int: + """Minimum number of frames that can be produced in one inference batch.""" + raise NotImplementedError("min_frames_per_request") + + @abstractmethod + def max_frames_per_request(self) -> int: + """Maximum number of frames that can be produced in one inference batch.""" + raise NotImplementedError("max_frames_per_request") + + @abstractmethod + def inference_time_per_frame(self) -> int: + """Estimated average inference time per frame (not per batch!) in seconds.""" + raise NotImplementedError("inference_time_per_frame") + + def inference_resolution(self) -> list[tuple[int, int]] | None: + """ + The supported inference resolutions (width, height) in pixels, + or None if any resolution is supported. + """ + return None + + def default_framerate(self) -> float | None: + """ + The model's preferred framerate when generating video. + Returns None when not applicable. + """ + return None + + @abstractmethod + def requires_seeding(self) -> int: + """Whether or not this model requires to be seeded with images before inference.""" + return False + + # ----------- Requests handling + + def request_inference(self, req: InferenceRequest) -> asyncio.Task: + if not self.model_seeded: + raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.") + if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results): + raise ValueError(f"Invalid request id '{req.request_id}': request already exists.") + self.check_valid_request(req) + + task = asyncio.create_task(self.run_inference(req)) + self.inference_tasks[req.request_id] = task + self.request_history.add(req.request_id) + return task + + + async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult: + await self.request_inference(req) + result = self.inference_result_or_none(req.request_id) + assert isinstance(result, InferenceResult) + return result + + + def inference_result_or_none(self, request_id: str) -> InferenceResult | None: + if request_id in self.inference_tasks: + task = self.inference_tasks[request_id] + if task.done(): + try: + # Inference result ready, cache it and return it + result = task.result() + self.inference_results[request_id] = result + del self.inference_tasks[request_id] + self.evict_results() + return result + except Exception as e: + # Inference failed + log.error(f"Task for request '{request_id}' failed with exception {e}") + raise e + else: + # Inference result not ready yet + return None + + elif request_id in self.inference_results: + # Inference result was ready and cached, return it directly + return self.inference_results[request_id] + + elif request_id in self.request_history: + raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.") + + else: + raise KeyError(f"Invalid request id '{request_id}': request not known.") + + + def evict_results(self, keep_max: int | None = None): + """ + Evict all results that were added before the last `keep_max` entries. + """ + keep_max = keep_max if (keep_max is not None) else self.inference_cache_size + + to_evict = [] + for i, k in enumerate(reversed(self.inference_results)): + if i < keep_max: + continue + to_evict.append(k) + for k in to_evict: + del self.inference_results[k] + + + def get_latest_rgb(self) -> np.ndarray | None: + """Returns the latest generated RGB image, if any. Useful for debugging.""" + if not self.inference_results: + return None + last_key = next(reversed(self.inference_results.keys())) + return self.inference_results[last_key].images[-1, ...] + + def check_valid_request(self, req: InferenceRequest): + if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1): + raise ValueError(f"This model can produce between {self.min_frames_per_request()} and" + f" {self.max_frames_per_request()} frames per request, but the request" + f" specified {len(req)} camera poses.") + + return True + + + # ----------- Resource management + def cleanup(self): + pass diff --git a/gui/api/server_cosmos.py b/gui/api/server_cosmos.py new file mode 100644 index 0000000000000000000000000000000000000000..114fdd5cb9e57d7fa2a419f5b5d1adb53030a7af --- /dev/null +++ b/gui/api/server_cosmos.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from os.path import join, realpath +import sys +try: + from typing import override +except ImportError: + def override(f): + return f + +from loguru import logger as log +import numpy as np + +from multi_gpu import MultiGPUInferenceAR +from server_base import ROOT_DIR +from server_cosmos_base import CosmosBaseModel + +COSMOS_PREDICT1_ROOT = ROOT_DIR + +TORCHRUN_DEFAULT_MASTER_ADDR = 'localhost' +TORCHRUN_DEFAULT_MASTER_PORT = 12355 + + +def add_cosmos_venv_to_path(): + version_string = f"python{sys.version_info.major}.{sys.version_info.minor}" + extras = [ + COSMOS_PREDICT1_ROOT, + join(COSMOS_PREDICT1_ROOT, "cosmos_predict1"), + ] + for e in extras: + if e not in sys.path: + sys.path.append(e) + + +class CosmosModel(CosmosBaseModel): + """ + Serves frames generated on-the-fly by the Cosmos generative model. + Intended for use with the Cosmos-Predict-1 based Gen3C model. + """ + + def __init__(self, gpu_count: int = 0, **kwargs): + add_cosmos_venv_to_path() + if not os.environ.get("HF_HOME"): + os.environ["HF_HOME"] = join(COSMOS_PREDICT1_ROOT, "huggingface_home") + + super().__init__(**kwargs) + + assert os.path.isdir(join(COSMOS_PREDICT1_ROOT, "cosmos_predict1")), \ + f"Could not find Cosmos (cosmos_predict1) directory at: {COSMOS_PREDICT1_ROOT}" + + + from cosmos_predict1.diffusion.inference.gen3c_persistent import Gen3cPersistentModel, create_parser + import torch + + if gpu_count == 0: + # Use as many GPUs for inference as are available on this machine. + gpu_count = torch.cuda.device_count() + + # Note: we use the argparse-based interface so that all defaults are preserved. + parser = create_parser() + common_args = [ + "--checkpoint_dir", self.checkpoint_path or join(COSMOS_PREDICT1_ROOT, "checkpoints"), + "--video_save_name=", # Empty string + "--video_save_folder", join(COSMOS_PREDICT1_ROOT, "outputs"), + "--trajectory", "none", + "--prompt=", # Empty string + "--negative_prompt=", # Empty string + "--offload_prompt_upsampler", + "--disable_prompt_upsampler", + "--disable_guardrail", + "--num_gpus", str(gpu_count), + "--guidance", "1.0", + "--num_video_frames", "121", + "--foreground_masking", + ] + args = parser.parse_args(common_args) + + if gpu_count == 1: + self.model = Gen3cPersistentModel(args) + else: + log.info(f"Loading Cosmos-Predict1 inference model on {gpu_count} GPUs.") + self.model = MultiGPUInferenceAR(gpu_count, cosmos_variant="predict1", args=args) + + # Since the model may require overlap of inference batches, + # we save previous inference poses so that we can provide any number of + # previous camera poses when starting the next inference batch. + # TODO: ensure some kind of ordering? + self.pose_history_w2c: list[np.array] = [] + self.intrinsics_history: list[np.array] = [] + + self.default_focal_length = (338.29, 338.29) + self.default_principal_point = (0.5, 0.5) + self.aabb_min = np.array([-16, -16, -16]) + self.aabb_max = np.array([16, 16, 16]) + + + def inference_resolution(self) -> list[tuple[int, int]] | None: + """The supported inference resolutions, or None if any resolution is supported.""" + return [(1280, 704),] + + + @override + def max_frames_per_request(self) -> int: + # Not actually tested, but anyway we can expect autoregressive + # generation to go wrong earlier than this. + return self.model.frames_per_batch * 100 + + @override + def default_framerate(self) -> float: + return 24.0 + + def cleanup(self): + if isinstance(self.model, MultiGPUInferenceAR): + self.model.cleanup() + + @override + def metadata(self) -> dict: + result = super().metadata() + result["model_name"] = "CosmosModel" + return result + + +if __name__ == "__main__": + model = CosmosModel() diff --git a/gui/api/server_cosmos_base.py b/gui/api/server_cosmos_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a03db092ce521861a2e986c55688b390c8010c87 --- /dev/null +++ b/gui/api/server_cosmos_base.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +try: + from typing import override +except ImportError: + def override(f): + return f + +from loguru import logger as log +import numpy as np + +from api_types import InferenceRequest, InferenceResult, CompressedInferenceResult, SeedingRequest, SeedingResult +from encoding import compress_images, CompressionFormat +from server_base import InferenceModel + + +class CosmosBaseModel(InferenceModel): + """ + Wraps a video generative model. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + + @override + async def make_test_image(self) -> InferenceResult: + raise NotImplementedError("Not implemented: make_test_image()") + + + async def seed_model(self, req: SeedingRequest) -> None: + import torch + + log.info(f"Seeding the model with request ID '{req.request_id}' ({len(req)} frames)") + # TODO: option to seed without clearing the existing cache + if self.pose_history_w2c: + log.info("[i] Clearing existing 3D cache and history due to seeding.") + self.model.clear_cache() + self.pose_history_w2c.clear() + self.intrinsics_history.clear() + + if hasattr(self.model, 'seed_model_from_values'): + seeding_method = self.model.seed_model_from_values + else: + raise RuntimeError(f"Could not locate seeding method in model.") + + model_result = seeding_method( + images_np=req.images, + depths_np=req.depths, + masks_np=req.masks, + world_to_cameras_np=req.world_to_cameras(), + focal_lengths_np=req.focal_lengths, + principal_point_rel_np=req.principal_points, + resolutions=req.resolutions, + ) + self.model_seeded = True + log.info("[+] Model seeded.") + + out_depths = None if (req.depths is not None) else self.model.get_cache_input_depths().cpu().numpy() + if model_result is None: + return SeedingResult.from_request(req, fallback_depths=out_depths) + else: + model_result = list(model_result) + for i, r in enumerate(model_result): + if isinstance(r, torch.Tensor): + model_result[i] = r.cpu().numpy() + + (estimated_w2c_b44, estimated_focal_lengths_b2, + estimated_principal_point_abs_b2, working_resolutions_b2) = model_result + + # Principal point is expected to be relative to the resolution + estimated_principal_point_rel_b2 = estimated_principal_point_abs_b2 / working_resolutions_b2 + + return SeedingResult( + request_id=req.request_id, + cameras_to_world=estimated_w2c_b44[:, :3, :], + focal_lengths=estimated_focal_lengths_b2, + principal_points=estimated_principal_point_rel_b2, + resolutions=working_resolutions_b2, + depths=out_depths + ) + + @override + async def run_inference(self, req: InferenceRequest) -> InferenceResult: + import torch + + async with self.inference_lock: + log.info(f"[+] Running inference for request \"{req.request_id}\"...") + start_time = time.time() + + w2c = req.world_to_cameras() + # Tricky: we receive intrinsics as in absolute units, assuming the + # resolution requested by the user. But the V2V codebase expects + # intrinsics in absolute units w.r.t. the *original seeding resolution*. + original_res = req.resolutions.copy() + original_res[:, 0] = self.model.W + original_res[:, 1] = self.model.H + intrinsics = req.intrinsics_matrix(for_resolutions=original_res) + + # We allow some overlaps on the cameras here during the inference. + if len(self.pose_history_w2c) == 0: + # First request: no frames to overlap + overlap_frames = 0 # from which frame the model starts prediction + else: + # Subsequent requests: reuse `overlap_frames` poses from the most + # recent completed request. + overlap_frames = self.model.inference_overlap_frames + assert overlap_frames < self.min_frames_per_request() + + w2c = np.concatenate([ + self.pose_history_w2c[-1][-overlap_frames:, ...], + w2c[:-overlap_frames, ...] + ], axis=0) + intrinsics = np.concatenate([ + self.intrinsics_history[-1][-overlap_frames:, ...], + intrinsics[:-overlap_frames, ...] + ], axis=0) + + self.pose_history_w2c.append(w2c) + self.intrinsics_history.append(intrinsics) + + # Run inference given the cameras + inference_results = self.model.inference_on_cameras( + w2c, + intrinsics, + fps=req.framerate, + overlap_frames=overlap_frames, + return_estimated_depths=req.return_depths, + video_save_quality=req.video_encoding_quality, + save_buffer=req.show_cache_renderings, + ) + if isinstance(inference_results, dict): + pred_no_overlap = inference_results['video_no_overlap'] + predicted_depth = inference_results['predicted_depth'] + video_save_path = inference_results.get('video_save_path') + else: + # Assume tuple or list + _, _, _, pred_no_overlap, predicted_depth = inference_results + video_save_path = None + + # Instead of synchronizing, which will block this thread and never yield to the + # asyncio event loop, we record a CUDA event and yield until it is reached + # by the GPU (= inference is complete). + cuda_event = torch.cuda.Event() + cuda_event.record() + while not cuda_event.query(): + await asyncio.sleep(0.0005) + + if self.fake_delay_ms > 0: + await asyncio.sleep(self.fake_delay_ms / 1000.0) + + # Note: we remove the overlap frame(s), if any, before returning the result. + if isinstance(pred_no_overlap, torch.Tensor): + pred_no_overlap = pred_no_overlap.cpu().numpy() + if pred_no_overlap.ndim == 5: + assert pred_no_overlap.shape[0] == 1, pred_no_overlap.shape + pred_no_overlap = pred_no_overlap.squeeze() + n_frames = pred_no_overlap.shape[0] + # Reorder [n_frames, channels, height, width] to [n_frames, height, width, channels] + images = pred_no_overlap.transpose(0, 2, 3, 1) + + if req.return_depths: + if isinstance(predicted_depth, torch.Tensor): + predicted_depth = predicted_depth.cpu().numpy() + # Desired shape: n_frames, height, width + if predicted_depth.ndim == 4: + assert predicted_depth.shape[1] == 1, predicted_depth.shape + predicted_depth = predicted_depth[:, 0, ...] + depths = predicted_depth + else: + depths = None + + # TODO: for dynamic scenes, get actual timestamps for each frame? + timestamps = np.zeros((n_frames,)) + + upper = (-overlap_frames) if (overlap_frames > 0) else None # For easier slicing + kwargs = { + 'request_id': req.request_id, + 'result_ids': [f"{req.request_id}__frame_{k}" for k in range(n_frames)], + 'timestamps': timestamps, + 'cameras_to_world': req.cameras_to_world[:upper, ...], + 'focal_lengths': req.focal_lengths[:upper, ...], + 'principal_points': req.principal_points[:upper, ...], + 'frame_count_without_padding': req.frame_count_without_padding, + 'runtime_ms': 1000 * (time.time() - start_time), + } + if self.compress_inference_results and (video_save_path is not None): + video_bytes = open(video_save_path, "rb").read() + depths_compressed = compress_images(depths, CompressionFormat.NPZ, is_depth=True) + + result = CompressedInferenceResult( + images=None, + depths=None, + resolutions=np.tile([[images.shape[2], images.shape[1]]], (images.shape[0], 1)), + images_compressed=[video_bytes], + images_format=CompressionFormat.MP4, + depths_compressed=depths_compressed, # May be None + depths_format=CompressionFormat.NPZ, + **kwargs + ) + else: + result = InferenceResult( + images=images, + depths=depths, + **kwargs + ) + + return result + + + @override + def min_frames_per_request(self) -> int: + # Note this might not be strictly respected due to overlap frames, + # starting at the second inference batch. + return self.model.frames_per_batch + + @override + def max_frames_per_request(self) -> int: + return self.model.frames_per_batch + + def inference_resolution(self) -> list[tuple[int, int]] | None: + """The supported inference resolutions (width, height), + or None if any resolution is supported.""" + try: + r = self.model.cfg.train_data.shared_params.crop_size + except AttributeError: + r = (self.model.H, self.model.W) + return [(r[1], r[0]),] + + @override + def inference_time_per_frame(self) -> int: + # TODO: actual mean inference time + return 4.0 + + @override + def requires_seeding(self) -> bool: + return True + + @override + def metadata(self) -> dict: + return { + "model_name": "CosmosBaseModel", + "model_version": (1, 0, 0), + "aabb_min": self.aabb_min.tolist(), + "aabb_max": self.aabb_max.tolist(), + "min_frames_per_request": self.min_frames_per_request(), + "max_frames_per_request": self.max_frames_per_request(), + "inference_resolution": self.inference_resolution(), + "inference_time_per_frame": self.inference_time_per_frame(), + "default_framerate": self.default_framerate(), + "requires_seeding": self.requires_seeding(), + } diff --git a/gui/api/v2v_utils.py b/gui/api/v2v_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16069726304e9969105442868394064f9cc3450a --- /dev/null +++ b/gui/api/v2v_utils.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import io +import os +from os.path import join, isdir, isfile +import zipfile + +import imageio.v3 as imageio +import json +import numpy as np +import pyexr +from tqdm import tqdm + +from api_types import CompressedSeedingRequest, SeedingRequest +from encoding import CompressionFormat + + +def srgb_to_linear(img): + limit = 0.04045 + mask = img > limit + + # Process the two cases in parallel using NumPy's vectorized operations + result = np.empty_like(img) + result[mask] = np.power((img[mask] + 0.055) / 1.055, 2.4) + result[~mask] = img[~mask] / 12.92 + + return result + + +def load_gen3c_seeding_data(data_directory: str, max_frames: int | None = None, + frames_stride: int = 1) -> CompressedSeedingRequest: + """ + Example directory structure: + ├── camera.npz + ├── depth.npz + ├── mask.npz + ├── metadata.json + └── rgb.mp4 + + We will keep the data compressed as much as possible so that it can + be uploaded faster to the inference server. + """ + bar = tqdm(range(6), desc="Seeding data loading") + + # [n_frames, height, width], float16 + depths = np.load(join(data_directory, "depth.npz"))['depth'] + assert depths.ndim == 3, depths.shape + n_img = depths.shape[0] + resolutions = np.tile([depths.shape[2], depths.shape[1]], reps=(n_img, 1)) + assert resolutions.shape == (n_img, 2) + + with io.BytesIO() as f: + np.savez_compressed(f, depths) + depths_compressed = f.getvalue() + bar.update(1) + + # Intrinsics: [n_frames, 3, 3], float32 + # Organized as: + # [[fx, 0, cx], + # [ 0, fy, cy], + # [ 0, 0, 1]] + camera_data = np.load(join(data_directory, "camera.npz")) + intrinsics = camera_data['intrinsics'] + # Absolute focal lengths + focal_lengths = np.stack([intrinsics[:, 0, 0], intrinsics[:, 1, 1]], axis=1) + assert focal_lengths.shape == (n_img, 2) + # Relative principal points + principal_points = (intrinsics[:, :2, 2] / resolutions).astype(np.float32) + assert principal_points.shape == (n_img, 2) + bar.update(1) + + # [n_frames, height, width], bool + masks = np.load(join(data_directory, "mask.npz"))['mask'] + with io.BytesIO() as f: + np.savez_compressed(f, masks) + masks_compressed = f.getvalue() + bar.update(1) + + # TODO: set the frontend's FPS slider based on `metadata["fps"]` + # metadata = json.load(open(join(data_directory, "metadata.json"))) + bar.update(1) + + images_compressed = open(join(data_directory, "rgb.mp4"), "rb").read() + bar.update(1) + + # [n_frames, 4, 4], float32 + w2c = camera_data['w2c'] + cameras_to_world = np.linalg.inv(w2c)[:, :3, :] + assert cameras_to_world.shape == (n_img, 3, 4) + bar.update(1) + + + return CompressedSeedingRequest( + request_id="__seeding_from_files", + images=None, # Will be auto-filled with placeholders + depths=None, # Will be auto-filled with placeholders + masks=None, # Will be auto-filled with placeholders + cameras_to_world=cameras_to_world, + focal_lengths=focal_lengths, + principal_points=principal_points, + resolutions=resolutions, + images_compressed=[images_compressed], + images_format=CompressionFormat.MP4, + depths_compressed=[depths_compressed], + depths_format=CompressionFormat.NPZ, + masks_compressed=[masks_compressed], + masks_format=CompressionFormat.NPZ, + ) + + + +def load_v2v_seeding_data(data_directory: str, max_frames: int | None = None, + frames_stride: int = 1) -> SeedingRequest: + """ + The seeding data would typically come from the client. + For convenience during debugging, we allow loading it here. + """ + + if isdir(data_directory): + # --- Load seeding data from a directory. + if isfile(join(data_directory, "rgb.mp4")) and isfile(join(data_directory, "metadata.json")): + return load_gen3c_seeding_data(data_directory, max_frames=max_frames, + frames_stride=frames_stride) + + # Gen3C / INGP pre-processed format. + # We assume depths, camera poses, etc are included. + # Load the seeding frames + n_img = len([img for img in sorted(os.listdir(join(data_directory, 'rgb'))) + if img.endswith('.jpg')]) + images = [] + depths = [] + for i_frame in range(n_img): + # Load image data + image = imageio.imread(join(data_directory, 'rgb', f'{i_frame:05d}.jpg')) + image_np = image.astype(np.float32) / 255.0 + + # Load depth data + depth_np = np.load(join(data_directory, 'depth', f'{i_frame:05d}.npz'))['depth'] + images.append(image_np) + depths.append(depth_np) + del image_np, depth_np + + # Load camera trajectory + with open(join(data_directory, 'cameras.json'), 'r') as f: + cameras = json.load(f) + cameras_to_world = np.asarray(cameras)[:n_img] + + if (max_frames is not None) and (max_frames < len(images)): + images = images[::frames_stride][:max_frames] + depths = depths[::frames_stride][:max_frames] + cameras_to_world = cameras_to_world[::frames_stride][:max_frames] + + + else: + # --- Load a single image. + # We will have to assume camera poses, etc and let depth be auto-estimated. + n_img = 1 + image = imageio.imread(data_directory) + images = [image.astype(np.float32) / 255.0] + depths = None + cameras_to_world = np.eye(4)[None, :3, :] + + # Shape: [batch, height, width, 3] + images = np.stack(images, axis=0) + if depths is not None: + # Shape: [batch, height, width] + depths = np.stack(depths, axis=0) + + # Note: assumed based on how this data was generated + resolutions = np.tile([images.shape[2], images.shape[1]], reps=(n_img, 1)) + fov_y_rad = np.pi * (50.625 / 180.0) + f = 0.5 / (np.tan(fov_y_rad / 2.0)) * resolutions[:, 1] + focal_lengths = np.stack([f, f], axis=-1) + principal_points = np.full((n_img, 2), 0.5) + + return SeedingRequest( + request_id="__seeding_from_files", + images=images, + depths=depths, + cameras_to_world=cameras_to_world, + focal_lengths=focal_lengths, + principal_points=principal_points, + resolutions=resolutions, + ) + + +def ensure_alpha_channel(image: np.ndarray): + # Allow alpha channel to be omitted for faster transfers + assert image.shape[-1] in (3, 4) + if image.shape[-1] == 3: + image = np.concatenate([image, np.ones((*image.shape[:2], 1))], + axis=-1) + image = image.astype(np.float32) + return image + + +def apply_to_pytree(pytree, cb): + tp = type(pytree) + if pytree is None: + return None + elif isinstance(pytree, (tuple, list)): + return tp([apply_to_pytree(v, cb) for v in pytree]) + elif isinstance(pytree, dict): + return { k: apply_to_pytree(v, cb) for k, v in pytree.items() } + else: + return cb(pytree) + + +def move_to_device(pytree, device): + import torch + + def move(pytree): + if torch.is_tensor(pytree): + return pytree.to(device) + elif isinstance(pytree, np.ndarray): + return torch.from_numpy(pytree).to(device) + else: + # Let's assume it's a not something we need to move + return pytree + # raise NotImplementedError(f"move_to_device(): unsupported type {type(pytree)}") + + return apply_to_pytree(pytree, move) + + +def clone_tensors(pytree): + import torch + + def clone(pytree): + if torch.is_tensor(pytree): + return pytree.clone() + elif isinstance(pytree, np.ndarray): + return pytree.copy() + else: + # Let's assume it's a not something we need to copy + return pytree + # raise NotImplementedError(f"clone_tensors(): unsupported type {type(pytree)}") + + return apply_to_pytree(pytree, clone) diff --git a/gui/api/video_stream.py b/gui/api/video_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..812ab2095c1b2ee7ff49673258c35ebaa92cfa77 --- /dev/null +++ b/gui/api/video_stream.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cv2 +import numpy as np + +class RawVideoStream(): + """ + A video stream from a raw mp4 file, using opencv. + This does not support nested iterations. + """ + + def __init__( + self, path: str, seek_range: range | None = None + ) -> None: + super().__init__() + if seek_range is None: + seek_range = range(-1) + + self.path = path + + # Read metadata + vcap = cv2.VideoCapture(self.path) + self._width = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self._height = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + _fps = vcap.get(cv2.CAP_PROP_FPS) + _n_frames = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)) + vcap.release() + + self.start = seek_range.start + self.end = seek_range.stop if seek_range.stop != -1 else _n_frames + self.end = min(self.end, _n_frames) + self.step = seek_range.step + self._fps = _fps / self.step + + def frame_size(self) -> tuple[int, int]: + """Returns (height, width).""" + return (self._height, self._width) + + def fps(self) -> float: + return self._fps + + def __len__(self) -> int: + return len(range(self.start, self.end, self.step)) + + def __iter__(self): + self.vcap = cv2.VideoCapture(self.path) + self.current_frame_idx = -1 + return self + + def __next__(self) -> tuple[int, np.ndarray]: + while True: + ret, frame = self.vcap.read() + self.current_frame_idx += 1 + + if not ret: + self.vcap.release() + raise StopIteration + + if self.current_frame_idx >= self.end: + self.vcap.release() + raise StopIteration + + if self.current_frame_idx < self.start: + continue + + if (self.current_frame_idx - self.start) % self.step == 0: + break + + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + return self.current_frame_idx, frame diff --git a/gui/assets/.gitignore b/gui/assets/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..eead191c69dadaf4e3f1b768ee7c8b31269f3963 --- /dev/null +++ b/gui/assets/.gitignore @@ -0,0 +1 @@ +gui_preview_large.png diff --git a/gui/cmake/bin2c_wrapper.cmake b/gui/cmake/bin2c_wrapper.cmake new file mode 100644 index 0000000000000000000000000000000000000000..42c06fdfd8dd32d87e77f8178f0bf019475e0459 --- /dev/null +++ b/gui/cmake/bin2c_wrapper.cmake @@ -0,0 +1,51 @@ +# +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +set(file_contents) +foreach(obj ${OBJECTS}) + get_filename_component(obj_fullname ${obj} NAME) + get_filename_component(obj_ext ${obj} EXT) + get_filename_component(obj_name ${obj} NAME_WE) + get_filename_component(obj_dir ${obj} DIRECTORY) + + STRING(REPLACE "." "_" FILENAME_FIXED ${obj_fullname}) + + if(obj_ext MATCHES ".ptx" OR obj_ext MATCHES ".bin" OR obj_ext MATCHES ".mdl") + set(args --name ${FILENAME_FIXED} ${obj}) + execute_process(COMMAND "${BIN_TO_C_COMMAND}" ${args} + WORKING_DIRECTORY ${obj_dir} + RESULT_VARIABLE result + OUTPUT_VARIABLE output + ERROR_VARIABLE error_var + ) + set(file_contents "${file_contents} \n${output}") + else() + message(WARNING "Unhandled extension in bin2c wrapper: " ${obj_ext}) + endif() +endforeach() +file(WRITE "${OUTPUT}" "${file_contents}") diff --git a/gui/dependencies/cuda-cmake-github-actions/LICENSE b/gui/dependencies/cuda-cmake-github-actions/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a7f2e9fb1c15c09aaa175b77f2af84e254939250 --- /dev/null +++ b/gui/dependencies/cuda-cmake-github-actions/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Peter Heywood + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_ubuntu.sh b/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_ubuntu.sh new file mode 100644 index 0000000000000000000000000000000000000000..610717dca94b1244f0ddf47db9122e080cce80c1 --- /dev/null +++ b/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_ubuntu.sh @@ -0,0 +1,181 @@ +# @todo - better / more robust parsing of inputs from env vars. +## ------------------- +## Constants +## ------------------- + +# @todo - apt repos/known supported versions? + +# @todo - GCC support matrix? + +# List of sub-packages to install. +# @todo - pass this in from outside the script? +# @todo - check the specified subpackages exist via apt pre-install? apt-rdepends cuda-9-0 | grep "^cuda-"? + +# Ideally choose from the list of meta-packages to minimise variance between cuda versions (although it does change too) +CUDA_PACKAGES_IN=( + "command-line-tools" + "libraries-dev" +) + +## ------------------- +## Bash functions +## ------------------- +# returns 0 (true) if a >= b +function version_ge() { + [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 + [ "$(printf '%s\n' "$@" | sort -V | head -n 1)" == "$2" ] +} +# returns 0 (true) if a > b +function version_gt() { + [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 + [ "$1" = "$2" ] && return 1 || version_ge $1 $2 +} +# returns 0 (true) if a <= b +function version_le() { + [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 + [ "$(printf '%s\n' "$@" | sort -V | head -n 1)" == "$1" ] +} +# returns 0 (true) if a < b +function version_lt() { + [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 + [ "$1" = "$2" ] && return 1 || version_le $1 $2 +} + +## ------------------- +## Select CUDA version +## ------------------- + +# Get the cuda version from the environment as $cuda. +CUDA_VERSION_MAJOR_MINOR=${cuda} + +# Split the version. +# We (might/probably) don't know PATCH at this point - it depends which version gets installed. +CUDA_MAJOR=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f1) +CUDA_MINOR=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f2) +CUDA_PATCH=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f3) +# use lsb_release to find the OS. +UBUNTU_VERSION=$(lsb_release -sr) +UBUNTU_VERSION="${UBUNTU_VERSION//.}" + +echo "CUDA_MAJOR: ${CUDA_MAJOR}" +echo "CUDA_MINOR: ${CUDA_MINOR}" +echo "CUDA_PATCH: ${CUDA_PATCH}" +# echo "UBUNTU_NAME: ${UBUNTU_NAME}" +echo "UBUNTU_VERSION: ${UBUNTU_VERSION}" + +# If we don't know the CUDA_MAJOR or MINOR, error. +if [ -z "${CUDA_MAJOR}" ] ; then + echo "Error: Unknown CUDA Major version. Aborting." + exit 1 +fi +if [ -z "${CUDA_MINOR}" ] ; then + echo "Error: Unknown CUDA Minor version. Aborting." + exit 1 +fi +# If we don't know the Ubuntu version, error. +if [ -z ${UBUNTU_VERSION} ]; then + echo "Error: Unknown Ubuntu version. Aborting." + exit 1 +fi + + +## --------------------------- +## GCC studio support check? +## --------------------------- + +# @todo + +## ------------------------------- +## Select CUDA packages to install +## ------------------------------- +CUDA_PACKAGES="" +for package in "${CUDA_PACKAGES_IN[@]}" +do : + # @todo This is not perfect. Should probably provide a separate list for diff versions + # cuda-compiler-X-Y if CUDA >= 9.1 else cuda-nvcc-X-Y + if [[ "${package}" == "nvcc" ]] && version_ge "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then + package="compiler" + elif [[ "${package}" == "compiler" ]] && version_lt "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then + package="nvcc" + fi + # Build the full package name and append to the string. + CUDA_PACKAGES+=" cuda-${package}-${CUDA_MAJOR}-${CUDA_MINOR}" +done +echo "CUDA_PACKAGES ${CUDA_PACKAGES}" + +## ----------------- +## Prepare to install +## ----------------- + +PIN_FILENAME="cuda-ubuntu${UBUNTU_VERSION}.pin" +PIN_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/${PIN_FILENAME}" +APT_KEY_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/3bf863cc.pub" +REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/" + +echo "PIN_FILENAME ${PIN_FILENAME}" +echo "PIN_URL ${PIN_URL}" +echo "APT_KEY_URL ${APT_KEY_URL}" + +## ----------------- +## Check for root/sudo +## ----------------- + +# Detect if the script is being run as root, storing true/false in is_root. +is_root=false +if (( $EUID == 0)); then + is_root=true +fi +# Find if sudo is available +has_sudo=false +if command -v sudo &> /dev/null ; then + has_sudo=true +fi +# Decide if we can proceed or not (root or sudo is required) and if so store whether sudo should be used or not. +if [ "$is_root" = false ] && [ "$has_sudo" = false ]; then + echo "Root or sudo is required. Aborting." + exit 1 +elif [ "$is_root" = false ] ; then + USE_SUDO=sudo +else + USE_SUDO= +fi + +## ----------------- +## Install +## ----------------- +echo "Adding CUDA Repository" +wget ${PIN_URL} +$USE_SUDO mv ${PIN_FILENAME} /etc/apt/preferences.d/cuda-repository-pin-600 +$USE_SUDO apt-key adv --fetch-keys ${APT_KEY_URL} +$USE_SUDO add-apt-repository "deb ${REPO_URL} /" +$USE_SUDO apt-get update + +echo "Installing CUDA packages ${CUDA_PACKAGES}" +$USE_SUDO apt-get -y install ${CUDA_PACKAGES} + +if [[ $? -ne 0 ]]; then + echo "CUDA Installation Error." + exit 1 +fi +## ----------------- +## Set environment vars / vars to be propagated +## ----------------- + +CUDA_PATH=/usr/local/cuda-${CUDA_MAJOR}.${CUDA_MINOR} +echo "CUDA_PATH=${CUDA_PATH}" +export CUDA_PATH=${CUDA_PATH} + + +# Quick test. @temp +export PATH="$CUDA_PATH/bin:$PATH" +export LD_LIBRARY_PATH="$CUDA_PATH/lib:$LD_LIBRARY_PATH" +nvcc -V + +# If executed on github actions, make the appropriate echo statements to update the environment +if [[ $GITHUB_ACTIONS ]]; then + # Set paths for subsequent steps, using ${CUDA_PATH} + echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH" + echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV + echo "${CUDA_PATH}/bin" >> $GITHUB_PATH + echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV +fi diff --git a/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1 b/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..32948ddb0f0913dd500abd69818f114ac8f0bc78 --- /dev/null +++ b/gui/dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1 @@ -0,0 +1,171 @@ +## ------------------- +## Constants +## ------------------- + +# Dictionary of known cuda versions and thier download URLS, which do not follow a consistent pattern :( +$CUDA_KNOWN_URLS = @{ + "8.0.44" = "http://developer.nvidia.com/compute/cuda/8.0/Prod/network_installers/cuda_8.0.44_win10_network-exe"; + "8.0.61" = "http://developer.nvidia.com/compute/cuda/8.0/Prod2/network_installers/cuda_8.0.61_win10_network-exe"; + "9.0.176" = "http://developer.nvidia.com/compute/cuda/9.0/Prod/network_installers/cuda_9.0.176_win10_network-exe"; + "9.1.85" = "http://developer.nvidia.com/compute/cuda/9.1/Prod/network_installers/cuda_9.1.85_win10_network"; + "9.2.148" = "http://developer.nvidia.com/compute/cuda/9.2/Prod2/network_installers2/cuda_9.2.148_win10_network"; + "10.0.130" = "http://developer.nvidia.com/compute/cuda/10.0/Prod/network_installers/cuda_10.0.130_win10_network"; + "10.1.105" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.105_win10_network.exe"; + "10.1.168" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.168_win10_network.exe"; + "10.1.243" = "http://developer.download.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.243_win10_network.exe"; + "10.2.89" = "http://developer.download.nvidia.com/compute/cuda/10.2/Prod/network_installers/cuda_10.2.89_win10_network.exe"; + "11.0.1" = "http://developer.download.nvidia.com/compute/cuda/11.0.1/network_installers/cuda_11.0.1_win10_network.exe"; + "11.0.2" = "http://developer.download.nvidia.com/compute/cuda/11.0.2/network_installers/cuda_11.0.2_win10_network.exe"; + "11.0.3" = "http://developer.download.nvidia.com/compute/cuda/11.0.3/network_installers/cuda_11.0.3_win10_network.exe"; + "11.1.0" = "https://developer.download.nvidia.com/compute/cuda/11.1.0/network_installers/cuda_11.1.0_win10_network.exe"; + "11.1.1" = "https://developer.download.nvidia.com/compute/cuda/11.1.1/network_installers/cuda_11.1.1_win10_network.exe"; + "11.2.0" = "https://developer.download.nvidia.com/compute/cuda/11.2.0/network_installers/cuda_11.2.0_win10_network.exe"; + "11.2.1" = "https://developer.download.nvidia.com/compute/cuda/11.2.1/network_installers/cuda_11.2.1_win10_network.exe"; + "11.2.2" = "https://developer.download.nvidia.com/compute/cuda/11.2.2/network_installers/cuda_11.2.2_win10_network.exe"; + "11.3.0" = "https://developer.download.nvidia.com/compute/cuda/11.3.0/network_installers/cuda_11.3.0_win10_network.exe"; + "11.3.1" = "https://developer.download.nvidia.com/compute/cuda/11.3.1/network_installers/cuda_11.3.1_win10_network.exe"; + "11.5.0" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe"; + "11.5.1" = "https://developer.download.nvidia.com/compute/cuda/11.5.1/network_installers/cuda_11.5.1_windows_network.exe" +} + +# @todo - change this to be based on _MSC_VER intead, or invert it to be CUDA keyed instead? +$VISUAL_STUDIO_MIN_CUDA = @{ + "2019" = "10.1"; + "2017" = "10.0"; # Depends on which version of 2017! 9.0 to 10.0 depending on version + "2015" = "8.0"; # might support older, unsure. +} + +# cuda_runtime.h is in nvcc <= 10.2, but cudart >= 11.0 +# @todo - make this easier to vary per CUDA version. +$CUDA_PACKAGES_IN = @( + "nvcc"; + "visual_studio_integration"; + "cublas_dev"; + "curand_dev"; + "nvrtc_dev"; + "cudart"; +) + + +## ------------------- +## Select CUDA version +## ------------------- + +# Get the cuda version from the environment as env:cuda. +$CUDA_VERSION_FULL = $env:cuda +# Make sure CUDA_VERSION_FULL is set and valid, otherwise error. + +# Validate CUDA version, extracting components via regex +$cuda_ver_matched = $CUDA_VERSION_FULL -match "^(?[1-9][0-9]*)\.(?[0-9]+)\.(?[0-9]+)$" +if(-not $cuda_ver_matched){ + Write-Output "Invalid CUDA version specified, .. required. '$CUDA_VERSION_FULL'." + exit 1 +} +$CUDA_MAJOR=$Matches.major +$CUDA_MINOR=$Matches.minor +$CUDA_PATCH=$Matches.patch + +## --------------------------- +## Visual studio support check +## --------------------------- +# Exit if visual studio is too new for the cuda version. +$VISUAL_STUDIO = $env:visual_studio.trim() +if ($VISUAL_STUDIO.length -ge 4) { +$VISUAL_STUDIO_YEAR = $VISUAL_STUDIO.Substring($VISUAL_STUDIO.Length-4) + if ($VISUAL_STUDIO_YEAR.length -eq 4 -and $VISUAL_STUDIO_MIN_CUDA.containsKey($VISUAL_STUDIO_YEAR)){ + $MINIMUM_CUDA_VERSION = $VISUAL_STUDIO_MIN_CUDA[$VISUAL_STUDIO_YEAR] + if ([version]$CUDA_VERSION_FULL -lt [version]$MINIMUM_CUDA_VERSION) { + Write-Output "Error: Visual Studio $($VISUAL_STUDIO_YEAR) requires CUDA >= $($MINIMUM_CUDA_VERSION)" + exit 1 + } + } +} else { + Write-Output "Warning: Unknown Visual Studio Version. CUDA version may be insufficient." +} + +## ------------------------------------------------ +## Select CUDA packages to install from environment +## ------------------------------------------------ + +$CUDA_PACKAGES = "" + +# for CUDA >= 11 cudart is a required package. +# if([version]$CUDA_VERSION_FULL -ge [version]"11.0") { +# if(-not $CUDA_PACKAGES_IN -contains "cudart") { +# $CUDA_PACKAGES_IN += 'cudart' +# } +# } + +Foreach ($package in $CUDA_PACKAGES_IN) { + # Make sure the correct package name is used for nvcc. + if($package -eq "nvcc" -and [version]$CUDA_VERSION_FULL -lt [version]"9.1"){ + $package="compiler" + } elseif($package -eq "compiler" -and [version]$CUDA_VERSION_FULL -ge [version]"9.1") { + $package="nvcc" + } + $CUDA_PACKAGES += " $($package)_$($CUDA_MAJOR).$($CUDA_MINOR)" + +} +echo "$($CUDA_PACKAGES)" +## ----------------- +## Prepare download +## ----------------- + +# Select the download link if known, otherwise have a guess. +$CUDA_REPO_PKG_REMOTE="" +if($CUDA_KNOWN_URLS.containsKey($CUDA_VERSION_FULL)){ + $CUDA_REPO_PKG_REMOTE=$CUDA_KNOWN_URLS[$CUDA_VERSION_FULL] +} else{ + # Guess what the url is given the most recent pattern (at the time of writing, 10.1) + Write-Output "note: URL for CUDA ${$CUDA_VERSION_FULL} not known, estimating." + $CUDA_REPO_PKG_REMOTE="http://developer.download.nvidia.com/compute/cuda/$($CUDA_MAJOR).$($CUDA_MINOR)/Prod/network_installers/cuda_$($CUDA_VERSION_FULL)_win10_network.exe" +} +$CUDA_REPO_PKG_LOCAL="cuda_$($CUDA_VERSION_FULL)_win10_network.exe" + + +## ------------ +## Install CUDA +## ------------ + +# Get CUDA network installer +Write-Output "Downloading CUDA Network Installer for $($CUDA_VERSION_FULL) from: $($CUDA_REPO_PKG_REMOTE)" +Invoke-WebRequest $CUDA_REPO_PKG_REMOTE -OutFile $CUDA_REPO_PKG_LOCAL | Out-Null +if(Test-Path -Path $CUDA_REPO_PKG_LOCAL){ + Write-Output "Downloading Complete" +} else { + Write-Output "Error: Failed to download $($CUDA_REPO_PKG_LOCAL) from $($CUDA_REPO_PKG_REMOTE)" + exit 1 +} + +# Invoke silent install of CUDA (via network installer) +Write-Output "Installing CUDA $($CUDA_VERSION_FULL). Subpackages $($CUDA_PACKAGES)" +Start-Process -Wait -FilePath .\"$($CUDA_REPO_PKG_LOCAL)" -ArgumentList "-s $($CUDA_PACKAGES)" + +# Check the return status of the CUDA installer. +if (!$?) { + Write-Output "Error: CUDA installer reported error. $($LASTEXITCODE)" + exit 1 +} + +# Store the CUDA_PATH in the environment for the current session, to be forwarded in the action. +$CUDA_PATH = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$($CUDA_MAJOR).$($CUDA_MINOR)" +$CUDA_PATH_VX_Y = "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" +# Set environmental variables in this session +$env:CUDA_PATH = "$($CUDA_PATH)" +$env:CUDA_PATH_VX_Y = "$($CUDA_PATH_VX_Y)" +Write-Output "CUDA_PATH $($CUDA_PATH)" +Write-Output "CUDA_PATH_VX_Y $($CUDA_PATH_VX_Y)" + +# PATH needs updating elsewhere, anything in here won't persist. +# Append $CUDA_PATH/bin to path. +# Set CUDA_PATH as an environmental variable + + +# If executing on github actions, emit the appropriate echo statements to update environment variables +if (Test-Path "env:GITHUB_ACTIONS") { + # Set paths for subsequent steps, using $env:CUDA_PATH + echo "Adding CUDA to CUDA_PATH, CUDA_PATH_X_Y and PATH" + echo "CUDA_PATH=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "$env:CUDA_PATH_VX_Y=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "$env:CUDA_PATH/bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append +} diff --git a/gui/dependencies/filesystem/LICENSE b/gui/dependencies/filesystem/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..057f600707aba168685b71abf9b798277526a9f3 --- /dev/null +++ b/gui/dependencies/filesystem/LICENSE @@ -0,0 +1,39 @@ +Copyright (c) 2016 Wenzel Jakob , + 2021-2023 Thomas Müller + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +("Enhancements") to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to the author of this software, without +imposing a separate written license agreement for such Enhancements, then you +hereby grant the following license: a non-exclusive, royalty-free perpetual +license to install, use, modify, prepare derivative works, incorporate into +other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. diff --git a/gui/dependencies/filesystem/README.md b/gui/dependencies/filesystem/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2de29e6d2b6b839617441d6da1fed0ccac7f63ad --- /dev/null +++ b/gui/dependencies/filesystem/README.md @@ -0,0 +1,5 @@ +#### filesystem/path.h: A simple class for manipulating paths on Linux/Windows/Mac OS + +This class is just a temporary workaround to avoid the heavy boost dependency +until `boost::filesystem` is integrated into the standard template library at +some point in the future. diff --git a/gui/dependencies/filesystem/filesystem/directory.h b/gui/dependencies/filesystem/filesystem/directory.h new file mode 100644 index 0000000000000000000000000000000000000000..f20d85bebb49f1606210f67d69649c02a34df6a1 --- /dev/null +++ b/gui/dependencies/filesystem/filesystem/directory.h @@ -0,0 +1,151 @@ +/* + * reference https://docs.microsoft.com/zh-cn/cpp/c-runtime-library/reference/findfirst-functions?f1url=https%3A%2F%2Fmsdn.microsoft.com%2Fquery%2Fdev15.query%3FappId%3DDev15IDEF1%26l%3DZH-CN%26k%3Dk(CORECRT_IO%2F_findfirst);k(_findfirst);k(DevLang-C%2B%2B);k(TargetOS-Windows)%26rd%3Dtrue&view=vs-2019 + * reference http://www.man7.org/linux/man-pages/man3/opendir.3.html + * + * Copyright (c) 2019 tangm421 + * + * All rights reserved. Use of this source code is governed by a + * BSD-style license that can be found in the LICENSE file. + */ + +#pragma once + +#include "path.h" + +#if defined(_WIN32) +#include +#else +#include +#include +#endif + +NAMESPACE_BEGIN(filesystem) + +class directory : public path +{ +public: + directory(const path& dir) : path(dir), m_dir(dir) {} + + class iterator + { + public: + iterator() /* default ctor indicates the end iterator */ +#if defined(_WIN32) + : m_handle(-1) {} +#else + : m_handle(NULL), m_data(NULL) {} +#endif + + iterator(const directory& dir) { + m_dir = dir; +#if defined(_WIN32) + std::wstring search_path(dir.make_absolute().wstr() + L"/*.*"); + m_handle = _wfindfirst(search_path.c_str(), &m_data); + if (is_valid_handler()) + { + m_result = m_dir / m_data.name; + } + else /* an error occurs or reaching the end */ + { + /* do nothing */ + } +#else + m_handle = opendir(dir.make_absolute().str().c_str()); + ++*this; /* here we do find the first directory entry like the begin iterator does */ +#endif + } + ~iterator() { + if (is_valid_handler()) + { +#if defined(_WIN32) + _findclose(m_handle); + m_handle = -1; +#else + closedir(m_handle); + m_handle = NULL; + m_data = NULL; +#endif + } + } + + iterator& operator++() { + if (is_valid_handler()) + { +#if defined(_WIN32) + if (_wfindnext(m_handle, &m_data)) + { + if (ENOENT == errno) /* reaching the end */ + { + m_result = path(); + } + else /* an error occurs */ + { + /* do nothing because the next call of this function will not do anything */ + } + } + else + { + m_result = m_dir / m_data.name; + } +#else + errno = 0; + m_data = readdir(m_handle); + if (0 != errno) /* an error occurs */ + { + /* do nothing because the next call of this function will not do anything */ + } + if (!m_data) /* reaching the end */ + { + m_result = path(); + } + else + { + m_result = m_dir / m_data->d_name; + } +#endif + } + return *this; + } + bool operator!=(const iterator& rhs) const { + return !(*this == rhs); + } + bool operator==(const iterator& rhs) const { + return **this == *rhs; + } + const path& operator*() const { + return m_result; + } + const path* operator->() const { + return &m_result; + } + + protected: + bool is_valid_handler() const { +#if defined(_WIN32) + return -1 != m_handle; +#else + return NULL != m_handle; +#endif + } + + private: + path m_dir; + path m_result; +#if defined(_WIN32) + intptr_t m_handle; + _wfinddata_t m_data; +#else + DIR* m_handle; + struct dirent* m_data; +#endif + }; + + iterator begin() const { return iterator(*this); } + iterator end() const { static iterator static_end; return static_end; }; + +private: + path m_dir; +}; + + +NAMESPACE_END(filesystem) diff --git a/gui/dependencies/filesystem/filesystem/fwd.h b/gui/dependencies/filesystem/filesystem/fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..3552199e2e60ec437ac0bb308888785e277e38bd --- /dev/null +++ b/gui/dependencies/filesystem/filesystem/fwd.h @@ -0,0 +1,24 @@ +/* + fwd.h -- Forward declarations for path.h and resolver.h + + Copyright (c) 2015 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#if !defined(NAMESPACE_BEGIN) +#define NAMESPACE_BEGIN(name) namespace name { +#endif +#if !defined(NAMESPACE_END) +#define NAMESPACE_END(name) } +#endif + +NAMESPACE_BEGIN(filesystem) + +class path; +class resolver; + +NAMESPACE_END(filesystem) diff --git a/gui/dependencies/filesystem/filesystem/path.h b/gui/dependencies/filesystem/filesystem/path.h new file mode 100644 index 0000000000000000000000000000000000000000..3094db6a92d0d5756a7160ed8fce1a1df2b07557 --- /dev/null +++ b/gui/dependencies/filesystem/filesystem/path.h @@ -0,0 +1,458 @@ +/* + path.h -- A simple class for manipulating paths on Linux/Windows/Mac OS + + Copyright (c) 2015 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "fwd.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +# include +# include +#else +# include +#endif +#include + +#if defined(__linux) +# include +#endif + +NAMESPACE_BEGIN(filesystem) + +/** + * \brief Simple class for manipulating paths on Linux/Windows/Mac OS + * + * This class is just a temporary workaround to avoid the heavy boost + * dependency until boost::filesystem is integrated into the standard template + * library at some point in the future. + */ +class path { +public: + enum path_type { + windows_path = 0, + posix_path = 1, +#if defined(_WIN32) + native_path = windows_path +#else + native_path = posix_path +#endif + }; + + path() : m_type(native_path), m_absolute(false), m_smb(false) { } + + path(const path &path) + : m_type(path.m_type), m_path(path.m_path), m_absolute(path.m_absolute), m_smb(path.m_smb) {} + + path(path &&path) + : m_type(path.m_type), m_path(std::move(path.m_path)), + m_absolute(path.m_absolute), m_smb(path.m_smb) {} + + path(const char *string) : m_smb(false) { set(string); } + + path(const std::string &string) : m_smb(false) { set(string); } + +#if defined(_WIN32) + path(const std::wstring &wstring) { set(wstring); } + path(const wchar_t *wstring) { set(wstring); } +#endif + + size_t length() const { return m_path.size(); } + + bool empty() const { return m_path.empty(); } + + bool is_absolute() const { return m_absolute; } + + path make_absolute() const { +#if !defined(_WIN32) + char temp[PATH_MAX]; + if (realpath(str().c_str(), temp) == NULL) + throw std::runtime_error("Internal error in realpath(): " + std::string(strerror(errno))); + return path(temp); +#else + std::wstring value = wstr(), out(MAX_PATH_WINDOWS, '\0'); + DWORD length = GetFullPathNameW(value.c_str(), MAX_PATH_WINDOWS, &out[0], NULL); + if (length == 0) + throw std::runtime_error("Internal error in realpath(): " + std::to_string(GetLastError())); + return path(out.substr(0, length)); +#endif + } + + bool exists() const { +#if defined(_WIN32) + return GetFileAttributesW(wstr().c_str()) != INVALID_FILE_ATTRIBUTES; +#else + struct stat sb; + return stat(str().c_str(), &sb) == 0; +#endif + } + + size_t file_size() const { +#if defined(_WIN32) + struct _stati64 sb; + if (_wstati64(wstr().c_str(), &sb) != 0) + throw std::runtime_error("path::file_size(): cannot stat file \"" + str() + "\"!"); +#else + struct stat sb; + if (stat(str().c_str(), &sb) != 0) + throw std::runtime_error("path::file_size(): cannot stat file \"" + str() + "\"!"); +#endif + return (size_t) sb.st_size; + } + + bool is_directory() const { +#if defined(_WIN32) + DWORD result = GetFileAttributesW(wstr().c_str()); + if (result == INVALID_FILE_ATTRIBUTES) + return false; + return (result & FILE_ATTRIBUTE_DIRECTORY) != 0; +#else + struct stat sb; + if (stat(str().c_str(), &sb)) + return false; + return S_ISDIR(sb.st_mode); +#endif + } + + bool is_file() const { +#if defined(_WIN32) + DWORD attr = GetFileAttributesW(wstr().c_str()); + return (attr != INVALID_FILE_ATTRIBUTES && (attr & FILE_ATTRIBUTE_DIRECTORY) == 0); +#else + struct stat sb; + if (stat(str().c_str(), &sb)) + return false; + return S_ISREG(sb.st_mode); +#endif + } + + std::string filename() const { + if (empty()) + return ""; + const std::string &last = m_path[m_path.size()-1]; + return last; + } + + std::string extension() const { + const std::string &name = filename(); + size_t pos = name.find_last_of("."); + if (pos == std::string::npos) + return ""; + return name.substr(pos+1); + } + + std::string basename() const { + const std::string &name = filename(); + size_t pos = name.find_last_of("."); + if (pos == std::string::npos) + return name; + return name.substr(0, pos); + } + + path parent_path() const { + path result; + result.m_absolute = m_absolute; + result.m_smb = m_smb; + + if (m_path.empty()) { + if (!m_absolute) + result.m_path.push_back(".."); + } else { + size_t until = m_path.size() - 1; + for (size_t i = 0; i < until; ++i) + result.m_path.push_back(m_path[i]); + } + return result; + } + + path stem() const { + return parent_path()/basename(); + } + + path with_extension(const std::string& ext) const { + return parent_path()/(basename()+"."+ext); + } + + path operator/(const path &other) const { + if (other.m_absolute) + throw std::runtime_error("path::operator/(): expected a relative path!"); + if (m_type != other.m_type) + throw std::runtime_error("path::operator/(): expected a path of the same type!"); + + path result(*this); + + for (size_t i=0; iMAX_PATH are + // not supported at all in Windows. + if (length > MAX_PATH_WINDOWS_LEGACY) { + if (m_smb) + oss << "\\\\?\\UNC\\"; + else + oss << "\\\\?\\"; + } else if (m_smb) + oss << "\\\\"; + } + } + + for (size_t i=0; iMAX_PATH characters long, so we remove it + // for convenience and add it back (if necessary) in str()/wstr(). + static const std::string LONG_PATH_PREFIX = "\\\\?\\"; + if (tmp.length() >= LONG_PATH_PREFIX.length() + && std::mismatch(std::begin(LONG_PATH_PREFIX), std::end(LONG_PATH_PREFIX), std::begin(tmp)).first == std::end(LONG_PATH_PREFIX)) { + tmp.erase(0, LONG_PATH_PREFIX.length()); + } + + // Special-case handling of absolute SMB paths, which start with the prefix "\\". + if (tmp.length() >= 2 && tmp[0] == '\\' && tmp[1] == '\\') { + m_path = {}; + tmp.erase(0, 2); + + // Interestingly, there is a special-special case where relative paths may be specified as beginning with a "\\" + // when a non-SMB file with a more-than-260-characters-long absolute _local_ path is double-clicked. This seems to + // only happen with single-segment relative paths, so we can check for this condition by making sure no further + // path separators are present. + if (tmp.find_first_of("/\\") != std::string::npos) + m_absolute = m_smb = true; + else + m_absolute = m_smb = false; + + // Special-case handling of absolute SMB paths, which start with the prefix "UNC\" + } else if (tmp.length() >= 4 && tmp[0] == 'U' && tmp[1] == 'N' && tmp[2] == 'C' && tmp[3] == '\\') { + m_path = {}; + tmp.erase(0, 4); + m_absolute = true; + m_smb = true; + // Special-case handling of absolute local paths, which start with the drive letter and a colon "X:\" + // So that UTF-8 works, do not call std::isalpha if the high bit is set, as that causes an assert on Windows. + } else if (tmp.length() >= 3 && ((unsigned char)tmp[0] < 0x80) && std::isalpha(tmp[0]) && + tmp[1] == ':' && (tmp[2] == '\\' || tmp[2] == '/')) { + m_path = {tmp.substr(0, 2)}; + tmp.erase(0, 3); + m_absolute = true; + m_smb = false; + // Relative path + } else { + m_path = {}; + m_absolute = false; + m_smb = false; + } + + std::vector tokenized = tokenize(tmp, "/\\"); + m_path.insert(std::end(m_path), std::begin(tokenized), std::end(tokenized)); + } else { + m_path = tokenize(str, "/"); + m_absolute = !str.empty() && str[0] == '/'; + } + + m_path.erase(std::remove(std::begin(m_path), std::end(m_path), ""), std::end(m_path)); + } + + path &operator=(const path &path) { + m_type = path.m_type; + m_path = path.m_path; + m_absolute = path.m_absolute; + m_smb = path.m_smb; + return *this; + } + + path &operator=(path &&path) { + if (this != &path) { + m_type = path.m_type; + m_path = std::move(path.m_path); + m_absolute = path.m_absolute; + m_smb = path.m_smb; + } + return *this; + } + + friend std::ostream &operator<<(std::ostream &os, const path &path) { + os << path.str(); + return os; + } + + bool remove_file() const { +#if !defined(_WIN32) + return std::remove(str().c_str()) == 0; +#else + if (is_directory()) { + return RemoveDirectoryW(wstr().c_str()) != 0; + } else { + return DeleteFileW(wstr().c_str()) != 0; + } +#endif + } + + bool resize_file(size_t target_length) { +#if !defined(_WIN32) + return ::truncate(str().c_str(), (off_t) target_length) == 0; +#else + HANDLE handle = CreateFileW(wstr().c_str(), GENERIC_WRITE, 0, nullptr, 0, FILE_ATTRIBUTE_NORMAL, nullptr); + if (handle == INVALID_HANDLE_VALUE) + return false; + LARGE_INTEGER size; + size.QuadPart = (LONGLONG) target_length; + if (SetFilePointerEx(handle, size, NULL, FILE_BEGIN) == 0) { + CloseHandle(handle); + return false; + } + if (SetEndOfFile(handle) == 0) { + CloseHandle(handle); + return false; + } + CloseHandle(handle); + return true; +#endif + } + + static path getcwd() { +#if !defined(_WIN32) + char temp[PATH_MAX]; + if (::getcwd(temp, PATH_MAX) == NULL) + throw std::runtime_error("Internal error in getcwd(): " + std::string(strerror(errno))); + return path(temp); +#else + std::wstring temp(MAX_PATH_WINDOWS, '\0'); + if (!_wgetcwd(&temp[0], MAX_PATH_WINDOWS)) + throw std::runtime_error("Internal error in getcwd(): " + std::to_string(GetLastError())); + return path(temp.c_str()); +#endif + } + +#if defined(_WIN32) + std::wstring wstr(path_type type = native_path) const { + std::string temp = str(type); + int size = MultiByteToWideChar(CP_UTF8, 0, &temp[0], (int)temp.size(), NULL, 0); + std::wstring result(size, 0); + MultiByteToWideChar(CP_UTF8, 0, &temp[0], (int)temp.size(), &result[0], size); + return result; + } + + + void set(const std::wstring &wstring, path_type type = native_path) { + std::string string; + if (!wstring.empty()) { + int size = WideCharToMultiByte(CP_UTF8, 0, &wstring[0], (int)wstring.size(), + NULL, 0, NULL, NULL); + string.resize(size, 0); + WideCharToMultiByte(CP_UTF8, 0, &wstring[0], (int)wstring.size(), + &string[0], size, NULL, NULL); + } + set(string, type); + } + + path &operator=(const std::wstring &str) { set(str); return *this; } +#endif + + bool operator==(const path &p) const { return p.m_path == m_path; } + bool operator!=(const path &p) const { return p.m_path != m_path; } + +protected: + static std::vector tokenize(const std::string &string, const std::string &delim) { + std::string::size_type lastPos = 0, pos = string.find_first_of(delim, lastPos); + std::vector tokens; + + while (lastPos != std::string::npos) { + if (pos != lastPos) + tokens.push_back(string.substr(lastPos, pos - lastPos)); + lastPos = pos; + if (lastPos == std::string::npos || lastPos + 1 == string.length()) + break; + pos = string.find_first_of(delim, ++lastPos); + } + + return tokens; + } + +protected: +#if defined(_WIN32) + static const size_t MAX_PATH_WINDOWS = 32767; +#endif + static const size_t MAX_PATH_WINDOWS_LEGACY = 260; + path_type m_type; + std::vector m_path; + bool m_absolute; + bool m_smb; // Unused, except for on Windows +}; + +inline bool create_directory(const path& p) { +#if defined(_WIN32) + return CreateDirectoryW(p.wstr().c_str(), NULL) != 0; +#else + return mkdir(p.str().c_str(), S_IRWXU) == 0; +#endif +} + +inline bool create_directories(const path& p) { +#if defined(_WIN32) + return SHCreateDirectory(nullptr, p.make_absolute().wstr().c_str()) == ERROR_SUCCESS; +#else + if (create_directory(p.str().c_str())) + return true; + + if (p.empty()) + return false; + + if (errno == ENOENT) { + if (create_directory(p.parent_path())) + return mkdir(p.str().c_str(), S_IRWXU) == 0; + else + return false; + } + return false; +#endif +} + +NAMESPACE_END(filesystem) diff --git a/gui/dependencies/filesystem/filesystem/resolver.h b/gui/dependencies/filesystem/filesystem/resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..0c7576841d5ee0934967ebb1d6dd0c312461da46 --- /dev/null +++ b/gui/dependencies/filesystem/filesystem/resolver.h @@ -0,0 +1,72 @@ +/* + resolver.h -- A simple class for cross-platform path resolution + + Copyright (c) 2015 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "path.h" + +NAMESPACE_BEGIN(filesystem) + +/** + * \brief Simple class for resolving paths on Linux/Windows/Mac OS + * + * This convenience class looks for a file or directory given its name + * and a set of search paths. The implementation walks through the + * search paths in order and stops once the file is found. + */ +class resolver { +public: + typedef std::vector::iterator iterator; + typedef std::vector::const_iterator const_iterator; + + resolver() { + m_paths.push_back(path::getcwd()); + } + + size_t size() const { return m_paths.size(); } + + iterator begin() { return m_paths.begin(); } + iterator end() { return m_paths.end(); } + + const_iterator begin() const { return m_paths.begin(); } + const_iterator end() const { return m_paths.end(); } + + void erase(iterator it) { m_paths.erase(it); } + + void prepend(const path &path) { m_paths.insert(m_paths.begin(), path); } + void append(const path &path) { m_paths.push_back(path); } + const path &operator[](size_t index) const { return m_paths[index]; } + path &operator[](size_t index) { return m_paths[index]; } + + path resolve(const path &value) const { + for (const_iterator it = m_paths.begin(); it != m_paths.end(); ++it) { + path combined = *it / value; + if (combined.exists()) + return combined; + } + return value; + } + + friend std::ostream &operator<<(std::ostream &os, const resolver &r) { + os << "resolver[" << std::endl; + for (size_t i = 0; i < r.m_paths.size(); ++i) { + os << " \"" << r.m_paths[i] << "\""; + if (i + 1 < r.m_paths.size()) + os << ","; + os << std::endl; + } + os << "]"; + return os; + } + +private: + std::vector m_paths; +}; + +NAMESPACE_END(filesystem) diff --git a/gui/dependencies/filesystem/path_demo.cpp b/gui/dependencies/filesystem/path_demo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..266c661ad54f8a78947edcf5e6e7a377ef39b2c6 --- /dev/null +++ b/gui/dependencies/filesystem/path_demo.cpp @@ -0,0 +1,48 @@ +#include +#include "filesystem/path.h" +#include "filesystem/resolver.h" + +using namespace std; +using namespace filesystem; + +int main(int argc, char **argv) { +#if !defined(WIN32) + path path1("/dir 1/dir 2/"); +#else + path path1("C:\\dir 1\\dir 2\\"); +#endif + path path2("dir 3"); + + cout << path1.exists() << endl; + cout << path1 << endl; + cout << (path1/path2) << endl; + cout << (path1/path2).parent_path() << endl; + cout << (path1/path2).parent_path().parent_path() << endl; + cout << (path1/path2).parent_path().parent_path().parent_path() << endl; + cout << (path1/path2).parent_path().parent_path().parent_path().parent_path() << endl; + cout << path().parent_path() << endl; + cout << "some/path.ext:operator==() = " << (path("some/path.ext") == path("some/path.ext")) << endl; + cout << "some/path.ext:operator==() (unequal) = " << (path("some/path.ext") == path("another/path.ext")) << endl; + + cout << "nonexistant:exists = " << path("nonexistant").exists() << endl; + cout << "nonexistant:is_file = " << path("nonexistant").is_file() << endl; + cout << "nonexistant:is_directory = " << path("nonexistant").is_directory() << endl; + cout << "nonexistant:filename = " << path("nonexistant").filename() << endl; + cout << "nonexistant:extension = " << path("nonexistant").extension() << endl; + cout << "filesystem/path.h:exists = " << path("filesystem/path.h").exists() << endl; + cout << "filesystem/path.h:is_file = " << path("filesystem/path.h").is_file() << endl; + cout << "filesystem/path.h:is_directory = " << path("filesystem/path.h").is_directory() << endl; + cout << "filesystem/path.h:filename = " << path("filesystem/path.h").filename() << endl; + cout << "filesystem/path.h:extension = " << path("filesystem/path.h").extension() << endl; + cout << "filesystem/path.h:make_absolute = " << path("filesystem/path.h").make_absolute() << endl; + cout << "../filesystem:exists = " << path("../filesystem").exists() << endl; + cout << "../filesystem:is_file = " << path("../filesystem").is_file() << endl; + cout << "../filesystem:is_directory = " << path("../filesystem").is_directory() << endl; + cout << "../filesystem:extension = " << path("../filesystem").extension() << endl; + cout << "../filesystem:filename = " << path("../filesystem").filename() << endl; + cout << "../filesystem:make_absolute = " << path("../filesystem").make_absolute() << endl; + + cout << "resolve(filesystem/path.h) = " << resolver().resolve("filesystem/path.h") << endl; + cout << "resolve(nonexistant) = " << resolver().resolve("nonexistant") << endl; + return 0; +} diff --git a/gui/dependencies/gl3w/GL/gl3w.c b/gui/dependencies/gl3w/GL/gl3w.c new file mode 100644 index 0000000000000000000000000000000000000000..464e017788bd0e79ef705cb2bf95777826901fea --- /dev/null +++ b/gui/dependencies/gl3w/GL/gl3w.c @@ -0,0 +1,1344 @@ +#include + +#ifdef _MSC_VER +#pragma warning (disable: 4055) // warning C4055: 'type cast' : from data pointer 'void *' to function pointer +#pragma warning (disable: 4152) // warning C4152: nonstandard extension, function/data pointer conversion in expression +#endif + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN 1 +#include + +static HMODULE libgl; + +static void open_libgl(void) +{ + libgl = LoadLibraryA("opengl32.dll"); +} + +static void close_libgl(void) +{ + FreeLibrary(libgl); +} + +static void *get_proc(const char *proc) +{ + void *res; + + res = wglGetProcAddress(proc); + if (!res) + res = GetProcAddress(libgl, proc); + return res; +} +#elif defined(__APPLE__) || defined(__APPLE_CC__) +#include + +CFBundleRef bundle; +CFURLRef bundleURL; + +static void open_libgl(void) +{ + bundleURL = CFURLCreateWithFileSystemPath(kCFAllocatorDefault, + CFSTR("/System/Library/Frameworks/OpenGL.framework"), + kCFURLPOSIXPathStyle, true); + + bundle = CFBundleCreate(kCFAllocatorDefault, bundleURL); + assert(bundle != NULL); +} + +static void close_libgl(void) +{ + CFRelease(bundle); + CFRelease(bundleURL); +} + +static void *get_proc(const char *proc) +{ + void *res; + + CFStringRef procname = CFStringCreateWithCString(kCFAllocatorDefault, proc, + kCFStringEncodingASCII); + res = CFBundleGetFunctionPointerForName(bundle, procname); + CFRelease(procname); + return res; +} +#else +#include +#include + +static void *libgl; + +static void open_libgl(void) +{ + libgl = dlopen("libGL.so.1", RTLD_LAZY | RTLD_GLOBAL); +} + +static void close_libgl(void) +{ + dlclose(libgl); +} + +static void *get_proc(const char *proc) +{ + void *res; + + res = (void*)glXGetProcAddress((const GLubyte *) proc); + if (!res) + res = dlsym(libgl, proc); + return res; +} +#endif + +static struct { + int major, minor; +} version; + +static int parse_version(void) +{ + if (!glGetIntegerv) + return -1; + + glGetIntegerv(GL_MAJOR_VERSION, &version.major); + glGetIntegerv(GL_MINOR_VERSION, &version.minor); + + if (version.major < 3) + return -1; + return 0; +} + +static void load_procs(void); + +int gl3wInit(void) +{ + open_libgl(); + load_procs(); + close_libgl(); + return parse_version(); +} + +int gl3wIsSupported(int major, int minor) +{ + if (major < 3) + return 0; + if (version.major == major) + return version.minor >= minor; + return version.major >= major; +} + +void *gl3wGetProcAddress(const char *proc) +{ + return get_proc(proc); +} + +PFNGLCULLFACEPROC gl3wCullFace; +PFNGLFRONTFACEPROC gl3wFrontFace; +PFNGLHINTPROC gl3wHint; +PFNGLLINEWIDTHPROC gl3wLineWidth; +PFNGLPOINTSIZEPROC gl3wPointSize; +PFNGLPOLYGONMODEPROC gl3wPolygonMode; +PFNGLSCISSORPROC gl3wScissor; +PFNGLTEXPARAMETERFPROC gl3wTexParameterf; +PFNGLTEXPARAMETERFVPROC gl3wTexParameterfv; +PFNGLTEXPARAMETERIPROC gl3wTexParameteri; +PFNGLTEXPARAMETERIVPROC gl3wTexParameteriv; +PFNGLTEXIMAGE1DPROC gl3wTexImage1D; +PFNGLTEXIMAGE2DPROC gl3wTexImage2D; +PFNGLDRAWBUFFERPROC gl3wDrawBuffer; +PFNGLCLEARPROC gl3wClear; +PFNGLCLEARCOLORPROC gl3wClearColor; +PFNGLCLEARSTENCILPROC gl3wClearStencil; +PFNGLCLEARDEPTHPROC gl3wClearDepth; +PFNGLSTENCILMASKPROC gl3wStencilMask; +PFNGLCOLORMASKPROC gl3wColorMask; +PFNGLDEPTHMASKPROC gl3wDepthMask; +PFNGLDISABLEPROC gl3wDisable; +PFNGLENABLEPROC gl3wEnable; +PFNGLFINISHPROC gl3wFinish; +PFNGLFLUSHPROC gl3wFlush; +PFNGLBLENDFUNCPROC gl3wBlendFunc; +PFNGLLOGICOPPROC gl3wLogicOp; +PFNGLSTENCILFUNCPROC gl3wStencilFunc; +PFNGLSTENCILOPPROC gl3wStencilOp; +PFNGLDEPTHFUNCPROC gl3wDepthFunc; +PFNGLPIXELSTOREFPROC gl3wPixelStoref; +PFNGLPIXELSTOREIPROC gl3wPixelStorei; +PFNGLREADBUFFERPROC gl3wReadBuffer; +PFNGLREADPIXELSPROC gl3wReadPixels; +PFNGLGETBOOLEANVPROC gl3wGetBooleanv; +PFNGLGETDOUBLEVPROC gl3wGetDoublev; +PFNGLGETERRORPROC gl3wGetError; +PFNGLGETFLOATVPROC gl3wGetFloatv; +PFNGLGETINTEGERVPROC gl3wGetIntegerv; +PFNGLGETSTRINGPROC gl3wGetString; +PFNGLGETTEXIMAGEPROC gl3wGetTexImage; +PFNGLGETTEXPARAMETERFVPROC gl3wGetTexParameterfv; +PFNGLGETTEXPARAMETERIVPROC gl3wGetTexParameteriv; +PFNGLGETTEXLEVELPARAMETERFVPROC gl3wGetTexLevelParameterfv; +PFNGLGETTEXLEVELPARAMETERIVPROC gl3wGetTexLevelParameteriv; +PFNGLISENABLEDPROC gl3wIsEnabled; +PFNGLDEPTHRANGEPROC gl3wDepthRange; +PFNGLVIEWPORTPROC gl3wViewport; +PFNGLDRAWARRAYSPROC gl3wDrawArrays; +PFNGLDRAWELEMENTSPROC gl3wDrawElements; +PFNGLGETPOINTERVPROC gl3wGetPointerv; +PFNGLPOLYGONOFFSETPROC gl3wPolygonOffset; +PFNGLCOPYTEXIMAGE1DPROC gl3wCopyTexImage1D; +PFNGLCOPYTEXIMAGE2DPROC gl3wCopyTexImage2D; +PFNGLCOPYTEXSUBIMAGE1DPROC gl3wCopyTexSubImage1D; +PFNGLCOPYTEXSUBIMAGE2DPROC gl3wCopyTexSubImage2D; +PFNGLTEXSUBIMAGE1DPROC gl3wTexSubImage1D; +PFNGLTEXSUBIMAGE2DPROC gl3wTexSubImage2D; +PFNGLBINDTEXTUREPROC gl3wBindTexture; +PFNGLDELETETEXTURESPROC gl3wDeleteTextures; +PFNGLGENTEXTURESPROC gl3wGenTextures; +PFNGLISTEXTUREPROC gl3wIsTexture; +PFNGLBLENDCOLORPROC gl3wBlendColor; +PFNGLBLENDEQUATIONPROC gl3wBlendEquation; +PFNGLDRAWRANGEELEMENTSPROC gl3wDrawRangeElements; +PFNGLTEXIMAGE3DPROC gl3wTexImage3D; +PFNGLTEXSUBIMAGE3DPROC gl3wTexSubImage3D; +PFNGLCOPYTEXSUBIMAGE3DPROC gl3wCopyTexSubImage3D; +PFNGLACTIVETEXTUREPROC gl3wActiveTexture; +PFNGLSAMPLECOVERAGEPROC gl3wSampleCoverage; +PFNGLCOMPRESSEDTEXIMAGE3DPROC gl3wCompressedTexImage3D; +PFNGLCOMPRESSEDTEXIMAGE2DPROC gl3wCompressedTexImage2D; +PFNGLCOMPRESSEDTEXIMAGE1DPROC gl3wCompressedTexImage1D; +PFNGLCOMPRESSEDTEXSUBIMAGE3DPROC gl3wCompressedTexSubImage3D; +PFNGLCOMPRESSEDTEXSUBIMAGE2DPROC gl3wCompressedTexSubImage2D; +PFNGLCOMPRESSEDTEXSUBIMAGE1DPROC gl3wCompressedTexSubImage1D; +PFNGLGETCOMPRESSEDTEXIMAGEPROC gl3wGetCompressedTexImage; +PFNGLBLENDFUNCSEPARATEPROC gl3wBlendFuncSeparate; +PFNGLMULTIDRAWARRAYSPROC gl3wMultiDrawArrays; +PFNGLMULTIDRAWELEMENTSPROC gl3wMultiDrawElements; +PFNGLPOINTPARAMETERFPROC gl3wPointParameterf; +PFNGLPOINTPARAMETERFVPROC gl3wPointParameterfv; +PFNGLPOINTPARAMETERIPROC gl3wPointParameteri; +PFNGLPOINTPARAMETERIVPROC gl3wPointParameteriv; +PFNGLGENQUERIESPROC gl3wGenQueries; +PFNGLDELETEQUERIESPROC gl3wDeleteQueries; +PFNGLISQUERYPROC gl3wIsQuery; +PFNGLBEGINQUERYPROC gl3wBeginQuery; +PFNGLENDQUERYPROC gl3wEndQuery; +PFNGLGETQUERYIVPROC gl3wGetQueryiv; +PFNGLGETQUERYOBJECTIVPROC gl3wGetQueryObjectiv; +PFNGLGETQUERYOBJECTUIVPROC gl3wGetQueryObjectuiv; +PFNGLBINDBUFFERPROC gl3wBindBuffer; +PFNGLDELETEBUFFERSPROC gl3wDeleteBuffers; +PFNGLGENBUFFERSPROC gl3wGenBuffers; +PFNGLISBUFFERPROC gl3wIsBuffer; +PFNGLBUFFERDATAPROC gl3wBufferData; +PFNGLBUFFERSUBDATAPROC gl3wBufferSubData; +PFNGLGETBUFFERSUBDATAPROC gl3wGetBufferSubData; +PFNGLMAPBUFFERPROC gl3wMapBuffer; +PFNGLUNMAPBUFFERPROC gl3wUnmapBuffer; +PFNGLGETBUFFERPARAMETERIVPROC gl3wGetBufferParameteriv; +PFNGLGETBUFFERPOINTERVPROC gl3wGetBufferPointerv; +PFNGLBLENDEQUATIONSEPARATEPROC gl3wBlendEquationSeparate; +PFNGLDRAWBUFFERSPROC gl3wDrawBuffers; +PFNGLSTENCILOPSEPARATEPROC gl3wStencilOpSeparate; +PFNGLSTENCILFUNCSEPARATEPROC gl3wStencilFuncSeparate; +PFNGLSTENCILMASKSEPARATEPROC gl3wStencilMaskSeparate; +PFNGLATTACHSHADERPROC gl3wAttachShader; +PFNGLBINDATTRIBLOCATIONPROC gl3wBindAttribLocation; +PFNGLCOMPILESHADERPROC gl3wCompileShader; +PFNGLCREATEPROGRAMPROC gl3wCreateProgram; +PFNGLCREATESHADERPROC gl3wCreateShader; +PFNGLDELETEPROGRAMPROC gl3wDeleteProgram; +PFNGLDELETESHADERPROC gl3wDeleteShader; +PFNGLDETACHSHADERPROC gl3wDetachShader; +PFNGLDISABLEVERTEXATTRIBARRAYPROC gl3wDisableVertexAttribArray; +PFNGLENABLEVERTEXATTRIBARRAYPROC gl3wEnableVertexAttribArray; +PFNGLGETACTIVEATTRIBPROC gl3wGetActiveAttrib; +PFNGLGETACTIVEUNIFORMPROC gl3wGetActiveUniform; +PFNGLGETATTACHEDSHADERSPROC gl3wGetAttachedShaders; +PFNGLGETATTRIBLOCATIONPROC gl3wGetAttribLocation; +PFNGLGETPROGRAMIVPROC gl3wGetProgramiv; +PFNGLGETPROGRAMINFOLOGPROC gl3wGetProgramInfoLog; +PFNGLGETSHADERIVPROC gl3wGetShaderiv; +PFNGLGETSHADERINFOLOGPROC gl3wGetShaderInfoLog; +PFNGLGETSHADERSOURCEPROC gl3wGetShaderSource; +PFNGLGETUNIFORMLOCATIONPROC gl3wGetUniformLocation; +PFNGLGETUNIFORMFVPROC gl3wGetUniformfv; +PFNGLGETUNIFORMIVPROC gl3wGetUniformiv; +PFNGLGETVERTEXATTRIBDVPROC gl3wGetVertexAttribdv; +PFNGLGETVERTEXATTRIBFVPROC gl3wGetVertexAttribfv; +PFNGLGETVERTEXATTRIBIVPROC gl3wGetVertexAttribiv; +PFNGLGETVERTEXATTRIBPOINTERVPROC gl3wGetVertexAttribPointerv; +PFNGLISPROGRAMPROC gl3wIsProgram; +PFNGLISSHADERPROC gl3wIsShader; +PFNGLLINKPROGRAMPROC gl3wLinkProgram; +PFNGLSHADERSOURCEPROC gl3wShaderSource; +PFNGLUSEPROGRAMPROC gl3wUseProgram; +PFNGLUNIFORM1FPROC gl3wUniform1f; +PFNGLUNIFORM2FPROC gl3wUniform2f; +PFNGLUNIFORM3FPROC gl3wUniform3f; +PFNGLUNIFORM4FPROC gl3wUniform4f; +PFNGLUNIFORM1IPROC gl3wUniform1i; +PFNGLUNIFORM2IPROC gl3wUniform2i; +PFNGLUNIFORM3IPROC gl3wUniform3i; +PFNGLUNIFORM4IPROC gl3wUniform4i; +PFNGLUNIFORM1FVPROC gl3wUniform1fv; +PFNGLUNIFORM2FVPROC gl3wUniform2fv; +PFNGLUNIFORM3FVPROC gl3wUniform3fv; +PFNGLUNIFORM4FVPROC gl3wUniform4fv; +PFNGLUNIFORM1IVPROC gl3wUniform1iv; +PFNGLUNIFORM2IVPROC gl3wUniform2iv; +PFNGLUNIFORM3IVPROC gl3wUniform3iv; +PFNGLUNIFORM4IVPROC gl3wUniform4iv; +PFNGLUNIFORMMATRIX2FVPROC gl3wUniformMatrix2fv; +PFNGLUNIFORMMATRIX3FVPROC gl3wUniformMatrix3fv; +PFNGLUNIFORMMATRIX4FVPROC gl3wUniformMatrix4fv; +PFNGLVALIDATEPROGRAMPROC gl3wValidateProgram; +PFNGLVERTEXATTRIB1DPROC gl3wVertexAttrib1d; +PFNGLVERTEXATTRIB1DVPROC gl3wVertexAttrib1dv; +PFNGLVERTEXATTRIB1FPROC gl3wVertexAttrib1f; +PFNGLVERTEXATTRIB1FVPROC gl3wVertexAttrib1fv; +PFNGLVERTEXATTRIB1SPROC gl3wVertexAttrib1s; +PFNGLVERTEXATTRIB1SVPROC gl3wVertexAttrib1sv; +PFNGLVERTEXATTRIB2DPROC gl3wVertexAttrib2d; +PFNGLVERTEXATTRIB2DVPROC gl3wVertexAttrib2dv; +PFNGLVERTEXATTRIB2FPROC gl3wVertexAttrib2f; +PFNGLVERTEXATTRIB2FVPROC gl3wVertexAttrib2fv; +PFNGLVERTEXATTRIB2SPROC gl3wVertexAttrib2s; +PFNGLVERTEXATTRIB2SVPROC gl3wVertexAttrib2sv; +PFNGLVERTEXATTRIB3DPROC gl3wVertexAttrib3d; +PFNGLVERTEXATTRIB3DVPROC gl3wVertexAttrib3dv; +PFNGLVERTEXATTRIB3FPROC gl3wVertexAttrib3f; +PFNGLVERTEXATTRIB3FVPROC gl3wVertexAttrib3fv; +PFNGLVERTEXATTRIB3SPROC gl3wVertexAttrib3s; +PFNGLVERTEXATTRIB3SVPROC gl3wVertexAttrib3sv; +PFNGLVERTEXATTRIB4NBVPROC gl3wVertexAttrib4Nbv; +PFNGLVERTEXATTRIB4NIVPROC gl3wVertexAttrib4Niv; +PFNGLVERTEXATTRIB4NSVPROC gl3wVertexAttrib4Nsv; +PFNGLVERTEXATTRIB4NUBPROC gl3wVertexAttrib4Nub; +PFNGLVERTEXATTRIB4NUBVPROC gl3wVertexAttrib4Nubv; +PFNGLVERTEXATTRIB4NUIVPROC gl3wVertexAttrib4Nuiv; +PFNGLVERTEXATTRIB4NUSVPROC gl3wVertexAttrib4Nusv; +PFNGLVERTEXATTRIB4BVPROC gl3wVertexAttrib4bv; +PFNGLVERTEXATTRIB4DPROC gl3wVertexAttrib4d; +PFNGLVERTEXATTRIB4DVPROC gl3wVertexAttrib4dv; +PFNGLVERTEXATTRIB4FPROC gl3wVertexAttrib4f; +PFNGLVERTEXATTRIB4FVPROC gl3wVertexAttrib4fv; +PFNGLVERTEXATTRIB4IVPROC gl3wVertexAttrib4iv; +PFNGLVERTEXATTRIB4SPROC gl3wVertexAttrib4s; +PFNGLVERTEXATTRIB4SVPROC gl3wVertexAttrib4sv; +PFNGLVERTEXATTRIB4UBVPROC gl3wVertexAttrib4ubv; +PFNGLVERTEXATTRIB4UIVPROC gl3wVertexAttrib4uiv; +PFNGLVERTEXATTRIB4USVPROC gl3wVertexAttrib4usv; +PFNGLVERTEXATTRIBPOINTERPROC gl3wVertexAttribPointer; +PFNGLUNIFORMMATRIX2X3FVPROC gl3wUniformMatrix2x3fv; +PFNGLUNIFORMMATRIX3X2FVPROC gl3wUniformMatrix3x2fv; +PFNGLUNIFORMMATRIX2X4FVPROC gl3wUniformMatrix2x4fv; +PFNGLUNIFORMMATRIX4X2FVPROC gl3wUniformMatrix4x2fv; +PFNGLUNIFORMMATRIX3X4FVPROC gl3wUniformMatrix3x4fv; +PFNGLUNIFORMMATRIX4X3FVPROC gl3wUniformMatrix4x3fv; +PFNGLCOLORMASKIPROC gl3wColorMaski; +PFNGLGETBOOLEANI_VPROC gl3wGetBooleani_v; +PFNGLGETINTEGERI_VPROC gl3wGetIntegeri_v; +PFNGLENABLEIPROC gl3wEnablei; +PFNGLDISABLEIPROC gl3wDisablei; +PFNGLISENABLEDIPROC gl3wIsEnabledi; +PFNGLBEGINTRANSFORMFEEDBACKPROC gl3wBeginTransformFeedback; +PFNGLENDTRANSFORMFEEDBACKPROC gl3wEndTransformFeedback; +PFNGLBINDBUFFERRANGEPROC gl3wBindBufferRange; +PFNGLBINDBUFFERBASEPROC gl3wBindBufferBase; +PFNGLTRANSFORMFEEDBACKVARYINGSPROC gl3wTransformFeedbackVaryings; +PFNGLGETTRANSFORMFEEDBACKVARYINGPROC gl3wGetTransformFeedbackVarying; +PFNGLCLAMPCOLORPROC gl3wClampColor; +PFNGLBEGINCONDITIONALRENDERPROC gl3wBeginConditionalRender; +PFNGLENDCONDITIONALRENDERPROC gl3wEndConditionalRender; +PFNGLVERTEXATTRIBIPOINTERPROC gl3wVertexAttribIPointer; +PFNGLGETVERTEXATTRIBIIVPROC gl3wGetVertexAttribIiv; +PFNGLGETVERTEXATTRIBIUIVPROC gl3wGetVertexAttribIuiv; +PFNGLVERTEXATTRIBI1IPROC gl3wVertexAttribI1i; +PFNGLVERTEXATTRIBI2IPROC gl3wVertexAttribI2i; +PFNGLVERTEXATTRIBI3IPROC gl3wVertexAttribI3i; +PFNGLVERTEXATTRIBI4IPROC gl3wVertexAttribI4i; +PFNGLVERTEXATTRIBI1UIPROC gl3wVertexAttribI1ui; +PFNGLVERTEXATTRIBI2UIPROC gl3wVertexAttribI2ui; +PFNGLVERTEXATTRIBI3UIPROC gl3wVertexAttribI3ui; +PFNGLVERTEXATTRIBI4UIPROC gl3wVertexAttribI4ui; +PFNGLVERTEXATTRIBI1IVPROC gl3wVertexAttribI1iv; +PFNGLVERTEXATTRIBI2IVPROC gl3wVertexAttribI2iv; +PFNGLVERTEXATTRIBI3IVPROC gl3wVertexAttribI3iv; +PFNGLVERTEXATTRIBI4IVPROC gl3wVertexAttribI4iv; +PFNGLVERTEXATTRIBI1UIVPROC gl3wVertexAttribI1uiv; +PFNGLVERTEXATTRIBI2UIVPROC gl3wVertexAttribI2uiv; +PFNGLVERTEXATTRIBI3UIVPROC gl3wVertexAttribI3uiv; +PFNGLVERTEXATTRIBI4UIVPROC gl3wVertexAttribI4uiv; +PFNGLVERTEXATTRIBI4BVPROC gl3wVertexAttribI4bv; +PFNGLVERTEXATTRIBI4SVPROC gl3wVertexAttribI4sv; +PFNGLVERTEXATTRIBI4UBVPROC gl3wVertexAttribI4ubv; +PFNGLVERTEXATTRIBI4USVPROC gl3wVertexAttribI4usv; +PFNGLGETUNIFORMUIVPROC gl3wGetUniformuiv; +PFNGLBINDFRAGDATALOCATIONPROC gl3wBindFragDataLocation; +PFNGLGETFRAGDATALOCATIONPROC gl3wGetFragDataLocation; +PFNGLUNIFORM1UIPROC gl3wUniform1ui; +PFNGLUNIFORM2UIPROC gl3wUniform2ui; +PFNGLUNIFORM3UIPROC gl3wUniform3ui; +PFNGLUNIFORM4UIPROC gl3wUniform4ui; +PFNGLUNIFORM1UIVPROC gl3wUniform1uiv; +PFNGLUNIFORM2UIVPROC gl3wUniform2uiv; +PFNGLUNIFORM3UIVPROC gl3wUniform3uiv; +PFNGLUNIFORM4UIVPROC gl3wUniform4uiv; +PFNGLTEXPARAMETERIIVPROC gl3wTexParameterIiv; +PFNGLTEXPARAMETERIUIVPROC gl3wTexParameterIuiv; +PFNGLGETTEXPARAMETERIIVPROC gl3wGetTexParameterIiv; +PFNGLGETTEXPARAMETERIUIVPROC gl3wGetTexParameterIuiv; +PFNGLCLEARBUFFERIVPROC gl3wClearBufferiv; +PFNGLCLEARBUFFERUIVPROC gl3wClearBufferuiv; +PFNGLCLEARBUFFERFVPROC gl3wClearBufferfv; +PFNGLCLEARBUFFERFIPROC gl3wClearBufferfi; +PFNGLGETSTRINGIPROC gl3wGetStringi; +PFNGLDRAWARRAYSINSTANCEDPROC gl3wDrawArraysInstanced; +PFNGLDRAWELEMENTSINSTANCEDPROC gl3wDrawElementsInstanced; +PFNGLTEXBUFFERPROC gl3wTexBuffer; +PFNGLPRIMITIVERESTARTINDEXPROC gl3wPrimitiveRestartIndex; +PFNGLGETINTEGER64I_VPROC gl3wGetInteger64i_v; +PFNGLGETBUFFERPARAMETERI64VPROC gl3wGetBufferParameteri64v; +PFNGLFRAMEBUFFERTEXTUREPROC gl3wFramebufferTexture; +PFNGLVERTEXATTRIBDIVISORPROC gl3wVertexAttribDivisor; +PFNGLMINSAMPLESHADINGPROC gl3wMinSampleShading; +PFNGLBLENDEQUATIONIPROC gl3wBlendEquationi; +PFNGLBLENDEQUATIONSEPARATEIPROC gl3wBlendEquationSeparatei; +PFNGLBLENDFUNCIPROC gl3wBlendFunci; +PFNGLBLENDFUNCSEPARATEIPROC gl3wBlendFuncSeparatei; +PFNGLISRENDERBUFFERPROC gl3wIsRenderbuffer; +PFNGLBINDRENDERBUFFERPROC gl3wBindRenderbuffer; +PFNGLDELETERENDERBUFFERSPROC gl3wDeleteRenderbuffers; +PFNGLGENRENDERBUFFERSPROC gl3wGenRenderbuffers; +PFNGLRENDERBUFFERSTORAGEPROC gl3wRenderbufferStorage; +PFNGLGETRENDERBUFFERPARAMETERIVPROC gl3wGetRenderbufferParameteriv; +PFNGLISFRAMEBUFFERPROC gl3wIsFramebuffer; +PFNGLBINDFRAMEBUFFERPROC gl3wBindFramebuffer; +PFNGLDELETEFRAMEBUFFERSPROC gl3wDeleteFramebuffers; +PFNGLGENFRAMEBUFFERSPROC gl3wGenFramebuffers; +PFNGLCHECKFRAMEBUFFERSTATUSPROC gl3wCheckFramebufferStatus; +PFNGLFRAMEBUFFERTEXTURE1DPROC gl3wFramebufferTexture1D; +PFNGLFRAMEBUFFERTEXTURE2DPROC gl3wFramebufferTexture2D; +PFNGLFRAMEBUFFERTEXTURE3DPROC gl3wFramebufferTexture3D; +PFNGLFRAMEBUFFERRENDERBUFFERPROC gl3wFramebufferRenderbuffer; +PFNGLGETFRAMEBUFFERATTACHMENTPARAMETERIVPROC gl3wGetFramebufferAttachmentParameteriv; +PFNGLGENERATEMIPMAPPROC gl3wGenerateMipmap; +PFNGLBLITFRAMEBUFFERPROC gl3wBlitFramebuffer; +PFNGLRENDERBUFFERSTORAGEMULTISAMPLEPROC gl3wRenderbufferStorageMultisample; +PFNGLFRAMEBUFFERTEXTURELAYERPROC gl3wFramebufferTextureLayer; +PFNGLMAPBUFFERRANGEPROC gl3wMapBufferRange; +PFNGLFLUSHMAPPEDBUFFERRANGEPROC gl3wFlushMappedBufferRange; +PFNGLBINDVERTEXARRAYPROC gl3wBindVertexArray; +PFNGLDELETEVERTEXARRAYSPROC gl3wDeleteVertexArrays; +PFNGLGENVERTEXARRAYSPROC gl3wGenVertexArrays; +PFNGLISVERTEXARRAYPROC gl3wIsVertexArray; +PFNGLGETUNIFORMINDICESPROC gl3wGetUniformIndices; +PFNGLGETACTIVEUNIFORMSIVPROC gl3wGetActiveUniformsiv; +PFNGLGETACTIVEUNIFORMNAMEPROC gl3wGetActiveUniformName; +PFNGLGETUNIFORMBLOCKINDEXPROC gl3wGetUniformBlockIndex; +PFNGLGETACTIVEUNIFORMBLOCKIVPROC gl3wGetActiveUniformBlockiv; +PFNGLGETACTIVEUNIFORMBLOCKNAMEPROC gl3wGetActiveUniformBlockName; +PFNGLUNIFORMBLOCKBINDINGPROC gl3wUniformBlockBinding; +PFNGLCOPYBUFFERSUBDATAPROC gl3wCopyBufferSubData; +PFNGLDRAWELEMENTSBASEVERTEXPROC gl3wDrawElementsBaseVertex; +PFNGLDRAWRANGEELEMENTSBASEVERTEXPROC gl3wDrawRangeElementsBaseVertex; +PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXPROC gl3wDrawElementsInstancedBaseVertex; +PFNGLMULTIDRAWELEMENTSBASEVERTEXPROC gl3wMultiDrawElementsBaseVertex; +PFNGLPROVOKINGVERTEXPROC gl3wProvokingVertex; +PFNGLFENCESYNCPROC gl3wFenceSync; +PFNGLISSYNCPROC gl3wIsSync; +PFNGLDELETESYNCPROC gl3wDeleteSync; +PFNGLCLIENTWAITSYNCPROC gl3wClientWaitSync; +PFNGLWAITSYNCPROC gl3wWaitSync; +PFNGLGETINTEGER64VPROC gl3wGetInteger64v; +PFNGLGETSYNCIVPROC gl3wGetSynciv; +PFNGLTEXIMAGE2DMULTISAMPLEPROC gl3wTexImage2DMultisample; +PFNGLTEXIMAGE3DMULTISAMPLEPROC gl3wTexImage3DMultisample; +PFNGLGETMULTISAMPLEFVPROC gl3wGetMultisamplefv; +PFNGLSAMPLEMASKIPROC gl3wSampleMaski; +PFNGLBLENDEQUATIONIARBPROC gl3wBlendEquationiARB; +PFNGLBLENDEQUATIONSEPARATEIARBPROC gl3wBlendEquationSeparateiARB; +PFNGLBLENDFUNCIARBPROC gl3wBlendFunciARB; +PFNGLBLENDFUNCSEPARATEIARBPROC gl3wBlendFuncSeparateiARB; +PFNGLMINSAMPLESHADINGARBPROC gl3wMinSampleShadingARB; +PFNGLNAMEDSTRINGARBPROC gl3wNamedStringARB; +PFNGLDELETENAMEDSTRINGARBPROC gl3wDeleteNamedStringARB; +PFNGLCOMPILESHADERINCLUDEARBPROC gl3wCompileShaderIncludeARB; +PFNGLISNAMEDSTRINGARBPROC gl3wIsNamedStringARB; +PFNGLGETNAMEDSTRINGARBPROC gl3wGetNamedStringARB; +PFNGLGETNAMEDSTRINGIVARBPROC gl3wGetNamedStringivARB; +PFNGLBINDFRAGDATALOCATIONINDEXEDPROC gl3wBindFragDataLocationIndexed; +PFNGLGETFRAGDATAINDEXPROC gl3wGetFragDataIndex; +PFNGLGENSAMPLERSPROC gl3wGenSamplers; +PFNGLDELETESAMPLERSPROC gl3wDeleteSamplers; +PFNGLISSAMPLERPROC gl3wIsSampler; +PFNGLBINDSAMPLERPROC gl3wBindSampler; +PFNGLSAMPLERPARAMETERIPROC gl3wSamplerParameteri; +PFNGLSAMPLERPARAMETERIVPROC gl3wSamplerParameteriv; +PFNGLSAMPLERPARAMETERFPROC gl3wSamplerParameterf; +PFNGLSAMPLERPARAMETERFVPROC gl3wSamplerParameterfv; +PFNGLSAMPLERPARAMETERIIVPROC gl3wSamplerParameterIiv; +PFNGLSAMPLERPARAMETERIUIVPROC gl3wSamplerParameterIuiv; +PFNGLGETSAMPLERPARAMETERIVPROC gl3wGetSamplerParameteriv; +PFNGLGETSAMPLERPARAMETERIIVPROC gl3wGetSamplerParameterIiv; +PFNGLGETSAMPLERPARAMETERFVPROC gl3wGetSamplerParameterfv; +PFNGLGETSAMPLERPARAMETERIUIVPROC gl3wGetSamplerParameterIuiv; +PFNGLQUERYCOUNTERPROC gl3wQueryCounter; +PFNGLGETQUERYOBJECTI64VPROC gl3wGetQueryObjecti64v; +PFNGLGETQUERYOBJECTUI64VPROC gl3wGetQueryObjectui64v; +PFNGLVERTEXP2UIPROC gl3wVertexP2ui; +PFNGLVERTEXP2UIVPROC gl3wVertexP2uiv; +PFNGLVERTEXP3UIPROC gl3wVertexP3ui; +PFNGLVERTEXP3UIVPROC gl3wVertexP3uiv; +PFNGLVERTEXP4UIPROC gl3wVertexP4ui; +PFNGLVERTEXP4UIVPROC gl3wVertexP4uiv; +PFNGLTEXCOORDP1UIPROC gl3wTexCoordP1ui; +PFNGLTEXCOORDP1UIVPROC gl3wTexCoordP1uiv; +PFNGLTEXCOORDP2UIPROC gl3wTexCoordP2ui; +PFNGLTEXCOORDP2UIVPROC gl3wTexCoordP2uiv; +PFNGLTEXCOORDP3UIPROC gl3wTexCoordP3ui; +PFNGLTEXCOORDP3UIVPROC gl3wTexCoordP3uiv; +PFNGLTEXCOORDP4UIPROC gl3wTexCoordP4ui; +PFNGLTEXCOORDP4UIVPROC gl3wTexCoordP4uiv; +PFNGLMULTITEXCOORDP1UIPROC gl3wMultiTexCoordP1ui; +PFNGLMULTITEXCOORDP1UIVPROC gl3wMultiTexCoordP1uiv; +PFNGLMULTITEXCOORDP2UIPROC gl3wMultiTexCoordP2ui; +PFNGLMULTITEXCOORDP2UIVPROC gl3wMultiTexCoordP2uiv; +PFNGLMULTITEXCOORDP3UIPROC gl3wMultiTexCoordP3ui; +PFNGLMULTITEXCOORDP3UIVPROC gl3wMultiTexCoordP3uiv; +PFNGLMULTITEXCOORDP4UIPROC gl3wMultiTexCoordP4ui; +PFNGLMULTITEXCOORDP4UIVPROC gl3wMultiTexCoordP4uiv; +PFNGLNORMALP3UIPROC gl3wNormalP3ui; +PFNGLNORMALP3UIVPROC gl3wNormalP3uiv; +PFNGLCOLORP3UIPROC gl3wColorP3ui; +PFNGLCOLORP3UIVPROC gl3wColorP3uiv; +PFNGLCOLORP4UIPROC gl3wColorP4ui; +PFNGLCOLORP4UIVPROC gl3wColorP4uiv; +PFNGLSECONDARYCOLORP3UIPROC gl3wSecondaryColorP3ui; +PFNGLSECONDARYCOLORP3UIVPROC gl3wSecondaryColorP3uiv; +PFNGLVERTEXATTRIBP1UIPROC gl3wVertexAttribP1ui; +PFNGLVERTEXATTRIBP1UIVPROC gl3wVertexAttribP1uiv; +PFNGLVERTEXATTRIBP2UIPROC gl3wVertexAttribP2ui; +PFNGLVERTEXATTRIBP2UIVPROC gl3wVertexAttribP2uiv; +PFNGLVERTEXATTRIBP3UIPROC gl3wVertexAttribP3ui; +PFNGLVERTEXATTRIBP3UIVPROC gl3wVertexAttribP3uiv; +PFNGLVERTEXATTRIBP4UIPROC gl3wVertexAttribP4ui; +PFNGLVERTEXATTRIBP4UIVPROC gl3wVertexAttribP4uiv; +PFNGLDRAWARRAYSINDIRECTPROC gl3wDrawArraysIndirect; +PFNGLDRAWELEMENTSINDIRECTPROC gl3wDrawElementsIndirect; +PFNGLUNIFORM1DPROC gl3wUniform1d; +PFNGLUNIFORM2DPROC gl3wUniform2d; +PFNGLUNIFORM3DPROC gl3wUniform3d; +PFNGLUNIFORM4DPROC gl3wUniform4d; +PFNGLUNIFORM1DVPROC gl3wUniform1dv; +PFNGLUNIFORM2DVPROC gl3wUniform2dv; +PFNGLUNIFORM3DVPROC gl3wUniform3dv; +PFNGLUNIFORM4DVPROC gl3wUniform4dv; +PFNGLUNIFORMMATRIX2DVPROC gl3wUniformMatrix2dv; +PFNGLUNIFORMMATRIX3DVPROC gl3wUniformMatrix3dv; +PFNGLUNIFORMMATRIX4DVPROC gl3wUniformMatrix4dv; +PFNGLUNIFORMMATRIX2X3DVPROC gl3wUniformMatrix2x3dv; +PFNGLUNIFORMMATRIX2X4DVPROC gl3wUniformMatrix2x4dv; +PFNGLUNIFORMMATRIX3X2DVPROC gl3wUniformMatrix3x2dv; +PFNGLUNIFORMMATRIX3X4DVPROC gl3wUniformMatrix3x4dv; +PFNGLUNIFORMMATRIX4X2DVPROC gl3wUniformMatrix4x2dv; +PFNGLUNIFORMMATRIX4X3DVPROC gl3wUniformMatrix4x3dv; +PFNGLGETUNIFORMDVPROC gl3wGetUniformdv; +PFNGLGETSUBROUTINEUNIFORMLOCATIONPROC gl3wGetSubroutineUniformLocation; +PFNGLGETSUBROUTINEINDEXPROC gl3wGetSubroutineIndex; +PFNGLGETACTIVESUBROUTINEUNIFORMIVPROC gl3wGetActiveSubroutineUniformiv; +PFNGLGETACTIVESUBROUTINEUNIFORMNAMEPROC gl3wGetActiveSubroutineUniformName; +PFNGLGETACTIVESUBROUTINENAMEPROC gl3wGetActiveSubroutineName; +PFNGLUNIFORMSUBROUTINESUIVPROC gl3wUniformSubroutinesuiv; +PFNGLGETUNIFORMSUBROUTINEUIVPROC gl3wGetUniformSubroutineuiv; +PFNGLGETPROGRAMSTAGEIVPROC gl3wGetProgramStageiv; +PFNGLPATCHPARAMETERIPROC gl3wPatchParameteri; +PFNGLPATCHPARAMETERFVPROC gl3wPatchParameterfv; +PFNGLBINDTRANSFORMFEEDBACKPROC gl3wBindTransformFeedback; +PFNGLDELETETRANSFORMFEEDBACKSPROC gl3wDeleteTransformFeedbacks; +PFNGLGENTRANSFORMFEEDBACKSPROC gl3wGenTransformFeedbacks; +PFNGLISTRANSFORMFEEDBACKPROC gl3wIsTransformFeedback; +PFNGLPAUSETRANSFORMFEEDBACKPROC gl3wPauseTransformFeedback; +PFNGLRESUMETRANSFORMFEEDBACKPROC gl3wResumeTransformFeedback; +PFNGLDRAWTRANSFORMFEEDBACKPROC gl3wDrawTransformFeedback; +PFNGLDRAWTRANSFORMFEEDBACKSTREAMPROC gl3wDrawTransformFeedbackStream; +PFNGLBEGINQUERYINDEXEDPROC gl3wBeginQueryIndexed; +PFNGLENDQUERYINDEXEDPROC gl3wEndQueryIndexed; +PFNGLGETQUERYINDEXEDIVPROC gl3wGetQueryIndexediv; +PFNGLRELEASESHADERCOMPILERPROC gl3wReleaseShaderCompiler; +PFNGLSHADERBINARYPROC gl3wShaderBinary; +PFNGLGETSHADERPRECISIONFORMATPROC gl3wGetShaderPrecisionFormat; +PFNGLDEPTHRANGEFPROC gl3wDepthRangef; +PFNGLCLEARDEPTHFPROC gl3wClearDepthf; +PFNGLGETPROGRAMBINARYPROC gl3wGetProgramBinary; +PFNGLPROGRAMBINARYPROC gl3wProgramBinary; +PFNGLPROGRAMPARAMETERIPROC gl3wProgramParameteri; +PFNGLUSEPROGRAMSTAGESPROC gl3wUseProgramStages; +PFNGLACTIVESHADERPROGRAMPROC gl3wActiveShaderProgram; +PFNGLCREATESHADERPROGRAMVPROC gl3wCreateShaderProgramv; +PFNGLBINDPROGRAMPIPELINEPROC gl3wBindProgramPipeline; +PFNGLDELETEPROGRAMPIPELINESPROC gl3wDeleteProgramPipelines; +PFNGLGENPROGRAMPIPELINESPROC gl3wGenProgramPipelines; +PFNGLISPROGRAMPIPELINEPROC gl3wIsProgramPipeline; +PFNGLGETPROGRAMPIPELINEIVPROC gl3wGetProgramPipelineiv; +PFNGLPROGRAMUNIFORM1IPROC gl3wProgramUniform1i; +PFNGLPROGRAMUNIFORM1IVPROC gl3wProgramUniform1iv; +PFNGLPROGRAMUNIFORM1FPROC gl3wProgramUniform1f; +PFNGLPROGRAMUNIFORM1FVPROC gl3wProgramUniform1fv; +PFNGLPROGRAMUNIFORM1DPROC gl3wProgramUniform1d; +PFNGLPROGRAMUNIFORM1DVPROC gl3wProgramUniform1dv; +PFNGLPROGRAMUNIFORM1UIPROC gl3wProgramUniform1ui; +PFNGLPROGRAMUNIFORM1UIVPROC gl3wProgramUniform1uiv; +PFNGLPROGRAMUNIFORM2IPROC gl3wProgramUniform2i; +PFNGLPROGRAMUNIFORM2IVPROC gl3wProgramUniform2iv; +PFNGLPROGRAMUNIFORM2FPROC gl3wProgramUniform2f; +PFNGLPROGRAMUNIFORM2FVPROC gl3wProgramUniform2fv; +PFNGLPROGRAMUNIFORM2DPROC gl3wProgramUniform2d; +PFNGLPROGRAMUNIFORM2DVPROC gl3wProgramUniform2dv; +PFNGLPROGRAMUNIFORM2UIPROC gl3wProgramUniform2ui; +PFNGLPROGRAMUNIFORM2UIVPROC gl3wProgramUniform2uiv; +PFNGLPROGRAMUNIFORM3IPROC gl3wProgramUniform3i; +PFNGLPROGRAMUNIFORM3IVPROC gl3wProgramUniform3iv; +PFNGLPROGRAMUNIFORM3FPROC gl3wProgramUniform3f; +PFNGLPROGRAMUNIFORM3FVPROC gl3wProgramUniform3fv; +PFNGLPROGRAMUNIFORM3DPROC gl3wProgramUniform3d; +PFNGLPROGRAMUNIFORM3DVPROC gl3wProgramUniform3dv; +PFNGLPROGRAMUNIFORM3UIPROC gl3wProgramUniform3ui; +PFNGLPROGRAMUNIFORM3UIVPROC gl3wProgramUniform3uiv; +PFNGLPROGRAMUNIFORM4IPROC gl3wProgramUniform4i; +PFNGLPROGRAMUNIFORM4IVPROC gl3wProgramUniform4iv; +PFNGLPROGRAMUNIFORM4FPROC gl3wProgramUniform4f; +PFNGLPROGRAMUNIFORM4FVPROC gl3wProgramUniform4fv; +PFNGLPROGRAMUNIFORM4DPROC gl3wProgramUniform4d; +PFNGLPROGRAMUNIFORM4DVPROC gl3wProgramUniform4dv; +PFNGLPROGRAMUNIFORM4UIPROC gl3wProgramUniform4ui; +PFNGLPROGRAMUNIFORM4UIVPROC gl3wProgramUniform4uiv; +PFNGLPROGRAMUNIFORMMATRIX2FVPROC gl3wProgramUniformMatrix2fv; +PFNGLPROGRAMUNIFORMMATRIX3FVPROC gl3wProgramUniformMatrix3fv; +PFNGLPROGRAMUNIFORMMATRIX4FVPROC gl3wProgramUniformMatrix4fv; +PFNGLPROGRAMUNIFORMMATRIX2DVPROC gl3wProgramUniformMatrix2dv; +PFNGLPROGRAMUNIFORMMATRIX3DVPROC gl3wProgramUniformMatrix3dv; +PFNGLPROGRAMUNIFORMMATRIX4DVPROC gl3wProgramUniformMatrix4dv; +PFNGLPROGRAMUNIFORMMATRIX2X3FVPROC gl3wProgramUniformMatrix2x3fv; +PFNGLPROGRAMUNIFORMMATRIX3X2FVPROC gl3wProgramUniformMatrix3x2fv; +PFNGLPROGRAMUNIFORMMATRIX2X4FVPROC gl3wProgramUniformMatrix2x4fv; +PFNGLPROGRAMUNIFORMMATRIX4X2FVPROC gl3wProgramUniformMatrix4x2fv; +PFNGLPROGRAMUNIFORMMATRIX3X4FVPROC gl3wProgramUniformMatrix3x4fv; +PFNGLPROGRAMUNIFORMMATRIX4X3FVPROC gl3wProgramUniformMatrix4x3fv; +PFNGLPROGRAMUNIFORMMATRIX2X3DVPROC gl3wProgramUniformMatrix2x3dv; +PFNGLPROGRAMUNIFORMMATRIX3X2DVPROC gl3wProgramUniformMatrix3x2dv; +PFNGLPROGRAMUNIFORMMATRIX2X4DVPROC gl3wProgramUniformMatrix2x4dv; +PFNGLPROGRAMUNIFORMMATRIX4X2DVPROC gl3wProgramUniformMatrix4x2dv; +PFNGLPROGRAMUNIFORMMATRIX3X4DVPROC gl3wProgramUniformMatrix3x4dv; +PFNGLPROGRAMUNIFORMMATRIX4X3DVPROC gl3wProgramUniformMatrix4x3dv; +PFNGLVALIDATEPROGRAMPIPELINEPROC gl3wValidateProgramPipeline; +PFNGLGETPROGRAMPIPELINEINFOLOGPROC gl3wGetProgramPipelineInfoLog; +PFNGLVERTEXATTRIBL1DPROC gl3wVertexAttribL1d; +PFNGLVERTEXATTRIBL2DPROC gl3wVertexAttribL2d; +PFNGLVERTEXATTRIBL3DPROC gl3wVertexAttribL3d; +PFNGLVERTEXATTRIBL4DPROC gl3wVertexAttribL4d; +PFNGLVERTEXATTRIBL1DVPROC gl3wVertexAttribL1dv; +PFNGLVERTEXATTRIBL2DVPROC gl3wVertexAttribL2dv; +PFNGLVERTEXATTRIBL3DVPROC gl3wVertexAttribL3dv; +PFNGLVERTEXATTRIBL4DVPROC gl3wVertexAttribL4dv; +PFNGLVERTEXATTRIBLPOINTERPROC gl3wVertexAttribLPointer; +PFNGLGETVERTEXATTRIBLDVPROC gl3wGetVertexAttribLdv; +PFNGLVIEWPORTARRAYVPROC gl3wViewportArrayv; +PFNGLVIEWPORTINDEXEDFPROC gl3wViewportIndexedf; +PFNGLVIEWPORTINDEXEDFVPROC gl3wViewportIndexedfv; +PFNGLSCISSORARRAYVPROC gl3wScissorArrayv; +PFNGLSCISSORINDEXEDPROC gl3wScissorIndexed; +PFNGLSCISSORINDEXEDVPROC gl3wScissorIndexedv; +PFNGLDEPTHRANGEARRAYVPROC gl3wDepthRangeArrayv; +PFNGLDEPTHRANGEINDEXEDPROC gl3wDepthRangeIndexed; +PFNGLGETFLOATI_VPROC gl3wGetFloati_v; +PFNGLGETDOUBLEI_VPROC gl3wGetDoublei_v; +PFNGLCREATESYNCFROMCLEVENTARBPROC gl3wCreateSyncFromCLeventARB; +PFNGLDEBUGMESSAGECONTROLARBPROC gl3wDebugMessageControlARB; +PFNGLDEBUGMESSAGEINSERTARBPROC gl3wDebugMessageInsertARB; +PFNGLDEBUGMESSAGECALLBACKARBPROC gl3wDebugMessageCallbackARB; +PFNGLGETDEBUGMESSAGELOGARBPROC gl3wGetDebugMessageLogARB; +PFNGLGETGRAPHICSRESETSTATUSARBPROC gl3wGetGraphicsResetStatusARB; +PFNGLGETNTEXIMAGEARBPROC gl3wGetnTexImageARB; +PFNGLREADNPIXELSARBPROC gl3wReadnPixelsARB; +PFNGLGETNCOMPRESSEDTEXIMAGEARBPROC gl3wGetnCompressedTexImageARB; +PFNGLGETNUNIFORMFVARBPROC gl3wGetnUniformfvARB; +PFNGLGETNUNIFORMIVARBPROC gl3wGetnUniformivARB; +PFNGLGETNUNIFORMUIVARBPROC gl3wGetnUniformuivARB; +PFNGLGETNUNIFORMDVARBPROC gl3wGetnUniformdvARB; +PFNGLDRAWARRAYSINSTANCEDBASEINSTANCEPROC gl3wDrawArraysInstancedBaseInstance; +PFNGLDRAWELEMENTSINSTANCEDBASEINSTANCEPROC gl3wDrawElementsInstancedBaseInstance; +PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXBASEINSTANCEPROC gl3wDrawElementsInstancedBaseVertexBaseInstance; +PFNGLDRAWTRANSFORMFEEDBACKINSTANCEDPROC gl3wDrawTransformFeedbackInstanced; +PFNGLDRAWTRANSFORMFEEDBACKSTREAMINSTANCEDPROC gl3wDrawTransformFeedbackStreamInstanced; +PFNGLGETINTERNALFORMATIVPROC gl3wGetInternalformativ; +PFNGLGETACTIVEATOMICCOUNTERBUFFERIVPROC gl3wGetActiveAtomicCounterBufferiv; +PFNGLBINDIMAGETEXTUREPROC gl3wBindImageTexture; +PFNGLMEMORYBARRIERPROC gl3wMemoryBarrier; +PFNGLTEXSTORAGE1DPROC gl3wTexStorage1D; +PFNGLTEXSTORAGE2DPROC gl3wTexStorage2D; +PFNGLTEXSTORAGE3DPROC gl3wTexStorage3D; +PFNGLTEXTURESTORAGE1DEXTPROC gl3wTextureStorage1DEXT; +PFNGLTEXTURESTORAGE2DEXTPROC gl3wTextureStorage2DEXT; +PFNGLTEXTURESTORAGE3DEXTPROC gl3wTextureStorage3DEXT; +PFNGLDEBUGMESSAGECONTROLPROC gl3wDebugMessageControl; +PFNGLDEBUGMESSAGEINSERTPROC gl3wDebugMessageInsert; +PFNGLDEBUGMESSAGECALLBACKPROC gl3wDebugMessageCallback; +PFNGLGETDEBUGMESSAGELOGPROC gl3wGetDebugMessageLog; +PFNGLPUSHDEBUGGROUPPROC gl3wPushDebugGroup; +PFNGLPOPDEBUGGROUPPROC gl3wPopDebugGroup; +PFNGLOBJECTLABELPROC gl3wObjectLabel; +PFNGLGETOBJECTLABELPROC gl3wGetObjectLabel; +PFNGLOBJECTPTRLABELPROC gl3wObjectPtrLabel; +PFNGLGETOBJECTPTRLABELPROC gl3wGetObjectPtrLabel; +PFNGLCLEARBUFFERDATAPROC gl3wClearBufferData; +PFNGLCLEARBUFFERSUBDATAPROC gl3wClearBufferSubData; +PFNGLCLEARNAMEDBUFFERDATAEXTPROC gl3wClearNamedBufferDataEXT; +PFNGLCLEARNAMEDBUFFERSUBDATAEXTPROC gl3wClearNamedBufferSubDataEXT; +PFNGLDISPATCHCOMPUTEPROC gl3wDispatchCompute; +PFNGLDISPATCHCOMPUTEINDIRECTPROC gl3wDispatchComputeIndirect; +PFNGLCOPYIMAGESUBDATAPROC gl3wCopyImageSubData; +PFNGLTEXTUREVIEWPROC gl3wTextureView; +PFNGLBINDVERTEXBUFFERPROC gl3wBindVertexBuffer; +PFNGLVERTEXATTRIBFORMATPROC gl3wVertexAttribFormat; +PFNGLVERTEXATTRIBIFORMATPROC gl3wVertexAttribIFormat; +PFNGLVERTEXATTRIBLFORMATPROC gl3wVertexAttribLFormat; +PFNGLVERTEXATTRIBBINDINGPROC gl3wVertexAttribBinding; +PFNGLVERTEXBINDINGDIVISORPROC gl3wVertexBindingDivisor; +PFNGLVERTEXARRAYBINDVERTEXBUFFEREXTPROC gl3wVertexArrayBindVertexBufferEXT; +PFNGLVERTEXARRAYVERTEXATTRIBFORMATEXTPROC gl3wVertexArrayVertexAttribFormatEXT; +PFNGLVERTEXARRAYVERTEXATTRIBIFORMATEXTPROC gl3wVertexArrayVertexAttribIFormatEXT; +PFNGLVERTEXARRAYVERTEXATTRIBLFORMATEXTPROC gl3wVertexArrayVertexAttribLFormatEXT; +PFNGLVERTEXARRAYVERTEXATTRIBBINDINGEXTPROC gl3wVertexArrayVertexAttribBindingEXT; +PFNGLVERTEXARRAYVERTEXBINDINGDIVISOREXTPROC gl3wVertexArrayVertexBindingDivisorEXT; +PFNGLFRAMEBUFFERPARAMETERIPROC gl3wFramebufferParameteri; +PFNGLGETFRAMEBUFFERPARAMETERIVPROC gl3wGetFramebufferParameteriv; +PFNGLNAMEDFRAMEBUFFERPARAMETERIEXTPROC gl3wNamedFramebufferParameteriEXT; +PFNGLGETNAMEDFRAMEBUFFERPARAMETERIVEXTPROC gl3wGetNamedFramebufferParameterivEXT; +PFNGLGETINTERNALFORMATI64VPROC gl3wGetInternalformati64v; +PFNGLINVALIDATETEXSUBIMAGEPROC gl3wInvalidateTexSubImage; +PFNGLINVALIDATETEXIMAGEPROC gl3wInvalidateTexImage; +PFNGLINVALIDATEBUFFERSUBDATAPROC gl3wInvalidateBufferSubData; +PFNGLINVALIDATEBUFFERDATAPROC gl3wInvalidateBufferData; +PFNGLINVALIDATEFRAMEBUFFERPROC gl3wInvalidateFramebuffer; +PFNGLINVALIDATESUBFRAMEBUFFERPROC gl3wInvalidateSubFramebuffer; +PFNGLMULTIDRAWARRAYSINDIRECTPROC gl3wMultiDrawArraysIndirect; +PFNGLMULTIDRAWELEMENTSINDIRECTPROC gl3wMultiDrawElementsIndirect; +PFNGLGETPROGRAMINTERFACEIVPROC gl3wGetProgramInterfaceiv; +PFNGLGETPROGRAMRESOURCEINDEXPROC gl3wGetProgramResourceIndex; +PFNGLGETPROGRAMRESOURCENAMEPROC gl3wGetProgramResourceName; +PFNGLGETPROGRAMRESOURCEIVPROC gl3wGetProgramResourceiv; +PFNGLGETPROGRAMRESOURCELOCATIONPROC gl3wGetProgramResourceLocation; +PFNGLGETPROGRAMRESOURCELOCATIONINDEXPROC gl3wGetProgramResourceLocationIndex; +PFNGLSHADERSTORAGEBLOCKBINDINGPROC gl3wShaderStorageBlockBinding; +PFNGLTEXBUFFERRANGEPROC gl3wTexBufferRange; +PFNGLTEXTUREBUFFERRANGEEXTPROC gl3wTextureBufferRangeEXT; +PFNGLTEXSTORAGE2DMULTISAMPLEPROC gl3wTexStorage2DMultisample; +PFNGLTEXSTORAGE3DMULTISAMPLEPROC gl3wTexStorage3DMultisample; +PFNGLTEXTURESTORAGE2DMULTISAMPLEEXTPROC gl3wTextureStorage2DMultisampleEXT; +PFNGLTEXTURESTORAGE3DMULTISAMPLEEXTPROC gl3wTextureStorage3DMultisampleEXT; + +static void load_procs(void) +{ + gl3wCullFace = (PFNGLCULLFACEPROC) get_proc("glCullFace"); + gl3wFrontFace = (PFNGLFRONTFACEPROC) get_proc("glFrontFace"); + gl3wHint = (PFNGLHINTPROC) get_proc("glHint"); + gl3wLineWidth = (PFNGLLINEWIDTHPROC) get_proc("glLineWidth"); + gl3wPointSize = (PFNGLPOINTSIZEPROC) get_proc("glPointSize"); + gl3wPolygonMode = (PFNGLPOLYGONMODEPROC) get_proc("glPolygonMode"); + gl3wScissor = (PFNGLSCISSORPROC) get_proc("glScissor"); + gl3wTexParameterf = (PFNGLTEXPARAMETERFPROC) get_proc("glTexParameterf"); + gl3wTexParameterfv = (PFNGLTEXPARAMETERFVPROC) get_proc("glTexParameterfv"); + gl3wTexParameteri = (PFNGLTEXPARAMETERIPROC) get_proc("glTexParameteri"); + gl3wTexParameteriv = (PFNGLTEXPARAMETERIVPROC) get_proc("glTexParameteriv"); + gl3wTexImage1D = (PFNGLTEXIMAGE1DPROC) get_proc("glTexImage1D"); + gl3wTexImage2D = (PFNGLTEXIMAGE2DPROC) get_proc("glTexImage2D"); + gl3wDrawBuffer = (PFNGLDRAWBUFFERPROC) get_proc("glDrawBuffer"); + gl3wClear = (PFNGLCLEARPROC) get_proc("glClear"); + gl3wClearColor = (PFNGLCLEARCOLORPROC) get_proc("glClearColor"); + gl3wClearStencil = (PFNGLCLEARSTENCILPROC) get_proc("glClearStencil"); + gl3wClearDepth = (PFNGLCLEARDEPTHPROC) get_proc("glClearDepth"); + gl3wStencilMask = (PFNGLSTENCILMASKPROC) get_proc("glStencilMask"); + gl3wColorMask = (PFNGLCOLORMASKPROC) get_proc("glColorMask"); + gl3wDepthMask = (PFNGLDEPTHMASKPROC) get_proc("glDepthMask"); + gl3wDisable = (PFNGLDISABLEPROC) get_proc("glDisable"); + gl3wEnable = (PFNGLENABLEPROC) get_proc("glEnable"); + gl3wFinish = (PFNGLFINISHPROC) get_proc("glFinish"); + gl3wFlush = (PFNGLFLUSHPROC) get_proc("glFlush"); + gl3wBlendFunc = (PFNGLBLENDFUNCPROC) get_proc("glBlendFunc"); + gl3wLogicOp = (PFNGLLOGICOPPROC) get_proc("glLogicOp"); + gl3wStencilFunc = (PFNGLSTENCILFUNCPROC) get_proc("glStencilFunc"); + gl3wStencilOp = (PFNGLSTENCILOPPROC) get_proc("glStencilOp"); + gl3wDepthFunc = (PFNGLDEPTHFUNCPROC) get_proc("glDepthFunc"); + gl3wPixelStoref = (PFNGLPIXELSTOREFPROC) get_proc("glPixelStoref"); + gl3wPixelStorei = (PFNGLPIXELSTOREIPROC) get_proc("glPixelStorei"); + gl3wReadBuffer = (PFNGLREADBUFFERPROC) get_proc("glReadBuffer"); + gl3wReadPixels = (PFNGLREADPIXELSPROC) get_proc("glReadPixels"); + gl3wGetBooleanv = (PFNGLGETBOOLEANVPROC) get_proc("glGetBooleanv"); + gl3wGetDoublev = (PFNGLGETDOUBLEVPROC) get_proc("glGetDoublev"); + gl3wGetError = (PFNGLGETERRORPROC) get_proc("glGetError"); + gl3wGetFloatv = (PFNGLGETFLOATVPROC) get_proc("glGetFloatv"); + gl3wGetIntegerv = (PFNGLGETINTEGERVPROC) get_proc("glGetIntegerv"); + gl3wGetString = (PFNGLGETSTRINGPROC) get_proc("glGetString"); + gl3wGetTexImage = (PFNGLGETTEXIMAGEPROC) get_proc("glGetTexImage"); + gl3wGetTexParameterfv = (PFNGLGETTEXPARAMETERFVPROC) get_proc("glGetTexParameterfv"); + gl3wGetTexParameteriv = (PFNGLGETTEXPARAMETERIVPROC) get_proc("glGetTexParameteriv"); + gl3wGetTexLevelParameterfv = (PFNGLGETTEXLEVELPARAMETERFVPROC) get_proc("glGetTexLevelParameterfv"); + gl3wGetTexLevelParameteriv = (PFNGLGETTEXLEVELPARAMETERIVPROC) get_proc("glGetTexLevelParameteriv"); + gl3wIsEnabled = (PFNGLISENABLEDPROC) get_proc("glIsEnabled"); + gl3wDepthRange = (PFNGLDEPTHRANGEPROC) get_proc("glDepthRange"); + gl3wViewport = (PFNGLVIEWPORTPROC) get_proc("glViewport"); + gl3wDrawArrays = (PFNGLDRAWARRAYSPROC) get_proc("glDrawArrays"); + gl3wDrawElements = (PFNGLDRAWELEMENTSPROC) get_proc("glDrawElements"); + gl3wGetPointerv = (PFNGLGETPOINTERVPROC) get_proc("glGetPointerv"); + gl3wPolygonOffset = (PFNGLPOLYGONOFFSETPROC) get_proc("glPolygonOffset"); + gl3wCopyTexImage1D = (PFNGLCOPYTEXIMAGE1DPROC) get_proc("glCopyTexImage1D"); + gl3wCopyTexImage2D = (PFNGLCOPYTEXIMAGE2DPROC) get_proc("glCopyTexImage2D"); + gl3wCopyTexSubImage1D = (PFNGLCOPYTEXSUBIMAGE1DPROC) get_proc("glCopyTexSubImage1D"); + gl3wCopyTexSubImage2D = (PFNGLCOPYTEXSUBIMAGE2DPROC) get_proc("glCopyTexSubImage2D"); + gl3wTexSubImage1D = (PFNGLTEXSUBIMAGE1DPROC) get_proc("glTexSubImage1D"); + gl3wTexSubImage2D = (PFNGLTEXSUBIMAGE2DPROC) get_proc("glTexSubImage2D"); + gl3wBindTexture = (PFNGLBINDTEXTUREPROC) get_proc("glBindTexture"); + gl3wDeleteTextures = (PFNGLDELETETEXTURESPROC) get_proc("glDeleteTextures"); + gl3wGenTextures = (PFNGLGENTEXTURESPROC) get_proc("glGenTextures"); + gl3wIsTexture = (PFNGLISTEXTUREPROC) get_proc("glIsTexture"); + gl3wBlendColor = (PFNGLBLENDCOLORPROC) get_proc("glBlendColor"); + gl3wBlendEquation = (PFNGLBLENDEQUATIONPROC) get_proc("glBlendEquation"); + gl3wDrawRangeElements = (PFNGLDRAWRANGEELEMENTSPROC) get_proc("glDrawRangeElements"); + gl3wTexImage3D = (PFNGLTEXIMAGE3DPROC) get_proc("glTexImage3D"); + gl3wTexSubImage3D = (PFNGLTEXSUBIMAGE3DPROC) get_proc("glTexSubImage3D"); + gl3wCopyTexSubImage3D = (PFNGLCOPYTEXSUBIMAGE3DPROC) get_proc("glCopyTexSubImage3D"); + gl3wActiveTexture = (PFNGLACTIVETEXTUREPROC) get_proc("glActiveTexture"); + gl3wSampleCoverage = (PFNGLSAMPLECOVERAGEPROC) get_proc("glSampleCoverage"); + gl3wCompressedTexImage3D = (PFNGLCOMPRESSEDTEXIMAGE3DPROC) get_proc("glCompressedTexImage3D"); + gl3wCompressedTexImage2D = (PFNGLCOMPRESSEDTEXIMAGE2DPROC) get_proc("glCompressedTexImage2D"); + gl3wCompressedTexImage1D = (PFNGLCOMPRESSEDTEXIMAGE1DPROC) get_proc("glCompressedTexImage1D"); + gl3wCompressedTexSubImage3D = (PFNGLCOMPRESSEDTEXSUBIMAGE3DPROC) get_proc("glCompressedTexSubImage3D"); + gl3wCompressedTexSubImage2D = (PFNGLCOMPRESSEDTEXSUBIMAGE2DPROC) get_proc("glCompressedTexSubImage2D"); + gl3wCompressedTexSubImage1D = (PFNGLCOMPRESSEDTEXSUBIMAGE1DPROC) get_proc("glCompressedTexSubImage1D"); + gl3wGetCompressedTexImage = (PFNGLGETCOMPRESSEDTEXIMAGEPROC) get_proc("glGetCompressedTexImage"); + gl3wBlendFuncSeparate = (PFNGLBLENDFUNCSEPARATEPROC) get_proc("glBlendFuncSeparate"); + gl3wMultiDrawArrays = (PFNGLMULTIDRAWARRAYSPROC) get_proc("glMultiDrawArrays"); + gl3wMultiDrawElements = (PFNGLMULTIDRAWELEMENTSPROC) get_proc("glMultiDrawElements"); + gl3wPointParameterf = (PFNGLPOINTPARAMETERFPROC) get_proc("glPointParameterf"); + gl3wPointParameterfv = (PFNGLPOINTPARAMETERFVPROC) get_proc("glPointParameterfv"); + gl3wPointParameteri = (PFNGLPOINTPARAMETERIPROC) get_proc("glPointParameteri"); + gl3wPointParameteriv = (PFNGLPOINTPARAMETERIVPROC) get_proc("glPointParameteriv"); + gl3wGenQueries = (PFNGLGENQUERIESPROC) get_proc("glGenQueries"); + gl3wDeleteQueries = (PFNGLDELETEQUERIESPROC) get_proc("glDeleteQueries"); + gl3wIsQuery = (PFNGLISQUERYPROC) get_proc("glIsQuery"); + gl3wBeginQuery = (PFNGLBEGINQUERYPROC) get_proc("glBeginQuery"); + gl3wEndQuery = (PFNGLENDQUERYPROC) get_proc("glEndQuery"); + gl3wGetQueryiv = (PFNGLGETQUERYIVPROC) get_proc("glGetQueryiv"); + gl3wGetQueryObjectiv = (PFNGLGETQUERYOBJECTIVPROC) get_proc("glGetQueryObjectiv"); + gl3wGetQueryObjectuiv = (PFNGLGETQUERYOBJECTUIVPROC) get_proc("glGetQueryObjectuiv"); + gl3wBindBuffer = (PFNGLBINDBUFFERPROC) get_proc("glBindBuffer"); + gl3wDeleteBuffers = (PFNGLDELETEBUFFERSPROC) get_proc("glDeleteBuffers"); + gl3wGenBuffers = (PFNGLGENBUFFERSPROC) get_proc("glGenBuffers"); + gl3wIsBuffer = (PFNGLISBUFFERPROC) get_proc("glIsBuffer"); + gl3wBufferData = (PFNGLBUFFERDATAPROC) get_proc("glBufferData"); + gl3wBufferSubData = (PFNGLBUFFERSUBDATAPROC) get_proc("glBufferSubData"); + gl3wGetBufferSubData = (PFNGLGETBUFFERSUBDATAPROC) get_proc("glGetBufferSubData"); + gl3wMapBuffer = (PFNGLMAPBUFFERPROC) get_proc("glMapBuffer"); + gl3wUnmapBuffer = (PFNGLUNMAPBUFFERPROC) get_proc("glUnmapBuffer"); + gl3wGetBufferParameteriv = (PFNGLGETBUFFERPARAMETERIVPROC) get_proc("glGetBufferParameteriv"); + gl3wGetBufferPointerv = (PFNGLGETBUFFERPOINTERVPROC) get_proc("glGetBufferPointerv"); + gl3wBlendEquationSeparate = (PFNGLBLENDEQUATIONSEPARATEPROC) get_proc("glBlendEquationSeparate"); + gl3wDrawBuffers = (PFNGLDRAWBUFFERSPROC) get_proc("glDrawBuffers"); + gl3wStencilOpSeparate = (PFNGLSTENCILOPSEPARATEPROC) get_proc("glStencilOpSeparate"); + gl3wStencilFuncSeparate = (PFNGLSTENCILFUNCSEPARATEPROC) get_proc("glStencilFuncSeparate"); + gl3wStencilMaskSeparate = (PFNGLSTENCILMASKSEPARATEPROC) get_proc("glStencilMaskSeparate"); + gl3wAttachShader = (PFNGLATTACHSHADERPROC) get_proc("glAttachShader"); + gl3wBindAttribLocation = (PFNGLBINDATTRIBLOCATIONPROC) get_proc("glBindAttribLocation"); + gl3wCompileShader = (PFNGLCOMPILESHADERPROC) get_proc("glCompileShader"); + gl3wCreateProgram = (PFNGLCREATEPROGRAMPROC) get_proc("glCreateProgram"); + gl3wCreateShader = (PFNGLCREATESHADERPROC) get_proc("glCreateShader"); + gl3wDeleteProgram = (PFNGLDELETEPROGRAMPROC) get_proc("glDeleteProgram"); + gl3wDeleteShader = (PFNGLDELETESHADERPROC) get_proc("glDeleteShader"); + gl3wDetachShader = (PFNGLDETACHSHADERPROC) get_proc("glDetachShader"); + gl3wDisableVertexAttribArray = (PFNGLDISABLEVERTEXATTRIBARRAYPROC) get_proc("glDisableVertexAttribArray"); + gl3wEnableVertexAttribArray = (PFNGLENABLEVERTEXATTRIBARRAYPROC) get_proc("glEnableVertexAttribArray"); + gl3wGetActiveAttrib = (PFNGLGETACTIVEATTRIBPROC) get_proc("glGetActiveAttrib"); + gl3wGetActiveUniform = (PFNGLGETACTIVEUNIFORMPROC) get_proc("glGetActiveUniform"); + gl3wGetAttachedShaders = (PFNGLGETATTACHEDSHADERSPROC) get_proc("glGetAttachedShaders"); + gl3wGetAttribLocation = (PFNGLGETATTRIBLOCATIONPROC) get_proc("glGetAttribLocation"); + gl3wGetProgramiv = (PFNGLGETPROGRAMIVPROC) get_proc("glGetProgramiv"); + gl3wGetProgramInfoLog = (PFNGLGETPROGRAMINFOLOGPROC) get_proc("glGetProgramInfoLog"); + gl3wGetShaderiv = (PFNGLGETSHADERIVPROC) get_proc("glGetShaderiv"); + gl3wGetShaderInfoLog = (PFNGLGETSHADERINFOLOGPROC) get_proc("glGetShaderInfoLog"); + gl3wGetShaderSource = (PFNGLGETSHADERSOURCEPROC) get_proc("glGetShaderSource"); + gl3wGetUniformLocation = (PFNGLGETUNIFORMLOCATIONPROC) get_proc("glGetUniformLocation"); + gl3wGetUniformfv = (PFNGLGETUNIFORMFVPROC) get_proc("glGetUniformfv"); + gl3wGetUniformiv = (PFNGLGETUNIFORMIVPROC) get_proc("glGetUniformiv"); + gl3wGetVertexAttribdv = (PFNGLGETVERTEXATTRIBDVPROC) get_proc("glGetVertexAttribdv"); + gl3wGetVertexAttribfv = (PFNGLGETVERTEXATTRIBFVPROC) get_proc("glGetVertexAttribfv"); + gl3wGetVertexAttribiv = (PFNGLGETVERTEXATTRIBIVPROC) get_proc("glGetVertexAttribiv"); + gl3wGetVertexAttribPointerv = (PFNGLGETVERTEXATTRIBPOINTERVPROC) get_proc("glGetVertexAttribPointerv"); + gl3wIsProgram = (PFNGLISPROGRAMPROC) get_proc("glIsProgram"); + gl3wIsShader = (PFNGLISSHADERPROC) get_proc("glIsShader"); + gl3wLinkProgram = (PFNGLLINKPROGRAMPROC) get_proc("glLinkProgram"); + gl3wShaderSource = (PFNGLSHADERSOURCEPROC) get_proc("glShaderSource"); + gl3wUseProgram = (PFNGLUSEPROGRAMPROC) get_proc("glUseProgram"); + gl3wUniform1f = (PFNGLUNIFORM1FPROC) get_proc("glUniform1f"); + gl3wUniform2f = (PFNGLUNIFORM2FPROC) get_proc("glUniform2f"); + gl3wUniform3f = (PFNGLUNIFORM3FPROC) get_proc("glUniform3f"); + gl3wUniform4f = (PFNGLUNIFORM4FPROC) get_proc("glUniform4f"); + gl3wUniform1i = (PFNGLUNIFORM1IPROC) get_proc("glUniform1i"); + gl3wUniform2i = (PFNGLUNIFORM2IPROC) get_proc("glUniform2i"); + gl3wUniform3i = (PFNGLUNIFORM3IPROC) get_proc("glUniform3i"); + gl3wUniform4i = (PFNGLUNIFORM4IPROC) get_proc("glUniform4i"); + gl3wUniform1fv = (PFNGLUNIFORM1FVPROC) get_proc("glUniform1fv"); + gl3wUniform2fv = (PFNGLUNIFORM2FVPROC) get_proc("glUniform2fv"); + gl3wUniform3fv = (PFNGLUNIFORM3FVPROC) get_proc("glUniform3fv"); + gl3wUniform4fv = (PFNGLUNIFORM4FVPROC) get_proc("glUniform4fv"); + gl3wUniform1iv = (PFNGLUNIFORM1IVPROC) get_proc("glUniform1iv"); + gl3wUniform2iv = (PFNGLUNIFORM2IVPROC) get_proc("glUniform2iv"); + gl3wUniform3iv = (PFNGLUNIFORM3IVPROC) get_proc("glUniform3iv"); + gl3wUniform4iv = (PFNGLUNIFORM4IVPROC) get_proc("glUniform4iv"); + gl3wUniformMatrix2fv = (PFNGLUNIFORMMATRIX2FVPROC) get_proc("glUniformMatrix2fv"); + gl3wUniformMatrix3fv = (PFNGLUNIFORMMATRIX3FVPROC) get_proc("glUniformMatrix3fv"); + gl3wUniformMatrix4fv = (PFNGLUNIFORMMATRIX4FVPROC) get_proc("glUniformMatrix4fv"); + gl3wValidateProgram = (PFNGLVALIDATEPROGRAMPROC) get_proc("glValidateProgram"); + gl3wVertexAttrib1d = (PFNGLVERTEXATTRIB1DPROC) get_proc("glVertexAttrib1d"); + gl3wVertexAttrib1dv = (PFNGLVERTEXATTRIB1DVPROC) get_proc("glVertexAttrib1dv"); + gl3wVertexAttrib1f = (PFNGLVERTEXATTRIB1FPROC) get_proc("glVertexAttrib1f"); + gl3wVertexAttrib1fv = (PFNGLVERTEXATTRIB1FVPROC) get_proc("glVertexAttrib1fv"); + gl3wVertexAttrib1s = (PFNGLVERTEXATTRIB1SPROC) get_proc("glVertexAttrib1s"); + gl3wVertexAttrib1sv = (PFNGLVERTEXATTRIB1SVPROC) get_proc("glVertexAttrib1sv"); + gl3wVertexAttrib2d = (PFNGLVERTEXATTRIB2DPROC) get_proc("glVertexAttrib2d"); + gl3wVertexAttrib2dv = (PFNGLVERTEXATTRIB2DVPROC) get_proc("glVertexAttrib2dv"); + gl3wVertexAttrib2f = (PFNGLVERTEXATTRIB2FPROC) get_proc("glVertexAttrib2f"); + gl3wVertexAttrib2fv = (PFNGLVERTEXATTRIB2FVPROC) get_proc("glVertexAttrib2fv"); + gl3wVertexAttrib2s = (PFNGLVERTEXATTRIB2SPROC) get_proc("glVertexAttrib2s"); + gl3wVertexAttrib2sv = (PFNGLVERTEXATTRIB2SVPROC) get_proc("glVertexAttrib2sv"); + gl3wVertexAttrib3d = (PFNGLVERTEXATTRIB3DPROC) get_proc("glVertexAttrib3d"); + gl3wVertexAttrib3dv = (PFNGLVERTEXATTRIB3DVPROC) get_proc("glVertexAttrib3dv"); + gl3wVertexAttrib3f = (PFNGLVERTEXATTRIB3FPROC) get_proc("glVertexAttrib3f"); + gl3wVertexAttrib3fv = (PFNGLVERTEXATTRIB3FVPROC) get_proc("glVertexAttrib3fv"); + gl3wVertexAttrib3s = (PFNGLVERTEXATTRIB3SPROC) get_proc("glVertexAttrib3s"); + gl3wVertexAttrib3sv = (PFNGLVERTEXATTRIB3SVPROC) get_proc("glVertexAttrib3sv"); + gl3wVertexAttrib4Nbv = (PFNGLVERTEXATTRIB4NBVPROC) get_proc("glVertexAttrib4Nbv"); + gl3wVertexAttrib4Niv = (PFNGLVERTEXATTRIB4NIVPROC) get_proc("glVertexAttrib4Niv"); + gl3wVertexAttrib4Nsv = (PFNGLVERTEXATTRIB4NSVPROC) get_proc("glVertexAttrib4Nsv"); + gl3wVertexAttrib4Nub = (PFNGLVERTEXATTRIB4NUBPROC) get_proc("glVertexAttrib4Nub"); + gl3wVertexAttrib4Nubv = (PFNGLVERTEXATTRIB4NUBVPROC) get_proc("glVertexAttrib4Nubv"); + gl3wVertexAttrib4Nuiv = (PFNGLVERTEXATTRIB4NUIVPROC) get_proc("glVertexAttrib4Nuiv"); + gl3wVertexAttrib4Nusv = (PFNGLVERTEXATTRIB4NUSVPROC) get_proc("glVertexAttrib4Nusv"); + gl3wVertexAttrib4bv = (PFNGLVERTEXATTRIB4BVPROC) get_proc("glVertexAttrib4bv"); + gl3wVertexAttrib4d = (PFNGLVERTEXATTRIB4DPROC) get_proc("glVertexAttrib4d"); + gl3wVertexAttrib4dv = (PFNGLVERTEXATTRIB4DVPROC) get_proc("glVertexAttrib4dv"); + gl3wVertexAttrib4f = (PFNGLVERTEXATTRIB4FPROC) get_proc("glVertexAttrib4f"); + gl3wVertexAttrib4fv = (PFNGLVERTEXATTRIB4FVPROC) get_proc("glVertexAttrib4fv"); + gl3wVertexAttrib4iv = (PFNGLVERTEXATTRIB4IVPROC) get_proc("glVertexAttrib4iv"); + gl3wVertexAttrib4s = (PFNGLVERTEXATTRIB4SPROC) get_proc("glVertexAttrib4s"); + gl3wVertexAttrib4sv = (PFNGLVERTEXATTRIB4SVPROC) get_proc("glVertexAttrib4sv"); + gl3wVertexAttrib4ubv = (PFNGLVERTEXATTRIB4UBVPROC) get_proc("glVertexAttrib4ubv"); + gl3wVertexAttrib4uiv = (PFNGLVERTEXATTRIB4UIVPROC) get_proc("glVertexAttrib4uiv"); + gl3wVertexAttrib4usv = (PFNGLVERTEXATTRIB4USVPROC) get_proc("glVertexAttrib4usv"); + gl3wVertexAttribPointer = (PFNGLVERTEXATTRIBPOINTERPROC) get_proc("glVertexAttribPointer"); + gl3wUniformMatrix2x3fv = (PFNGLUNIFORMMATRIX2X3FVPROC) get_proc("glUniformMatrix2x3fv"); + gl3wUniformMatrix3x2fv = (PFNGLUNIFORMMATRIX3X2FVPROC) get_proc("glUniformMatrix3x2fv"); + gl3wUniformMatrix2x4fv = (PFNGLUNIFORMMATRIX2X4FVPROC) get_proc("glUniformMatrix2x4fv"); + gl3wUniformMatrix4x2fv = (PFNGLUNIFORMMATRIX4X2FVPROC) get_proc("glUniformMatrix4x2fv"); + gl3wUniformMatrix3x4fv = (PFNGLUNIFORMMATRIX3X4FVPROC) get_proc("glUniformMatrix3x4fv"); + gl3wUniformMatrix4x3fv = (PFNGLUNIFORMMATRIX4X3FVPROC) get_proc("glUniformMatrix4x3fv"); + gl3wColorMaski = (PFNGLCOLORMASKIPROC) get_proc("glColorMaski"); + gl3wGetBooleani_v = (PFNGLGETBOOLEANI_VPROC) get_proc("glGetBooleani_v"); + gl3wGetIntegeri_v = (PFNGLGETINTEGERI_VPROC) get_proc("glGetIntegeri_v"); + gl3wEnablei = (PFNGLENABLEIPROC) get_proc("glEnablei"); + gl3wDisablei = (PFNGLDISABLEIPROC) get_proc("glDisablei"); + gl3wIsEnabledi = (PFNGLISENABLEDIPROC) get_proc("glIsEnabledi"); + gl3wBeginTransformFeedback = (PFNGLBEGINTRANSFORMFEEDBACKPROC) get_proc("glBeginTransformFeedback"); + gl3wEndTransformFeedback = (PFNGLENDTRANSFORMFEEDBACKPROC) get_proc("glEndTransformFeedback"); + gl3wBindBufferRange = (PFNGLBINDBUFFERRANGEPROC) get_proc("glBindBufferRange"); + gl3wBindBufferBase = (PFNGLBINDBUFFERBASEPROC) get_proc("glBindBufferBase"); + gl3wTransformFeedbackVaryings = (PFNGLTRANSFORMFEEDBACKVARYINGSPROC) get_proc("glTransformFeedbackVaryings"); + gl3wGetTransformFeedbackVarying = (PFNGLGETTRANSFORMFEEDBACKVARYINGPROC) get_proc("glGetTransformFeedbackVarying"); + gl3wClampColor = (PFNGLCLAMPCOLORPROC) get_proc("glClampColor"); + gl3wBeginConditionalRender = (PFNGLBEGINCONDITIONALRENDERPROC) get_proc("glBeginConditionalRender"); + gl3wEndConditionalRender = (PFNGLENDCONDITIONALRENDERPROC) get_proc("glEndConditionalRender"); + gl3wVertexAttribIPointer = (PFNGLVERTEXATTRIBIPOINTERPROC) get_proc("glVertexAttribIPointer"); + gl3wGetVertexAttribIiv = (PFNGLGETVERTEXATTRIBIIVPROC) get_proc("glGetVertexAttribIiv"); + gl3wGetVertexAttribIuiv = (PFNGLGETVERTEXATTRIBIUIVPROC) get_proc("glGetVertexAttribIuiv"); + gl3wVertexAttribI1i = (PFNGLVERTEXATTRIBI1IPROC) get_proc("glVertexAttribI1i"); + gl3wVertexAttribI2i = (PFNGLVERTEXATTRIBI2IPROC) get_proc("glVertexAttribI2i"); + gl3wVertexAttribI3i = (PFNGLVERTEXATTRIBI3IPROC) get_proc("glVertexAttribI3i"); + gl3wVertexAttribI4i = (PFNGLVERTEXATTRIBI4IPROC) get_proc("glVertexAttribI4i"); + gl3wVertexAttribI1ui = (PFNGLVERTEXATTRIBI1UIPROC) get_proc("glVertexAttribI1ui"); + gl3wVertexAttribI2ui = (PFNGLVERTEXATTRIBI2UIPROC) get_proc("glVertexAttribI2ui"); + gl3wVertexAttribI3ui = (PFNGLVERTEXATTRIBI3UIPROC) get_proc("glVertexAttribI3ui"); + gl3wVertexAttribI4ui = (PFNGLVERTEXATTRIBI4UIPROC) get_proc("glVertexAttribI4ui"); + gl3wVertexAttribI1iv = (PFNGLVERTEXATTRIBI1IVPROC) get_proc("glVertexAttribI1iv"); + gl3wVertexAttribI2iv = (PFNGLVERTEXATTRIBI2IVPROC) get_proc("glVertexAttribI2iv"); + gl3wVertexAttribI3iv = (PFNGLVERTEXATTRIBI3IVPROC) get_proc("glVertexAttribI3iv"); + gl3wVertexAttribI4iv = (PFNGLVERTEXATTRIBI4IVPROC) get_proc("glVertexAttribI4iv"); + gl3wVertexAttribI1uiv = (PFNGLVERTEXATTRIBI1UIVPROC) get_proc("glVertexAttribI1uiv"); + gl3wVertexAttribI2uiv = (PFNGLVERTEXATTRIBI2UIVPROC) get_proc("glVertexAttribI2uiv"); + gl3wVertexAttribI3uiv = (PFNGLVERTEXATTRIBI3UIVPROC) get_proc("glVertexAttribI3uiv"); + gl3wVertexAttribI4uiv = (PFNGLVERTEXATTRIBI4UIVPROC) get_proc("glVertexAttribI4uiv"); + gl3wVertexAttribI4bv = (PFNGLVERTEXATTRIBI4BVPROC) get_proc("glVertexAttribI4bv"); + gl3wVertexAttribI4sv = (PFNGLVERTEXATTRIBI4SVPROC) get_proc("glVertexAttribI4sv"); + gl3wVertexAttribI4ubv = (PFNGLVERTEXATTRIBI4UBVPROC) get_proc("glVertexAttribI4ubv"); + gl3wVertexAttribI4usv = (PFNGLVERTEXATTRIBI4USVPROC) get_proc("glVertexAttribI4usv"); + gl3wGetUniformuiv = (PFNGLGETUNIFORMUIVPROC) get_proc("glGetUniformuiv"); + gl3wBindFragDataLocation = (PFNGLBINDFRAGDATALOCATIONPROC) get_proc("glBindFragDataLocation"); + gl3wGetFragDataLocation = (PFNGLGETFRAGDATALOCATIONPROC) get_proc("glGetFragDataLocation"); + gl3wUniform1ui = (PFNGLUNIFORM1UIPROC) get_proc("glUniform1ui"); + gl3wUniform2ui = (PFNGLUNIFORM2UIPROC) get_proc("glUniform2ui"); + gl3wUniform3ui = (PFNGLUNIFORM3UIPROC) get_proc("glUniform3ui"); + gl3wUniform4ui = (PFNGLUNIFORM4UIPROC) get_proc("glUniform4ui"); + gl3wUniform1uiv = (PFNGLUNIFORM1UIVPROC) get_proc("glUniform1uiv"); + gl3wUniform2uiv = (PFNGLUNIFORM2UIVPROC) get_proc("glUniform2uiv"); + gl3wUniform3uiv = (PFNGLUNIFORM3UIVPROC) get_proc("glUniform3uiv"); + gl3wUniform4uiv = (PFNGLUNIFORM4UIVPROC) get_proc("glUniform4uiv"); + gl3wTexParameterIiv = (PFNGLTEXPARAMETERIIVPROC) get_proc("glTexParameterIiv"); + gl3wTexParameterIuiv = (PFNGLTEXPARAMETERIUIVPROC) get_proc("glTexParameterIuiv"); + gl3wGetTexParameterIiv = (PFNGLGETTEXPARAMETERIIVPROC) get_proc("glGetTexParameterIiv"); + gl3wGetTexParameterIuiv = (PFNGLGETTEXPARAMETERIUIVPROC) get_proc("glGetTexParameterIuiv"); + gl3wClearBufferiv = (PFNGLCLEARBUFFERIVPROC) get_proc("glClearBufferiv"); + gl3wClearBufferuiv = (PFNGLCLEARBUFFERUIVPROC) get_proc("glClearBufferuiv"); + gl3wClearBufferfv = (PFNGLCLEARBUFFERFVPROC) get_proc("glClearBufferfv"); + gl3wClearBufferfi = (PFNGLCLEARBUFFERFIPROC) get_proc("glClearBufferfi"); + gl3wGetStringi = (PFNGLGETSTRINGIPROC) get_proc("glGetStringi"); + gl3wDrawArraysInstanced = (PFNGLDRAWARRAYSINSTANCEDPROC) get_proc("glDrawArraysInstanced"); + gl3wDrawElementsInstanced = (PFNGLDRAWELEMENTSINSTANCEDPROC) get_proc("glDrawElementsInstanced"); + gl3wTexBuffer = (PFNGLTEXBUFFERPROC) get_proc("glTexBuffer"); + gl3wPrimitiveRestartIndex = (PFNGLPRIMITIVERESTARTINDEXPROC) get_proc("glPrimitiveRestartIndex"); + gl3wGetInteger64i_v = (PFNGLGETINTEGER64I_VPROC) get_proc("glGetInteger64i_v"); + gl3wGetBufferParameteri64v = (PFNGLGETBUFFERPARAMETERI64VPROC) get_proc("glGetBufferParameteri64v"); + gl3wFramebufferTexture = (PFNGLFRAMEBUFFERTEXTUREPROC) get_proc("glFramebufferTexture"); + gl3wVertexAttribDivisor = (PFNGLVERTEXATTRIBDIVISORPROC) get_proc("glVertexAttribDivisor"); + gl3wMinSampleShading = (PFNGLMINSAMPLESHADINGPROC) get_proc("glMinSampleShading"); + gl3wBlendEquationi = (PFNGLBLENDEQUATIONIPROC) get_proc("glBlendEquationi"); + gl3wBlendEquationSeparatei = (PFNGLBLENDEQUATIONSEPARATEIPROC) get_proc("glBlendEquationSeparatei"); + gl3wBlendFunci = (PFNGLBLENDFUNCIPROC) get_proc("glBlendFunci"); + gl3wBlendFuncSeparatei = (PFNGLBLENDFUNCSEPARATEIPROC) get_proc("glBlendFuncSeparatei"); + gl3wIsRenderbuffer = (PFNGLISRENDERBUFFERPROC) get_proc("glIsRenderbuffer"); + gl3wBindRenderbuffer = (PFNGLBINDRENDERBUFFERPROC) get_proc("glBindRenderbuffer"); + gl3wDeleteRenderbuffers = (PFNGLDELETERENDERBUFFERSPROC) get_proc("glDeleteRenderbuffers"); + gl3wGenRenderbuffers = (PFNGLGENRENDERBUFFERSPROC) get_proc("glGenRenderbuffers"); + gl3wRenderbufferStorage = (PFNGLRENDERBUFFERSTORAGEPROC) get_proc("glRenderbufferStorage"); + gl3wGetRenderbufferParameteriv = (PFNGLGETRENDERBUFFERPARAMETERIVPROC) get_proc("glGetRenderbufferParameteriv"); + gl3wIsFramebuffer = (PFNGLISFRAMEBUFFERPROC) get_proc("glIsFramebuffer"); + gl3wBindFramebuffer = (PFNGLBINDFRAMEBUFFERPROC) get_proc("glBindFramebuffer"); + gl3wDeleteFramebuffers = (PFNGLDELETEFRAMEBUFFERSPROC) get_proc("glDeleteFramebuffers"); + gl3wGenFramebuffers = (PFNGLGENFRAMEBUFFERSPROC) get_proc("glGenFramebuffers"); + gl3wCheckFramebufferStatus = (PFNGLCHECKFRAMEBUFFERSTATUSPROC) get_proc("glCheckFramebufferStatus"); + gl3wFramebufferTexture1D = (PFNGLFRAMEBUFFERTEXTURE1DPROC) get_proc("glFramebufferTexture1D"); + gl3wFramebufferTexture2D = (PFNGLFRAMEBUFFERTEXTURE2DPROC) get_proc("glFramebufferTexture2D"); + gl3wFramebufferTexture3D = (PFNGLFRAMEBUFFERTEXTURE3DPROC) get_proc("glFramebufferTexture3D"); + gl3wFramebufferRenderbuffer = (PFNGLFRAMEBUFFERRENDERBUFFERPROC) get_proc("glFramebufferRenderbuffer"); + gl3wGetFramebufferAttachmentParameteriv = (PFNGLGETFRAMEBUFFERATTACHMENTPARAMETERIVPROC) get_proc("glGetFramebufferAttachmentParameteriv"); + gl3wGenerateMipmap = (PFNGLGENERATEMIPMAPPROC) get_proc("glGenerateMipmap"); + gl3wBlitFramebuffer = (PFNGLBLITFRAMEBUFFERPROC) get_proc("glBlitFramebuffer"); + gl3wRenderbufferStorageMultisample = (PFNGLRENDERBUFFERSTORAGEMULTISAMPLEPROC) get_proc("glRenderbufferStorageMultisample"); + gl3wFramebufferTextureLayer = (PFNGLFRAMEBUFFERTEXTURELAYERPROC) get_proc("glFramebufferTextureLayer"); + gl3wMapBufferRange = (PFNGLMAPBUFFERRANGEPROC) get_proc("glMapBufferRange"); + gl3wFlushMappedBufferRange = (PFNGLFLUSHMAPPEDBUFFERRANGEPROC) get_proc("glFlushMappedBufferRange"); + gl3wBindVertexArray = (PFNGLBINDVERTEXARRAYPROC) get_proc("glBindVertexArray"); + gl3wDeleteVertexArrays = (PFNGLDELETEVERTEXARRAYSPROC) get_proc("glDeleteVertexArrays"); + gl3wGenVertexArrays = (PFNGLGENVERTEXARRAYSPROC) get_proc("glGenVertexArrays"); + gl3wIsVertexArray = (PFNGLISVERTEXARRAYPROC) get_proc("glIsVertexArray"); + gl3wGetUniformIndices = (PFNGLGETUNIFORMINDICESPROC) get_proc("glGetUniformIndices"); + gl3wGetActiveUniformsiv = (PFNGLGETACTIVEUNIFORMSIVPROC) get_proc("glGetActiveUniformsiv"); + gl3wGetActiveUniformName = (PFNGLGETACTIVEUNIFORMNAMEPROC) get_proc("glGetActiveUniformName"); + gl3wGetUniformBlockIndex = (PFNGLGETUNIFORMBLOCKINDEXPROC) get_proc("glGetUniformBlockIndex"); + gl3wGetActiveUniformBlockiv = (PFNGLGETACTIVEUNIFORMBLOCKIVPROC) get_proc("glGetActiveUniformBlockiv"); + gl3wGetActiveUniformBlockName = (PFNGLGETACTIVEUNIFORMBLOCKNAMEPROC) get_proc("glGetActiveUniformBlockName"); + gl3wUniformBlockBinding = (PFNGLUNIFORMBLOCKBINDINGPROC) get_proc("glUniformBlockBinding"); + gl3wCopyBufferSubData = (PFNGLCOPYBUFFERSUBDATAPROC) get_proc("glCopyBufferSubData"); + gl3wDrawElementsBaseVertex = (PFNGLDRAWELEMENTSBASEVERTEXPROC) get_proc("glDrawElementsBaseVertex"); + gl3wDrawRangeElementsBaseVertex = (PFNGLDRAWRANGEELEMENTSBASEVERTEXPROC) get_proc("glDrawRangeElementsBaseVertex"); + gl3wDrawElementsInstancedBaseVertex = (PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXPROC) get_proc("glDrawElementsInstancedBaseVertex"); + gl3wMultiDrawElementsBaseVertex = (PFNGLMULTIDRAWELEMENTSBASEVERTEXPROC) get_proc("glMultiDrawElementsBaseVertex"); + gl3wProvokingVertex = (PFNGLPROVOKINGVERTEXPROC) get_proc("glProvokingVertex"); + gl3wFenceSync = (PFNGLFENCESYNCPROC) get_proc("glFenceSync"); + gl3wIsSync = (PFNGLISSYNCPROC) get_proc("glIsSync"); + gl3wDeleteSync = (PFNGLDELETESYNCPROC) get_proc("glDeleteSync"); + gl3wClientWaitSync = (PFNGLCLIENTWAITSYNCPROC) get_proc("glClientWaitSync"); + gl3wWaitSync = (PFNGLWAITSYNCPROC) get_proc("glWaitSync"); + gl3wGetInteger64v = (PFNGLGETINTEGER64VPROC) get_proc("glGetInteger64v"); + gl3wGetSynciv = (PFNGLGETSYNCIVPROC) get_proc("glGetSynciv"); + gl3wTexImage2DMultisample = (PFNGLTEXIMAGE2DMULTISAMPLEPROC) get_proc("glTexImage2DMultisample"); + gl3wTexImage3DMultisample = (PFNGLTEXIMAGE3DMULTISAMPLEPROC) get_proc("glTexImage3DMultisample"); + gl3wGetMultisamplefv = (PFNGLGETMULTISAMPLEFVPROC) get_proc("glGetMultisamplefv"); + gl3wSampleMaski = (PFNGLSAMPLEMASKIPROC) get_proc("glSampleMaski"); + gl3wBlendEquationiARB = (PFNGLBLENDEQUATIONIARBPROC) get_proc("glBlendEquationiARB"); + gl3wBlendEquationSeparateiARB = (PFNGLBLENDEQUATIONSEPARATEIARBPROC) get_proc("glBlendEquationSeparateiARB"); + gl3wBlendFunciARB = (PFNGLBLENDFUNCIARBPROC) get_proc("glBlendFunciARB"); + gl3wBlendFuncSeparateiARB = (PFNGLBLENDFUNCSEPARATEIARBPROC) get_proc("glBlendFuncSeparateiARB"); + gl3wMinSampleShadingARB = (PFNGLMINSAMPLESHADINGARBPROC) get_proc("glMinSampleShadingARB"); + gl3wNamedStringARB = (PFNGLNAMEDSTRINGARBPROC) get_proc("glNamedStringARB"); + gl3wDeleteNamedStringARB = (PFNGLDELETENAMEDSTRINGARBPROC) get_proc("glDeleteNamedStringARB"); + gl3wCompileShaderIncludeARB = (PFNGLCOMPILESHADERINCLUDEARBPROC) get_proc("glCompileShaderIncludeARB"); + gl3wIsNamedStringARB = (PFNGLISNAMEDSTRINGARBPROC) get_proc("glIsNamedStringARB"); + gl3wGetNamedStringARB = (PFNGLGETNAMEDSTRINGARBPROC) get_proc("glGetNamedStringARB"); + gl3wGetNamedStringivARB = (PFNGLGETNAMEDSTRINGIVARBPROC) get_proc("glGetNamedStringivARB"); + gl3wBindFragDataLocationIndexed = (PFNGLBINDFRAGDATALOCATIONINDEXEDPROC) get_proc("glBindFragDataLocationIndexed"); + gl3wGetFragDataIndex = (PFNGLGETFRAGDATAINDEXPROC) get_proc("glGetFragDataIndex"); + gl3wGenSamplers = (PFNGLGENSAMPLERSPROC) get_proc("glGenSamplers"); + gl3wDeleteSamplers = (PFNGLDELETESAMPLERSPROC) get_proc("glDeleteSamplers"); + gl3wIsSampler = (PFNGLISSAMPLERPROC) get_proc("glIsSampler"); + gl3wBindSampler = (PFNGLBINDSAMPLERPROC) get_proc("glBindSampler"); + gl3wSamplerParameteri = (PFNGLSAMPLERPARAMETERIPROC) get_proc("glSamplerParameteri"); + gl3wSamplerParameteriv = (PFNGLSAMPLERPARAMETERIVPROC) get_proc("glSamplerParameteriv"); + gl3wSamplerParameterf = (PFNGLSAMPLERPARAMETERFPROC) get_proc("glSamplerParameterf"); + gl3wSamplerParameterfv = (PFNGLSAMPLERPARAMETERFVPROC) get_proc("glSamplerParameterfv"); + gl3wSamplerParameterIiv = (PFNGLSAMPLERPARAMETERIIVPROC) get_proc("glSamplerParameterIiv"); + gl3wSamplerParameterIuiv = (PFNGLSAMPLERPARAMETERIUIVPROC) get_proc("glSamplerParameterIuiv"); + gl3wGetSamplerParameteriv = (PFNGLGETSAMPLERPARAMETERIVPROC) get_proc("glGetSamplerParameteriv"); + gl3wGetSamplerParameterIiv = (PFNGLGETSAMPLERPARAMETERIIVPROC) get_proc("glGetSamplerParameterIiv"); + gl3wGetSamplerParameterfv = (PFNGLGETSAMPLERPARAMETERFVPROC) get_proc("glGetSamplerParameterfv"); + gl3wGetSamplerParameterIuiv = (PFNGLGETSAMPLERPARAMETERIUIVPROC) get_proc("glGetSamplerParameterIuiv"); + gl3wQueryCounter = (PFNGLQUERYCOUNTERPROC) get_proc("glQueryCounter"); + gl3wGetQueryObjecti64v = (PFNGLGETQUERYOBJECTI64VPROC) get_proc("glGetQueryObjecti64v"); + gl3wGetQueryObjectui64v = (PFNGLGETQUERYOBJECTUI64VPROC) get_proc("glGetQueryObjectui64v"); + gl3wVertexP2ui = (PFNGLVERTEXP2UIPROC) get_proc("glVertexP2ui"); + gl3wVertexP2uiv = (PFNGLVERTEXP2UIVPROC) get_proc("glVertexP2uiv"); + gl3wVertexP3ui = (PFNGLVERTEXP3UIPROC) get_proc("glVertexP3ui"); + gl3wVertexP3uiv = (PFNGLVERTEXP3UIVPROC) get_proc("glVertexP3uiv"); + gl3wVertexP4ui = (PFNGLVERTEXP4UIPROC) get_proc("glVertexP4ui"); + gl3wVertexP4uiv = (PFNGLVERTEXP4UIVPROC) get_proc("glVertexP4uiv"); + gl3wTexCoordP1ui = (PFNGLTEXCOORDP1UIPROC) get_proc("glTexCoordP1ui"); + gl3wTexCoordP1uiv = (PFNGLTEXCOORDP1UIVPROC) get_proc("glTexCoordP1uiv"); + gl3wTexCoordP2ui = (PFNGLTEXCOORDP2UIPROC) get_proc("glTexCoordP2ui"); + gl3wTexCoordP2uiv = (PFNGLTEXCOORDP2UIVPROC) get_proc("glTexCoordP2uiv"); + gl3wTexCoordP3ui = (PFNGLTEXCOORDP3UIPROC) get_proc("glTexCoordP3ui"); + gl3wTexCoordP3uiv = (PFNGLTEXCOORDP3UIVPROC) get_proc("glTexCoordP3uiv"); + gl3wTexCoordP4ui = (PFNGLTEXCOORDP4UIPROC) get_proc("glTexCoordP4ui"); + gl3wTexCoordP4uiv = (PFNGLTEXCOORDP4UIVPROC) get_proc("glTexCoordP4uiv"); + gl3wMultiTexCoordP1ui = (PFNGLMULTITEXCOORDP1UIPROC) get_proc("glMultiTexCoordP1ui"); + gl3wMultiTexCoordP1uiv = (PFNGLMULTITEXCOORDP1UIVPROC) get_proc("glMultiTexCoordP1uiv"); + gl3wMultiTexCoordP2ui = (PFNGLMULTITEXCOORDP2UIPROC) get_proc("glMultiTexCoordP2ui"); + gl3wMultiTexCoordP2uiv = (PFNGLMULTITEXCOORDP2UIVPROC) get_proc("glMultiTexCoordP2uiv"); + gl3wMultiTexCoordP3ui = (PFNGLMULTITEXCOORDP3UIPROC) get_proc("glMultiTexCoordP3ui"); + gl3wMultiTexCoordP3uiv = (PFNGLMULTITEXCOORDP3UIVPROC) get_proc("glMultiTexCoordP3uiv"); + gl3wMultiTexCoordP4ui = (PFNGLMULTITEXCOORDP4UIPROC) get_proc("glMultiTexCoordP4ui"); + gl3wMultiTexCoordP4uiv = (PFNGLMULTITEXCOORDP4UIVPROC) get_proc("glMultiTexCoordP4uiv"); + gl3wNormalP3ui = (PFNGLNORMALP3UIPROC) get_proc("glNormalP3ui"); + gl3wNormalP3uiv = (PFNGLNORMALP3UIVPROC) get_proc("glNormalP3uiv"); + gl3wColorP3ui = (PFNGLCOLORP3UIPROC) get_proc("glColorP3ui"); + gl3wColorP3uiv = (PFNGLCOLORP3UIVPROC) get_proc("glColorP3uiv"); + gl3wColorP4ui = (PFNGLCOLORP4UIPROC) get_proc("glColorP4ui"); + gl3wColorP4uiv = (PFNGLCOLORP4UIVPROC) get_proc("glColorP4uiv"); + gl3wSecondaryColorP3ui = (PFNGLSECONDARYCOLORP3UIPROC) get_proc("glSecondaryColorP3ui"); + gl3wSecondaryColorP3uiv = (PFNGLSECONDARYCOLORP3UIVPROC) get_proc("glSecondaryColorP3uiv"); + gl3wVertexAttribP1ui = (PFNGLVERTEXATTRIBP1UIPROC) get_proc("glVertexAttribP1ui"); + gl3wVertexAttribP1uiv = (PFNGLVERTEXATTRIBP1UIVPROC) get_proc("glVertexAttribP1uiv"); + gl3wVertexAttribP2ui = (PFNGLVERTEXATTRIBP2UIPROC) get_proc("glVertexAttribP2ui"); + gl3wVertexAttribP2uiv = (PFNGLVERTEXATTRIBP2UIVPROC) get_proc("glVertexAttribP2uiv"); + gl3wVertexAttribP3ui = (PFNGLVERTEXATTRIBP3UIPROC) get_proc("glVertexAttribP3ui"); + gl3wVertexAttribP3uiv = (PFNGLVERTEXATTRIBP3UIVPROC) get_proc("glVertexAttribP3uiv"); + gl3wVertexAttribP4ui = (PFNGLVERTEXATTRIBP4UIPROC) get_proc("glVertexAttribP4ui"); + gl3wVertexAttribP4uiv = (PFNGLVERTEXATTRIBP4UIVPROC) get_proc("glVertexAttribP4uiv"); + gl3wDrawArraysIndirect = (PFNGLDRAWARRAYSINDIRECTPROC) get_proc("glDrawArraysIndirect"); + gl3wDrawElementsIndirect = (PFNGLDRAWELEMENTSINDIRECTPROC) get_proc("glDrawElementsIndirect"); + gl3wUniform1d = (PFNGLUNIFORM1DPROC) get_proc("glUniform1d"); + gl3wUniform2d = (PFNGLUNIFORM2DPROC) get_proc("glUniform2d"); + gl3wUniform3d = (PFNGLUNIFORM3DPROC) get_proc("glUniform3d"); + gl3wUniform4d = (PFNGLUNIFORM4DPROC) get_proc("glUniform4d"); + gl3wUniform1dv = (PFNGLUNIFORM1DVPROC) get_proc("glUniform1dv"); + gl3wUniform2dv = (PFNGLUNIFORM2DVPROC) get_proc("glUniform2dv"); + gl3wUniform3dv = (PFNGLUNIFORM3DVPROC) get_proc("glUniform3dv"); + gl3wUniform4dv = (PFNGLUNIFORM4DVPROC) get_proc("glUniform4dv"); + gl3wUniformMatrix2dv = (PFNGLUNIFORMMATRIX2DVPROC) get_proc("glUniformMatrix2dv"); + gl3wUniformMatrix3dv = (PFNGLUNIFORMMATRIX3DVPROC) get_proc("glUniformMatrix3dv"); + gl3wUniformMatrix4dv = (PFNGLUNIFORMMATRIX4DVPROC) get_proc("glUniformMatrix4dv"); + gl3wUniformMatrix2x3dv = (PFNGLUNIFORMMATRIX2X3DVPROC) get_proc("glUniformMatrix2x3dv"); + gl3wUniformMatrix2x4dv = (PFNGLUNIFORMMATRIX2X4DVPROC) get_proc("glUniformMatrix2x4dv"); + gl3wUniformMatrix3x2dv = (PFNGLUNIFORMMATRIX3X2DVPROC) get_proc("glUniformMatrix3x2dv"); + gl3wUniformMatrix3x4dv = (PFNGLUNIFORMMATRIX3X4DVPROC) get_proc("glUniformMatrix3x4dv"); + gl3wUniformMatrix4x2dv = (PFNGLUNIFORMMATRIX4X2DVPROC) get_proc("glUniformMatrix4x2dv"); + gl3wUniformMatrix4x3dv = (PFNGLUNIFORMMATRIX4X3DVPROC) get_proc("glUniformMatrix4x3dv"); + gl3wGetUniformdv = (PFNGLGETUNIFORMDVPROC) get_proc("glGetUniformdv"); + gl3wGetSubroutineUniformLocation = (PFNGLGETSUBROUTINEUNIFORMLOCATIONPROC) get_proc("glGetSubroutineUniformLocation"); + gl3wGetSubroutineIndex = (PFNGLGETSUBROUTINEINDEXPROC) get_proc("glGetSubroutineIndex"); + gl3wGetActiveSubroutineUniformiv = (PFNGLGETACTIVESUBROUTINEUNIFORMIVPROC) get_proc("glGetActiveSubroutineUniformiv"); + gl3wGetActiveSubroutineUniformName = (PFNGLGETACTIVESUBROUTINEUNIFORMNAMEPROC) get_proc("glGetActiveSubroutineUniformName"); + gl3wGetActiveSubroutineName = (PFNGLGETACTIVESUBROUTINENAMEPROC) get_proc("glGetActiveSubroutineName"); + gl3wUniformSubroutinesuiv = (PFNGLUNIFORMSUBROUTINESUIVPROC) get_proc("glUniformSubroutinesuiv"); + gl3wGetUniformSubroutineuiv = (PFNGLGETUNIFORMSUBROUTINEUIVPROC) get_proc("glGetUniformSubroutineuiv"); + gl3wGetProgramStageiv = (PFNGLGETPROGRAMSTAGEIVPROC) get_proc("glGetProgramStageiv"); + gl3wPatchParameteri = (PFNGLPATCHPARAMETERIPROC) get_proc("glPatchParameteri"); + gl3wPatchParameterfv = (PFNGLPATCHPARAMETERFVPROC) get_proc("glPatchParameterfv"); + gl3wBindTransformFeedback = (PFNGLBINDTRANSFORMFEEDBACKPROC) get_proc("glBindTransformFeedback"); + gl3wDeleteTransformFeedbacks = (PFNGLDELETETRANSFORMFEEDBACKSPROC) get_proc("glDeleteTransformFeedbacks"); + gl3wGenTransformFeedbacks = (PFNGLGENTRANSFORMFEEDBACKSPROC) get_proc("glGenTransformFeedbacks"); + gl3wIsTransformFeedback = (PFNGLISTRANSFORMFEEDBACKPROC) get_proc("glIsTransformFeedback"); + gl3wPauseTransformFeedback = (PFNGLPAUSETRANSFORMFEEDBACKPROC) get_proc("glPauseTransformFeedback"); + gl3wResumeTransformFeedback = (PFNGLRESUMETRANSFORMFEEDBACKPROC) get_proc("glResumeTransformFeedback"); + gl3wDrawTransformFeedback = (PFNGLDRAWTRANSFORMFEEDBACKPROC) get_proc("glDrawTransformFeedback"); + gl3wDrawTransformFeedbackStream = (PFNGLDRAWTRANSFORMFEEDBACKSTREAMPROC) get_proc("glDrawTransformFeedbackStream"); + gl3wBeginQueryIndexed = (PFNGLBEGINQUERYINDEXEDPROC) get_proc("glBeginQueryIndexed"); + gl3wEndQueryIndexed = (PFNGLENDQUERYINDEXEDPROC) get_proc("glEndQueryIndexed"); + gl3wGetQueryIndexediv = (PFNGLGETQUERYINDEXEDIVPROC) get_proc("glGetQueryIndexediv"); + gl3wReleaseShaderCompiler = (PFNGLRELEASESHADERCOMPILERPROC) get_proc("glReleaseShaderCompiler"); + gl3wShaderBinary = (PFNGLSHADERBINARYPROC) get_proc("glShaderBinary"); + gl3wGetShaderPrecisionFormat = (PFNGLGETSHADERPRECISIONFORMATPROC) get_proc("glGetShaderPrecisionFormat"); + gl3wDepthRangef = (PFNGLDEPTHRANGEFPROC) get_proc("glDepthRangef"); + gl3wClearDepthf = (PFNGLCLEARDEPTHFPROC) get_proc("glClearDepthf"); + gl3wGetProgramBinary = (PFNGLGETPROGRAMBINARYPROC) get_proc("glGetProgramBinary"); + gl3wProgramBinary = (PFNGLPROGRAMBINARYPROC) get_proc("glProgramBinary"); + gl3wProgramParameteri = (PFNGLPROGRAMPARAMETERIPROC) get_proc("glProgramParameteri"); + gl3wUseProgramStages = (PFNGLUSEPROGRAMSTAGESPROC) get_proc("glUseProgramStages"); + gl3wActiveShaderProgram = (PFNGLACTIVESHADERPROGRAMPROC) get_proc("glActiveShaderProgram"); + gl3wCreateShaderProgramv = (PFNGLCREATESHADERPROGRAMVPROC) get_proc("glCreateShaderProgramv"); + gl3wBindProgramPipeline = (PFNGLBINDPROGRAMPIPELINEPROC) get_proc("glBindProgramPipeline"); + gl3wDeleteProgramPipelines = (PFNGLDELETEPROGRAMPIPELINESPROC) get_proc("glDeleteProgramPipelines"); + gl3wGenProgramPipelines = (PFNGLGENPROGRAMPIPELINESPROC) get_proc("glGenProgramPipelines"); + gl3wIsProgramPipeline = (PFNGLISPROGRAMPIPELINEPROC) get_proc("glIsProgramPipeline"); + gl3wGetProgramPipelineiv = (PFNGLGETPROGRAMPIPELINEIVPROC) get_proc("glGetProgramPipelineiv"); + gl3wProgramUniform1i = (PFNGLPROGRAMUNIFORM1IPROC) get_proc("glProgramUniform1i"); + gl3wProgramUniform1iv = (PFNGLPROGRAMUNIFORM1IVPROC) get_proc("glProgramUniform1iv"); + gl3wProgramUniform1f = (PFNGLPROGRAMUNIFORM1FPROC) get_proc("glProgramUniform1f"); + gl3wProgramUniform1fv = (PFNGLPROGRAMUNIFORM1FVPROC) get_proc("glProgramUniform1fv"); + gl3wProgramUniform1d = (PFNGLPROGRAMUNIFORM1DPROC) get_proc("glProgramUniform1d"); + gl3wProgramUniform1dv = (PFNGLPROGRAMUNIFORM1DVPROC) get_proc("glProgramUniform1dv"); + gl3wProgramUniform1ui = (PFNGLPROGRAMUNIFORM1UIPROC) get_proc("glProgramUniform1ui"); + gl3wProgramUniform1uiv = (PFNGLPROGRAMUNIFORM1UIVPROC) get_proc("glProgramUniform1uiv"); + gl3wProgramUniform2i = (PFNGLPROGRAMUNIFORM2IPROC) get_proc("glProgramUniform2i"); + gl3wProgramUniform2iv = (PFNGLPROGRAMUNIFORM2IVPROC) get_proc("glProgramUniform2iv"); + gl3wProgramUniform2f = (PFNGLPROGRAMUNIFORM2FPROC) get_proc("glProgramUniform2f"); + gl3wProgramUniform2fv = (PFNGLPROGRAMUNIFORM2FVPROC) get_proc("glProgramUniform2fv"); + gl3wProgramUniform2d = (PFNGLPROGRAMUNIFORM2DPROC) get_proc("glProgramUniform2d"); + gl3wProgramUniform2dv = (PFNGLPROGRAMUNIFORM2DVPROC) get_proc("glProgramUniform2dv"); + gl3wProgramUniform2ui = (PFNGLPROGRAMUNIFORM2UIPROC) get_proc("glProgramUniform2ui"); + gl3wProgramUniform2uiv = (PFNGLPROGRAMUNIFORM2UIVPROC) get_proc("glProgramUniform2uiv"); + gl3wProgramUniform3i = (PFNGLPROGRAMUNIFORM3IPROC) get_proc("glProgramUniform3i"); + gl3wProgramUniform3iv = (PFNGLPROGRAMUNIFORM3IVPROC) get_proc("glProgramUniform3iv"); + gl3wProgramUniform3f = (PFNGLPROGRAMUNIFORM3FPROC) get_proc("glProgramUniform3f"); + gl3wProgramUniform3fv = (PFNGLPROGRAMUNIFORM3FVPROC) get_proc("glProgramUniform3fv"); + gl3wProgramUniform3d = (PFNGLPROGRAMUNIFORM3DPROC) get_proc("glProgramUniform3d"); + gl3wProgramUniform3dv = (PFNGLPROGRAMUNIFORM3DVPROC) get_proc("glProgramUniform3dv"); + gl3wProgramUniform3ui = (PFNGLPROGRAMUNIFORM3UIPROC) get_proc("glProgramUniform3ui"); + gl3wProgramUniform3uiv = (PFNGLPROGRAMUNIFORM3UIVPROC) get_proc("glProgramUniform3uiv"); + gl3wProgramUniform4i = (PFNGLPROGRAMUNIFORM4IPROC) get_proc("glProgramUniform4i"); + gl3wProgramUniform4iv = (PFNGLPROGRAMUNIFORM4IVPROC) get_proc("glProgramUniform4iv"); + gl3wProgramUniform4f = (PFNGLPROGRAMUNIFORM4FPROC) get_proc("glProgramUniform4f"); + gl3wProgramUniform4fv = (PFNGLPROGRAMUNIFORM4FVPROC) get_proc("glProgramUniform4fv"); + gl3wProgramUniform4d = (PFNGLPROGRAMUNIFORM4DPROC) get_proc("glProgramUniform4d"); + gl3wProgramUniform4dv = (PFNGLPROGRAMUNIFORM4DVPROC) get_proc("glProgramUniform4dv"); + gl3wProgramUniform4ui = (PFNGLPROGRAMUNIFORM4UIPROC) get_proc("glProgramUniform4ui"); + gl3wProgramUniform4uiv = (PFNGLPROGRAMUNIFORM4UIVPROC) get_proc("glProgramUniform4uiv"); + gl3wProgramUniformMatrix2fv = (PFNGLPROGRAMUNIFORMMATRIX2FVPROC) get_proc("glProgramUniformMatrix2fv"); + gl3wProgramUniformMatrix3fv = (PFNGLPROGRAMUNIFORMMATRIX3FVPROC) get_proc("glProgramUniformMatrix3fv"); + gl3wProgramUniformMatrix4fv = (PFNGLPROGRAMUNIFORMMATRIX4FVPROC) get_proc("glProgramUniformMatrix4fv"); + gl3wProgramUniformMatrix2dv = (PFNGLPROGRAMUNIFORMMATRIX2DVPROC) get_proc("glProgramUniformMatrix2dv"); + gl3wProgramUniformMatrix3dv = (PFNGLPROGRAMUNIFORMMATRIX3DVPROC) get_proc("glProgramUniformMatrix3dv"); + gl3wProgramUniformMatrix4dv = (PFNGLPROGRAMUNIFORMMATRIX4DVPROC) get_proc("glProgramUniformMatrix4dv"); + gl3wProgramUniformMatrix2x3fv = (PFNGLPROGRAMUNIFORMMATRIX2X3FVPROC) get_proc("glProgramUniformMatrix2x3fv"); + gl3wProgramUniformMatrix3x2fv = (PFNGLPROGRAMUNIFORMMATRIX3X2FVPROC) get_proc("glProgramUniformMatrix3x2fv"); + gl3wProgramUniformMatrix2x4fv = (PFNGLPROGRAMUNIFORMMATRIX2X4FVPROC) get_proc("glProgramUniformMatrix2x4fv"); + gl3wProgramUniformMatrix4x2fv = (PFNGLPROGRAMUNIFORMMATRIX4X2FVPROC) get_proc("glProgramUniformMatrix4x2fv"); + gl3wProgramUniformMatrix3x4fv = (PFNGLPROGRAMUNIFORMMATRIX3X4FVPROC) get_proc("glProgramUniformMatrix3x4fv"); + gl3wProgramUniformMatrix4x3fv = (PFNGLPROGRAMUNIFORMMATRIX4X3FVPROC) get_proc("glProgramUniformMatrix4x3fv"); + gl3wProgramUniformMatrix2x3dv = (PFNGLPROGRAMUNIFORMMATRIX2X3DVPROC) get_proc("glProgramUniformMatrix2x3dv"); + gl3wProgramUniformMatrix3x2dv = (PFNGLPROGRAMUNIFORMMATRIX3X2DVPROC) get_proc("glProgramUniformMatrix3x2dv"); + gl3wProgramUniformMatrix2x4dv = (PFNGLPROGRAMUNIFORMMATRIX2X4DVPROC) get_proc("glProgramUniformMatrix2x4dv"); + gl3wProgramUniformMatrix4x2dv = (PFNGLPROGRAMUNIFORMMATRIX4X2DVPROC) get_proc("glProgramUniformMatrix4x2dv"); + gl3wProgramUniformMatrix3x4dv = (PFNGLPROGRAMUNIFORMMATRIX3X4DVPROC) get_proc("glProgramUniformMatrix3x4dv"); + gl3wProgramUniformMatrix4x3dv = (PFNGLPROGRAMUNIFORMMATRIX4X3DVPROC) get_proc("glProgramUniformMatrix4x3dv"); + gl3wValidateProgramPipeline = (PFNGLVALIDATEPROGRAMPIPELINEPROC) get_proc("glValidateProgramPipeline"); + gl3wGetProgramPipelineInfoLog = (PFNGLGETPROGRAMPIPELINEINFOLOGPROC) get_proc("glGetProgramPipelineInfoLog"); + gl3wVertexAttribL1d = (PFNGLVERTEXATTRIBL1DPROC) get_proc("glVertexAttribL1d"); + gl3wVertexAttribL2d = (PFNGLVERTEXATTRIBL2DPROC) get_proc("glVertexAttribL2d"); + gl3wVertexAttribL3d = (PFNGLVERTEXATTRIBL3DPROC) get_proc("glVertexAttribL3d"); + gl3wVertexAttribL4d = (PFNGLVERTEXATTRIBL4DPROC) get_proc("glVertexAttribL4d"); + gl3wVertexAttribL1dv = (PFNGLVERTEXATTRIBL1DVPROC) get_proc("glVertexAttribL1dv"); + gl3wVertexAttribL2dv = (PFNGLVERTEXATTRIBL2DVPROC) get_proc("glVertexAttribL2dv"); + gl3wVertexAttribL3dv = (PFNGLVERTEXATTRIBL3DVPROC) get_proc("glVertexAttribL3dv"); + gl3wVertexAttribL4dv = (PFNGLVERTEXATTRIBL4DVPROC) get_proc("glVertexAttribL4dv"); + gl3wVertexAttribLPointer = (PFNGLVERTEXATTRIBLPOINTERPROC) get_proc("glVertexAttribLPointer"); + gl3wGetVertexAttribLdv = (PFNGLGETVERTEXATTRIBLDVPROC) get_proc("glGetVertexAttribLdv"); + gl3wViewportArrayv = (PFNGLVIEWPORTARRAYVPROC) get_proc("glViewportArrayv"); + gl3wViewportIndexedf = (PFNGLVIEWPORTINDEXEDFPROC) get_proc("glViewportIndexedf"); + gl3wViewportIndexedfv = (PFNGLVIEWPORTINDEXEDFVPROC) get_proc("glViewportIndexedfv"); + gl3wScissorArrayv = (PFNGLSCISSORARRAYVPROC) get_proc("glScissorArrayv"); + gl3wScissorIndexed = (PFNGLSCISSORINDEXEDPROC) get_proc("glScissorIndexed"); + gl3wScissorIndexedv = (PFNGLSCISSORINDEXEDVPROC) get_proc("glScissorIndexedv"); + gl3wDepthRangeArrayv = (PFNGLDEPTHRANGEARRAYVPROC) get_proc("glDepthRangeArrayv"); + gl3wDepthRangeIndexed = (PFNGLDEPTHRANGEINDEXEDPROC) get_proc("glDepthRangeIndexed"); + gl3wGetFloati_v = (PFNGLGETFLOATI_VPROC) get_proc("glGetFloati_v"); + gl3wGetDoublei_v = (PFNGLGETDOUBLEI_VPROC) get_proc("glGetDoublei_v"); + gl3wCreateSyncFromCLeventARB = (PFNGLCREATESYNCFROMCLEVENTARBPROC) get_proc("glCreateSyncFromCLeventARB"); + gl3wDebugMessageControlARB = (PFNGLDEBUGMESSAGECONTROLARBPROC) get_proc("glDebugMessageControlARB"); + gl3wDebugMessageInsertARB = (PFNGLDEBUGMESSAGEINSERTARBPROC) get_proc("glDebugMessageInsertARB"); + gl3wDebugMessageCallbackARB = (PFNGLDEBUGMESSAGECALLBACKARBPROC) get_proc("glDebugMessageCallbackARB"); + gl3wGetDebugMessageLogARB = (PFNGLGETDEBUGMESSAGELOGARBPROC) get_proc("glGetDebugMessageLogARB"); + gl3wGetGraphicsResetStatusARB = (PFNGLGETGRAPHICSRESETSTATUSARBPROC) get_proc("glGetGraphicsResetStatusARB"); + gl3wGetnTexImageARB = (PFNGLGETNTEXIMAGEARBPROC) get_proc("glGetnTexImageARB"); + gl3wReadnPixelsARB = (PFNGLREADNPIXELSARBPROC) get_proc("glReadnPixelsARB"); + gl3wGetnCompressedTexImageARB = (PFNGLGETNCOMPRESSEDTEXIMAGEARBPROC) get_proc("glGetnCompressedTexImageARB"); + gl3wGetnUniformfvARB = (PFNGLGETNUNIFORMFVARBPROC) get_proc("glGetnUniformfvARB"); + gl3wGetnUniformivARB = (PFNGLGETNUNIFORMIVARBPROC) get_proc("glGetnUniformivARB"); + gl3wGetnUniformuivARB = (PFNGLGETNUNIFORMUIVARBPROC) get_proc("glGetnUniformuivARB"); + gl3wGetnUniformdvARB = (PFNGLGETNUNIFORMDVARBPROC) get_proc("glGetnUniformdvARB"); + gl3wDrawArraysInstancedBaseInstance = (PFNGLDRAWARRAYSINSTANCEDBASEINSTANCEPROC) get_proc("glDrawArraysInstancedBaseInstance"); + gl3wDrawElementsInstancedBaseInstance = (PFNGLDRAWELEMENTSINSTANCEDBASEINSTANCEPROC) get_proc("glDrawElementsInstancedBaseInstance"); + gl3wDrawElementsInstancedBaseVertexBaseInstance = (PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXBASEINSTANCEPROC) get_proc("glDrawElementsInstancedBaseVertexBaseInstance"); + gl3wDrawTransformFeedbackInstanced = (PFNGLDRAWTRANSFORMFEEDBACKINSTANCEDPROC) get_proc("glDrawTransformFeedbackInstanced"); + gl3wDrawTransformFeedbackStreamInstanced = (PFNGLDRAWTRANSFORMFEEDBACKSTREAMINSTANCEDPROC) get_proc("glDrawTransformFeedbackStreamInstanced"); + gl3wGetInternalformativ = (PFNGLGETINTERNALFORMATIVPROC) get_proc("glGetInternalformativ"); + gl3wGetActiveAtomicCounterBufferiv = (PFNGLGETACTIVEATOMICCOUNTERBUFFERIVPROC) get_proc("glGetActiveAtomicCounterBufferiv"); + gl3wBindImageTexture = (PFNGLBINDIMAGETEXTUREPROC) get_proc("glBindImageTexture"); + gl3wMemoryBarrier = (PFNGLMEMORYBARRIERPROC) get_proc("glMemoryBarrier"); + gl3wTexStorage1D = (PFNGLTEXSTORAGE1DPROC) get_proc("glTexStorage1D"); + gl3wTexStorage2D = (PFNGLTEXSTORAGE2DPROC) get_proc("glTexStorage2D"); + gl3wTexStorage3D = (PFNGLTEXSTORAGE3DPROC) get_proc("glTexStorage3D"); + gl3wTextureStorage1DEXT = (PFNGLTEXTURESTORAGE1DEXTPROC) get_proc("glTextureStorage1DEXT"); + gl3wTextureStorage2DEXT = (PFNGLTEXTURESTORAGE2DEXTPROC) get_proc("glTextureStorage2DEXT"); + gl3wTextureStorage3DEXT = (PFNGLTEXTURESTORAGE3DEXTPROC) get_proc("glTextureStorage3DEXT"); + gl3wDebugMessageControl = (PFNGLDEBUGMESSAGECONTROLPROC) get_proc("glDebugMessageControl"); + gl3wDebugMessageInsert = (PFNGLDEBUGMESSAGEINSERTPROC) get_proc("glDebugMessageInsert"); + gl3wDebugMessageCallback = (PFNGLDEBUGMESSAGECALLBACKPROC) get_proc("glDebugMessageCallback"); + gl3wGetDebugMessageLog = (PFNGLGETDEBUGMESSAGELOGPROC) get_proc("glGetDebugMessageLog"); + gl3wPushDebugGroup = (PFNGLPUSHDEBUGGROUPPROC) get_proc("glPushDebugGroup"); + gl3wPopDebugGroup = (PFNGLPOPDEBUGGROUPPROC) get_proc("glPopDebugGroup"); + gl3wObjectLabel = (PFNGLOBJECTLABELPROC) get_proc("glObjectLabel"); + gl3wGetObjectLabel = (PFNGLGETOBJECTLABELPROC) get_proc("glGetObjectLabel"); + gl3wObjectPtrLabel = (PFNGLOBJECTPTRLABELPROC) get_proc("glObjectPtrLabel"); + gl3wGetObjectPtrLabel = (PFNGLGETOBJECTPTRLABELPROC) get_proc("glGetObjectPtrLabel"); + gl3wClearBufferData = (PFNGLCLEARBUFFERDATAPROC) get_proc("glClearBufferData"); + gl3wClearBufferSubData = (PFNGLCLEARBUFFERSUBDATAPROC) get_proc("glClearBufferSubData"); + gl3wClearNamedBufferDataEXT = (PFNGLCLEARNAMEDBUFFERDATAEXTPROC) get_proc("glClearNamedBufferDataEXT"); + gl3wClearNamedBufferSubDataEXT = (PFNGLCLEARNAMEDBUFFERSUBDATAEXTPROC) get_proc("glClearNamedBufferSubDataEXT"); + gl3wDispatchCompute = (PFNGLDISPATCHCOMPUTEPROC) get_proc("glDispatchCompute"); + gl3wDispatchComputeIndirect = (PFNGLDISPATCHCOMPUTEINDIRECTPROC) get_proc("glDispatchComputeIndirect"); + gl3wCopyImageSubData = (PFNGLCOPYIMAGESUBDATAPROC) get_proc("glCopyImageSubData"); + gl3wTextureView = (PFNGLTEXTUREVIEWPROC) get_proc("glTextureView"); + gl3wBindVertexBuffer = (PFNGLBINDVERTEXBUFFERPROC) get_proc("glBindVertexBuffer"); + gl3wVertexAttribFormat = (PFNGLVERTEXATTRIBFORMATPROC) get_proc("glVertexAttribFormat"); + gl3wVertexAttribIFormat = (PFNGLVERTEXATTRIBIFORMATPROC) get_proc("glVertexAttribIFormat"); + gl3wVertexAttribLFormat = (PFNGLVERTEXATTRIBLFORMATPROC) get_proc("glVertexAttribLFormat"); + gl3wVertexAttribBinding = (PFNGLVERTEXATTRIBBINDINGPROC) get_proc("glVertexAttribBinding"); + gl3wVertexBindingDivisor = (PFNGLVERTEXBINDINGDIVISORPROC) get_proc("glVertexBindingDivisor"); + gl3wVertexArrayBindVertexBufferEXT = (PFNGLVERTEXARRAYBINDVERTEXBUFFEREXTPROC) get_proc("glVertexArrayBindVertexBufferEXT"); + gl3wVertexArrayVertexAttribFormatEXT = (PFNGLVERTEXARRAYVERTEXATTRIBFORMATEXTPROC) get_proc("glVertexArrayVertexAttribFormatEXT"); + gl3wVertexArrayVertexAttribIFormatEXT = (PFNGLVERTEXARRAYVERTEXATTRIBIFORMATEXTPROC) get_proc("glVertexArrayVertexAttribIFormatEXT"); + gl3wVertexArrayVertexAttribLFormatEXT = (PFNGLVERTEXARRAYVERTEXATTRIBLFORMATEXTPROC) get_proc("glVertexArrayVertexAttribLFormatEXT"); + gl3wVertexArrayVertexAttribBindingEXT = (PFNGLVERTEXARRAYVERTEXATTRIBBINDINGEXTPROC) get_proc("glVertexArrayVertexAttribBindingEXT"); + gl3wVertexArrayVertexBindingDivisorEXT = (PFNGLVERTEXARRAYVERTEXBINDINGDIVISOREXTPROC) get_proc("glVertexArrayVertexBindingDivisorEXT"); + gl3wFramebufferParameteri = (PFNGLFRAMEBUFFERPARAMETERIPROC) get_proc("glFramebufferParameteri"); + gl3wGetFramebufferParameteriv = (PFNGLGETFRAMEBUFFERPARAMETERIVPROC) get_proc("glGetFramebufferParameteriv"); + gl3wNamedFramebufferParameteriEXT = (PFNGLNAMEDFRAMEBUFFERPARAMETERIEXTPROC) get_proc("glNamedFramebufferParameteriEXT"); + gl3wGetNamedFramebufferParameterivEXT = (PFNGLGETNAMEDFRAMEBUFFERPARAMETERIVEXTPROC) get_proc("glGetNamedFramebufferParameterivEXT"); + gl3wGetInternalformati64v = (PFNGLGETINTERNALFORMATI64VPROC) get_proc("glGetInternalformati64v"); + gl3wInvalidateTexSubImage = (PFNGLINVALIDATETEXSUBIMAGEPROC) get_proc("glInvalidateTexSubImage"); + gl3wInvalidateTexImage = (PFNGLINVALIDATETEXIMAGEPROC) get_proc("glInvalidateTexImage"); + gl3wInvalidateBufferSubData = (PFNGLINVALIDATEBUFFERSUBDATAPROC) get_proc("glInvalidateBufferSubData"); + gl3wInvalidateBufferData = (PFNGLINVALIDATEBUFFERDATAPROC) get_proc("glInvalidateBufferData"); + gl3wInvalidateFramebuffer = (PFNGLINVALIDATEFRAMEBUFFERPROC) get_proc("glInvalidateFramebuffer"); + gl3wInvalidateSubFramebuffer = (PFNGLINVALIDATESUBFRAMEBUFFERPROC) get_proc("glInvalidateSubFramebuffer"); + gl3wMultiDrawArraysIndirect = (PFNGLMULTIDRAWARRAYSINDIRECTPROC) get_proc("glMultiDrawArraysIndirect"); + gl3wMultiDrawElementsIndirect = (PFNGLMULTIDRAWELEMENTSINDIRECTPROC) get_proc("glMultiDrawElementsIndirect"); + gl3wGetProgramInterfaceiv = (PFNGLGETPROGRAMINTERFACEIVPROC) get_proc("glGetProgramInterfaceiv"); + gl3wGetProgramResourceIndex = (PFNGLGETPROGRAMRESOURCEINDEXPROC) get_proc("glGetProgramResourceIndex"); + gl3wGetProgramResourceName = (PFNGLGETPROGRAMRESOURCENAMEPROC) get_proc("glGetProgramResourceName"); + gl3wGetProgramResourceiv = (PFNGLGETPROGRAMRESOURCEIVPROC) get_proc("glGetProgramResourceiv"); + gl3wGetProgramResourceLocation = (PFNGLGETPROGRAMRESOURCELOCATIONPROC) get_proc("glGetProgramResourceLocation"); + gl3wGetProgramResourceLocationIndex = (PFNGLGETPROGRAMRESOURCELOCATIONINDEXPROC) get_proc("glGetProgramResourceLocationIndex"); + gl3wShaderStorageBlockBinding = (PFNGLSHADERSTORAGEBLOCKBINDINGPROC) get_proc("glShaderStorageBlockBinding"); + gl3wTexBufferRange = (PFNGLTEXBUFFERRANGEPROC) get_proc("glTexBufferRange"); + gl3wTextureBufferRangeEXT = (PFNGLTEXTUREBUFFERRANGEEXTPROC) get_proc("glTextureBufferRangeEXT"); + gl3wTexStorage2DMultisample = (PFNGLTEXSTORAGE2DMULTISAMPLEPROC) get_proc("glTexStorage2DMultisample"); + gl3wTexStorage3DMultisample = (PFNGLTEXSTORAGE3DMULTISAMPLEPROC) get_proc("glTexStorage3DMultisample"); + gl3wTextureStorage2DMultisampleEXT = (PFNGLTEXTURESTORAGE2DMULTISAMPLEEXTPROC) get_proc("glTextureStorage2DMultisampleEXT"); + gl3wTextureStorage3DMultisampleEXT = (PFNGLTEXTURESTORAGE3DMULTISAMPLEEXTPROC) get_proc("glTextureStorage3DMultisampleEXT"); +} diff --git a/gui/dependencies/gl3w/GL/gl3w.h b/gui/dependencies/gl3w/GL/gl3w.h new file mode 100644 index 0000000000000000000000000000000000000000..ee563f8d401cff2861103aa61bd1f27d16d44708 --- /dev/null +++ b/gui/dependencies/gl3w/GL/gl3w.h @@ -0,0 +1,1234 @@ +#ifndef __gl3w_h_ +#define __gl3w_h_ + +#include + +#ifndef __gl_h_ +#define __gl_h_ +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/* gl3w api */ +int gl3wInit(void); +int gl3wIsSupported(int major, int minor); +void *gl3wGetProcAddress(const char *proc); + +/* OpenGL functions */ +extern PFNGLCULLFACEPROC gl3wCullFace; +extern PFNGLFRONTFACEPROC gl3wFrontFace; +extern PFNGLHINTPROC gl3wHint; +extern PFNGLLINEWIDTHPROC gl3wLineWidth; +extern PFNGLPOINTSIZEPROC gl3wPointSize; +extern PFNGLPOLYGONMODEPROC gl3wPolygonMode; +extern PFNGLSCISSORPROC gl3wScissor; +extern PFNGLTEXPARAMETERFPROC gl3wTexParameterf; +extern PFNGLTEXPARAMETERFVPROC gl3wTexParameterfv; +extern PFNGLTEXPARAMETERIPROC gl3wTexParameteri; +extern PFNGLTEXPARAMETERIVPROC gl3wTexParameteriv; +extern PFNGLTEXIMAGE1DPROC gl3wTexImage1D; +extern PFNGLTEXIMAGE2DPROC gl3wTexImage2D; +extern PFNGLDRAWBUFFERPROC gl3wDrawBuffer; +extern PFNGLCLEARPROC gl3wClear; +extern PFNGLCLEARCOLORPROC gl3wClearColor; +extern PFNGLCLEARSTENCILPROC gl3wClearStencil; +extern PFNGLCLEARDEPTHPROC gl3wClearDepth; +extern PFNGLSTENCILMASKPROC gl3wStencilMask; +extern PFNGLCOLORMASKPROC gl3wColorMask; +extern PFNGLDEPTHMASKPROC gl3wDepthMask; +extern PFNGLDISABLEPROC gl3wDisable; +extern PFNGLENABLEPROC gl3wEnable; +extern PFNGLFINISHPROC gl3wFinish; +extern PFNGLFLUSHPROC gl3wFlush; +extern PFNGLBLENDFUNCPROC gl3wBlendFunc; +extern PFNGLLOGICOPPROC gl3wLogicOp; +extern PFNGLSTENCILFUNCPROC gl3wStencilFunc; +extern PFNGLSTENCILOPPROC gl3wStencilOp; +extern PFNGLDEPTHFUNCPROC gl3wDepthFunc; +extern PFNGLPIXELSTOREFPROC gl3wPixelStoref; +extern PFNGLPIXELSTOREIPROC gl3wPixelStorei; +extern PFNGLREADBUFFERPROC gl3wReadBuffer; +extern PFNGLREADPIXELSPROC gl3wReadPixels; +extern PFNGLGETBOOLEANVPROC gl3wGetBooleanv; +extern PFNGLGETDOUBLEVPROC gl3wGetDoublev; +extern PFNGLGETERRORPROC gl3wGetError; +extern PFNGLGETFLOATVPROC gl3wGetFloatv; +extern PFNGLGETINTEGERVPROC gl3wGetIntegerv; +extern PFNGLGETSTRINGPROC gl3wGetString; +extern PFNGLGETTEXIMAGEPROC gl3wGetTexImage; +extern PFNGLGETTEXPARAMETERFVPROC gl3wGetTexParameterfv; +extern PFNGLGETTEXPARAMETERIVPROC gl3wGetTexParameteriv; +extern PFNGLGETTEXLEVELPARAMETERFVPROC gl3wGetTexLevelParameterfv; +extern PFNGLGETTEXLEVELPARAMETERIVPROC gl3wGetTexLevelParameteriv; +extern PFNGLISENABLEDPROC gl3wIsEnabled; +extern PFNGLDEPTHRANGEPROC gl3wDepthRange; +extern PFNGLVIEWPORTPROC gl3wViewport; +extern PFNGLDRAWARRAYSPROC gl3wDrawArrays; +extern PFNGLDRAWELEMENTSPROC gl3wDrawElements; +extern PFNGLGETPOINTERVPROC gl3wGetPointerv; +extern PFNGLPOLYGONOFFSETPROC gl3wPolygonOffset; +extern PFNGLCOPYTEXIMAGE1DPROC gl3wCopyTexImage1D; +extern PFNGLCOPYTEXIMAGE2DPROC gl3wCopyTexImage2D; +extern PFNGLCOPYTEXSUBIMAGE1DPROC gl3wCopyTexSubImage1D; +extern PFNGLCOPYTEXSUBIMAGE2DPROC gl3wCopyTexSubImage2D; +extern PFNGLTEXSUBIMAGE1DPROC gl3wTexSubImage1D; +extern PFNGLTEXSUBIMAGE2DPROC gl3wTexSubImage2D; +extern PFNGLBINDTEXTUREPROC gl3wBindTexture; +extern PFNGLDELETETEXTURESPROC gl3wDeleteTextures; +extern PFNGLGENTEXTURESPROC gl3wGenTextures; +extern PFNGLISTEXTUREPROC gl3wIsTexture; +extern PFNGLBLENDCOLORPROC gl3wBlendColor; +extern PFNGLBLENDEQUATIONPROC gl3wBlendEquation; +extern PFNGLDRAWRANGEELEMENTSPROC gl3wDrawRangeElements; +extern PFNGLTEXIMAGE3DPROC gl3wTexImage3D; +extern PFNGLTEXSUBIMAGE3DPROC gl3wTexSubImage3D; +extern PFNGLCOPYTEXSUBIMAGE3DPROC gl3wCopyTexSubImage3D; +extern PFNGLACTIVETEXTUREPROC gl3wActiveTexture; +extern PFNGLSAMPLECOVERAGEPROC gl3wSampleCoverage; +extern PFNGLCOMPRESSEDTEXIMAGE3DPROC gl3wCompressedTexImage3D; +extern PFNGLCOMPRESSEDTEXIMAGE2DPROC gl3wCompressedTexImage2D; +extern PFNGLCOMPRESSEDTEXIMAGE1DPROC gl3wCompressedTexImage1D; +extern PFNGLCOMPRESSEDTEXSUBIMAGE3DPROC gl3wCompressedTexSubImage3D; +extern PFNGLCOMPRESSEDTEXSUBIMAGE2DPROC gl3wCompressedTexSubImage2D; +extern PFNGLCOMPRESSEDTEXSUBIMAGE1DPROC gl3wCompressedTexSubImage1D; +extern PFNGLGETCOMPRESSEDTEXIMAGEPROC gl3wGetCompressedTexImage; +extern PFNGLBLENDFUNCSEPARATEPROC gl3wBlendFuncSeparate; +extern PFNGLMULTIDRAWARRAYSPROC gl3wMultiDrawArrays; +extern PFNGLMULTIDRAWELEMENTSPROC gl3wMultiDrawElements; +extern PFNGLPOINTPARAMETERFPROC gl3wPointParameterf; +extern PFNGLPOINTPARAMETERFVPROC gl3wPointParameterfv; +extern PFNGLPOINTPARAMETERIPROC gl3wPointParameteri; +extern PFNGLPOINTPARAMETERIVPROC gl3wPointParameteriv; +extern PFNGLGENQUERIESPROC gl3wGenQueries; +extern PFNGLDELETEQUERIESPROC gl3wDeleteQueries; +extern PFNGLISQUERYPROC gl3wIsQuery; +extern PFNGLBEGINQUERYPROC gl3wBeginQuery; +extern PFNGLENDQUERYPROC gl3wEndQuery; +extern PFNGLGETQUERYIVPROC gl3wGetQueryiv; +extern PFNGLGETQUERYOBJECTIVPROC gl3wGetQueryObjectiv; +extern PFNGLGETQUERYOBJECTUIVPROC gl3wGetQueryObjectuiv; +extern PFNGLBINDBUFFERPROC gl3wBindBuffer; +extern PFNGLDELETEBUFFERSPROC gl3wDeleteBuffers; +extern PFNGLGENBUFFERSPROC gl3wGenBuffers; +extern PFNGLISBUFFERPROC gl3wIsBuffer; +extern PFNGLBUFFERDATAPROC gl3wBufferData; +extern PFNGLBUFFERSUBDATAPROC gl3wBufferSubData; +extern PFNGLGETBUFFERSUBDATAPROC gl3wGetBufferSubData; +extern PFNGLMAPBUFFERPROC gl3wMapBuffer; +extern PFNGLUNMAPBUFFERPROC gl3wUnmapBuffer; +extern PFNGLGETBUFFERPARAMETERIVPROC gl3wGetBufferParameteriv; +extern PFNGLGETBUFFERPOINTERVPROC gl3wGetBufferPointerv; +extern PFNGLBLENDEQUATIONSEPARATEPROC gl3wBlendEquationSeparate; +extern PFNGLDRAWBUFFERSPROC gl3wDrawBuffers; +extern PFNGLSTENCILOPSEPARATEPROC gl3wStencilOpSeparate; +extern PFNGLSTENCILFUNCSEPARATEPROC gl3wStencilFuncSeparate; +extern PFNGLSTENCILMASKSEPARATEPROC gl3wStencilMaskSeparate; +extern PFNGLATTACHSHADERPROC gl3wAttachShader; +extern PFNGLBINDATTRIBLOCATIONPROC gl3wBindAttribLocation; +extern PFNGLCOMPILESHADERPROC gl3wCompileShader; +extern PFNGLCREATEPROGRAMPROC gl3wCreateProgram; +extern PFNGLCREATESHADERPROC gl3wCreateShader; +extern PFNGLDELETEPROGRAMPROC gl3wDeleteProgram; +extern PFNGLDELETESHADERPROC gl3wDeleteShader; +extern PFNGLDETACHSHADERPROC gl3wDetachShader; +extern PFNGLDISABLEVERTEXATTRIBARRAYPROC gl3wDisableVertexAttribArray; +extern PFNGLENABLEVERTEXATTRIBARRAYPROC gl3wEnableVertexAttribArray; +extern PFNGLGETACTIVEATTRIBPROC gl3wGetActiveAttrib; +extern PFNGLGETACTIVEUNIFORMPROC gl3wGetActiveUniform; +extern PFNGLGETATTACHEDSHADERSPROC gl3wGetAttachedShaders; +extern PFNGLGETATTRIBLOCATIONPROC gl3wGetAttribLocation; +extern PFNGLGETPROGRAMIVPROC gl3wGetProgramiv; +extern PFNGLGETPROGRAMINFOLOGPROC gl3wGetProgramInfoLog; +extern PFNGLGETSHADERIVPROC gl3wGetShaderiv; +extern PFNGLGETSHADERINFOLOGPROC gl3wGetShaderInfoLog; +extern PFNGLGETSHADERSOURCEPROC gl3wGetShaderSource; +extern PFNGLGETUNIFORMLOCATIONPROC gl3wGetUniformLocation; +extern PFNGLGETUNIFORMFVPROC gl3wGetUniformfv; +extern PFNGLGETUNIFORMIVPROC gl3wGetUniformiv; +extern PFNGLGETVERTEXATTRIBDVPROC gl3wGetVertexAttribdv; +extern PFNGLGETVERTEXATTRIBFVPROC gl3wGetVertexAttribfv; +extern PFNGLGETVERTEXATTRIBIVPROC gl3wGetVertexAttribiv; +extern PFNGLGETVERTEXATTRIBPOINTERVPROC gl3wGetVertexAttribPointerv; +extern PFNGLISPROGRAMPROC gl3wIsProgram; +extern PFNGLISSHADERPROC gl3wIsShader; +extern PFNGLLINKPROGRAMPROC gl3wLinkProgram; +extern PFNGLSHADERSOURCEPROC gl3wShaderSource; +extern PFNGLUSEPROGRAMPROC gl3wUseProgram; +extern PFNGLUNIFORM1FPROC gl3wUniform1f; +extern PFNGLUNIFORM2FPROC gl3wUniform2f; +extern PFNGLUNIFORM3FPROC gl3wUniform3f; +extern PFNGLUNIFORM4FPROC gl3wUniform4f; +extern PFNGLUNIFORM1IPROC gl3wUniform1i; +extern PFNGLUNIFORM2IPROC gl3wUniform2i; +extern PFNGLUNIFORM3IPROC gl3wUniform3i; +extern PFNGLUNIFORM4IPROC gl3wUniform4i; +extern PFNGLUNIFORM1FVPROC gl3wUniform1fv; +extern PFNGLUNIFORM2FVPROC gl3wUniform2fv; +extern PFNGLUNIFORM3FVPROC gl3wUniform3fv; +extern PFNGLUNIFORM4FVPROC gl3wUniform4fv; +extern PFNGLUNIFORM1IVPROC gl3wUniform1iv; +extern PFNGLUNIFORM2IVPROC gl3wUniform2iv; +extern PFNGLUNIFORM3IVPROC gl3wUniform3iv; +extern PFNGLUNIFORM4IVPROC gl3wUniform4iv; +extern PFNGLUNIFORMMATRIX2FVPROC gl3wUniformMatrix2fv; +extern PFNGLUNIFORMMATRIX3FVPROC gl3wUniformMatrix3fv; +extern PFNGLUNIFORMMATRIX4FVPROC gl3wUniformMatrix4fv; +extern PFNGLVALIDATEPROGRAMPROC gl3wValidateProgram; +extern PFNGLVERTEXATTRIB1DPROC gl3wVertexAttrib1d; +extern PFNGLVERTEXATTRIB1DVPROC gl3wVertexAttrib1dv; +extern PFNGLVERTEXATTRIB1FPROC gl3wVertexAttrib1f; +extern PFNGLVERTEXATTRIB1FVPROC gl3wVertexAttrib1fv; +extern PFNGLVERTEXATTRIB1SPROC gl3wVertexAttrib1s; +extern PFNGLVERTEXATTRIB1SVPROC gl3wVertexAttrib1sv; +extern PFNGLVERTEXATTRIB2DPROC gl3wVertexAttrib2d; +extern PFNGLVERTEXATTRIB2DVPROC gl3wVertexAttrib2dv; +extern PFNGLVERTEXATTRIB2FPROC gl3wVertexAttrib2f; +extern PFNGLVERTEXATTRIB2FVPROC gl3wVertexAttrib2fv; +extern PFNGLVERTEXATTRIB2SPROC gl3wVertexAttrib2s; +extern PFNGLVERTEXATTRIB2SVPROC gl3wVertexAttrib2sv; +extern PFNGLVERTEXATTRIB3DPROC gl3wVertexAttrib3d; +extern PFNGLVERTEXATTRIB3DVPROC gl3wVertexAttrib3dv; +extern PFNGLVERTEXATTRIB3FPROC gl3wVertexAttrib3f; +extern PFNGLVERTEXATTRIB3FVPROC gl3wVertexAttrib3fv; +extern PFNGLVERTEXATTRIB3SPROC gl3wVertexAttrib3s; +extern PFNGLVERTEXATTRIB3SVPROC gl3wVertexAttrib3sv; +extern PFNGLVERTEXATTRIB4NBVPROC gl3wVertexAttrib4Nbv; +extern PFNGLVERTEXATTRIB4NIVPROC gl3wVertexAttrib4Niv; +extern PFNGLVERTEXATTRIB4NSVPROC gl3wVertexAttrib4Nsv; +extern PFNGLVERTEXATTRIB4NUBPROC gl3wVertexAttrib4Nub; +extern PFNGLVERTEXATTRIB4NUBVPROC gl3wVertexAttrib4Nubv; +extern PFNGLVERTEXATTRIB4NUIVPROC gl3wVertexAttrib4Nuiv; +extern PFNGLVERTEXATTRIB4NUSVPROC gl3wVertexAttrib4Nusv; +extern PFNGLVERTEXATTRIB4BVPROC gl3wVertexAttrib4bv; +extern PFNGLVERTEXATTRIB4DPROC gl3wVertexAttrib4d; +extern PFNGLVERTEXATTRIB4DVPROC gl3wVertexAttrib4dv; +extern PFNGLVERTEXATTRIB4FPROC gl3wVertexAttrib4f; +extern PFNGLVERTEXATTRIB4FVPROC gl3wVertexAttrib4fv; +extern PFNGLVERTEXATTRIB4IVPROC gl3wVertexAttrib4iv; +extern PFNGLVERTEXATTRIB4SPROC gl3wVertexAttrib4s; +extern PFNGLVERTEXATTRIB4SVPROC gl3wVertexAttrib4sv; +extern PFNGLVERTEXATTRIB4UBVPROC gl3wVertexAttrib4ubv; +extern PFNGLVERTEXATTRIB4UIVPROC gl3wVertexAttrib4uiv; +extern PFNGLVERTEXATTRIB4USVPROC gl3wVertexAttrib4usv; +extern PFNGLVERTEXATTRIBPOINTERPROC gl3wVertexAttribPointer; +extern PFNGLUNIFORMMATRIX2X3FVPROC gl3wUniformMatrix2x3fv; +extern PFNGLUNIFORMMATRIX3X2FVPROC gl3wUniformMatrix3x2fv; +extern PFNGLUNIFORMMATRIX2X4FVPROC gl3wUniformMatrix2x4fv; +extern PFNGLUNIFORMMATRIX4X2FVPROC gl3wUniformMatrix4x2fv; +extern PFNGLUNIFORMMATRIX3X4FVPROC gl3wUniformMatrix3x4fv; +extern PFNGLUNIFORMMATRIX4X3FVPROC gl3wUniformMatrix4x3fv; +extern PFNGLCOLORMASKIPROC gl3wColorMaski; +extern PFNGLGETBOOLEANI_VPROC gl3wGetBooleani_v; +extern PFNGLGETINTEGERI_VPROC gl3wGetIntegeri_v; +extern PFNGLENABLEIPROC gl3wEnablei; +extern PFNGLDISABLEIPROC gl3wDisablei; +extern PFNGLISENABLEDIPROC gl3wIsEnabledi; +extern PFNGLBEGINTRANSFORMFEEDBACKPROC gl3wBeginTransformFeedback; +extern PFNGLENDTRANSFORMFEEDBACKPROC gl3wEndTransformFeedback; +extern PFNGLBINDBUFFERRANGEPROC gl3wBindBufferRange; +extern PFNGLBINDBUFFERBASEPROC gl3wBindBufferBase; +extern PFNGLTRANSFORMFEEDBACKVARYINGSPROC gl3wTransformFeedbackVaryings; +extern PFNGLGETTRANSFORMFEEDBACKVARYINGPROC gl3wGetTransformFeedbackVarying; +extern PFNGLCLAMPCOLORPROC gl3wClampColor; +extern PFNGLBEGINCONDITIONALRENDERPROC gl3wBeginConditionalRender; +extern PFNGLENDCONDITIONALRENDERPROC gl3wEndConditionalRender; +extern PFNGLVERTEXATTRIBIPOINTERPROC gl3wVertexAttribIPointer; +extern PFNGLGETVERTEXATTRIBIIVPROC gl3wGetVertexAttribIiv; +extern PFNGLGETVERTEXATTRIBIUIVPROC gl3wGetVertexAttribIuiv; +extern PFNGLVERTEXATTRIBI1IPROC gl3wVertexAttribI1i; +extern PFNGLVERTEXATTRIBI2IPROC gl3wVertexAttribI2i; +extern PFNGLVERTEXATTRIBI3IPROC gl3wVertexAttribI3i; +extern PFNGLVERTEXATTRIBI4IPROC gl3wVertexAttribI4i; +extern PFNGLVERTEXATTRIBI1UIPROC gl3wVertexAttribI1ui; +extern PFNGLVERTEXATTRIBI2UIPROC gl3wVertexAttribI2ui; +extern PFNGLVERTEXATTRIBI3UIPROC gl3wVertexAttribI3ui; +extern PFNGLVERTEXATTRIBI4UIPROC gl3wVertexAttribI4ui; +extern PFNGLVERTEXATTRIBI1IVPROC gl3wVertexAttribI1iv; +extern PFNGLVERTEXATTRIBI2IVPROC gl3wVertexAttribI2iv; +extern PFNGLVERTEXATTRIBI3IVPROC gl3wVertexAttribI3iv; +extern PFNGLVERTEXATTRIBI4IVPROC gl3wVertexAttribI4iv; +extern PFNGLVERTEXATTRIBI1UIVPROC gl3wVertexAttribI1uiv; +extern PFNGLVERTEXATTRIBI2UIVPROC gl3wVertexAttribI2uiv; +extern PFNGLVERTEXATTRIBI3UIVPROC gl3wVertexAttribI3uiv; +extern PFNGLVERTEXATTRIBI4UIVPROC gl3wVertexAttribI4uiv; +extern PFNGLVERTEXATTRIBI4BVPROC gl3wVertexAttribI4bv; +extern PFNGLVERTEXATTRIBI4SVPROC gl3wVertexAttribI4sv; +extern PFNGLVERTEXATTRIBI4UBVPROC gl3wVertexAttribI4ubv; +extern PFNGLVERTEXATTRIBI4USVPROC gl3wVertexAttribI4usv; +extern PFNGLGETUNIFORMUIVPROC gl3wGetUniformuiv; +extern PFNGLBINDFRAGDATALOCATIONPROC gl3wBindFragDataLocation; +extern PFNGLGETFRAGDATALOCATIONPROC gl3wGetFragDataLocation; +extern PFNGLUNIFORM1UIPROC gl3wUniform1ui; +extern PFNGLUNIFORM2UIPROC gl3wUniform2ui; +extern PFNGLUNIFORM3UIPROC gl3wUniform3ui; +extern PFNGLUNIFORM4UIPROC gl3wUniform4ui; +extern PFNGLUNIFORM1UIVPROC gl3wUniform1uiv; +extern PFNGLUNIFORM2UIVPROC gl3wUniform2uiv; +extern PFNGLUNIFORM3UIVPROC gl3wUniform3uiv; +extern PFNGLUNIFORM4UIVPROC gl3wUniform4uiv; +extern PFNGLTEXPARAMETERIIVPROC gl3wTexParameterIiv; +extern PFNGLTEXPARAMETERIUIVPROC gl3wTexParameterIuiv; +extern PFNGLGETTEXPARAMETERIIVPROC gl3wGetTexParameterIiv; +extern PFNGLGETTEXPARAMETERIUIVPROC gl3wGetTexParameterIuiv; +extern PFNGLCLEARBUFFERIVPROC gl3wClearBufferiv; +extern PFNGLCLEARBUFFERUIVPROC gl3wClearBufferuiv; +extern PFNGLCLEARBUFFERFVPROC gl3wClearBufferfv; +extern PFNGLCLEARBUFFERFIPROC gl3wClearBufferfi; +extern PFNGLGETSTRINGIPROC gl3wGetStringi; +extern PFNGLDRAWARRAYSINSTANCEDPROC gl3wDrawArraysInstanced; +extern PFNGLDRAWELEMENTSINSTANCEDPROC gl3wDrawElementsInstanced; +extern PFNGLTEXBUFFERPROC gl3wTexBuffer; +extern PFNGLPRIMITIVERESTARTINDEXPROC gl3wPrimitiveRestartIndex; +extern PFNGLGETINTEGER64I_VPROC gl3wGetInteger64i_v; +extern PFNGLGETBUFFERPARAMETERI64VPROC gl3wGetBufferParameteri64v; +extern PFNGLFRAMEBUFFERTEXTUREPROC gl3wFramebufferTexture; +extern PFNGLVERTEXATTRIBDIVISORPROC gl3wVertexAttribDivisor; +extern PFNGLMINSAMPLESHADINGPROC gl3wMinSampleShading; +extern PFNGLBLENDEQUATIONIPROC gl3wBlendEquationi; +extern PFNGLBLENDEQUATIONSEPARATEIPROC gl3wBlendEquationSeparatei; +extern PFNGLBLENDFUNCIPROC gl3wBlendFunci; +extern PFNGLBLENDFUNCSEPARATEIPROC gl3wBlendFuncSeparatei; +extern PFNGLISRENDERBUFFERPROC gl3wIsRenderbuffer; +extern PFNGLBINDRENDERBUFFERPROC gl3wBindRenderbuffer; +extern PFNGLDELETERENDERBUFFERSPROC gl3wDeleteRenderbuffers; +extern PFNGLGENRENDERBUFFERSPROC gl3wGenRenderbuffers; +extern PFNGLRENDERBUFFERSTORAGEPROC gl3wRenderbufferStorage; +extern PFNGLGETRENDERBUFFERPARAMETERIVPROC gl3wGetRenderbufferParameteriv; +extern PFNGLISFRAMEBUFFERPROC gl3wIsFramebuffer; +extern PFNGLBINDFRAMEBUFFERPROC gl3wBindFramebuffer; +extern PFNGLDELETEFRAMEBUFFERSPROC gl3wDeleteFramebuffers; +extern PFNGLGENFRAMEBUFFERSPROC gl3wGenFramebuffers; +extern PFNGLCHECKFRAMEBUFFERSTATUSPROC gl3wCheckFramebufferStatus; +extern PFNGLFRAMEBUFFERTEXTURE1DPROC gl3wFramebufferTexture1D; +extern PFNGLFRAMEBUFFERTEXTURE2DPROC gl3wFramebufferTexture2D; +extern PFNGLFRAMEBUFFERTEXTURE3DPROC gl3wFramebufferTexture3D; +extern PFNGLFRAMEBUFFERRENDERBUFFERPROC gl3wFramebufferRenderbuffer; +extern PFNGLGETFRAMEBUFFERATTACHMENTPARAMETERIVPROC gl3wGetFramebufferAttachmentParameteriv; +extern PFNGLGENERATEMIPMAPPROC gl3wGenerateMipmap; +extern PFNGLBLITFRAMEBUFFERPROC gl3wBlitFramebuffer; +extern PFNGLRENDERBUFFERSTORAGEMULTISAMPLEPROC gl3wRenderbufferStorageMultisample; +extern PFNGLFRAMEBUFFERTEXTURELAYERPROC gl3wFramebufferTextureLayer; +extern PFNGLMAPBUFFERRANGEPROC gl3wMapBufferRange; +extern PFNGLFLUSHMAPPEDBUFFERRANGEPROC gl3wFlushMappedBufferRange; +extern PFNGLBINDVERTEXARRAYPROC gl3wBindVertexArray; +extern PFNGLDELETEVERTEXARRAYSPROC gl3wDeleteVertexArrays; +extern PFNGLGENVERTEXARRAYSPROC gl3wGenVertexArrays; +extern PFNGLISVERTEXARRAYPROC gl3wIsVertexArray; +extern PFNGLGETUNIFORMINDICESPROC gl3wGetUniformIndices; +extern PFNGLGETACTIVEUNIFORMSIVPROC gl3wGetActiveUniformsiv; +extern PFNGLGETACTIVEUNIFORMNAMEPROC gl3wGetActiveUniformName; +extern PFNGLGETUNIFORMBLOCKINDEXPROC gl3wGetUniformBlockIndex; +extern PFNGLGETACTIVEUNIFORMBLOCKIVPROC gl3wGetActiveUniformBlockiv; +extern PFNGLGETACTIVEUNIFORMBLOCKNAMEPROC gl3wGetActiveUniformBlockName; +extern PFNGLUNIFORMBLOCKBINDINGPROC gl3wUniformBlockBinding; +extern PFNGLCOPYBUFFERSUBDATAPROC gl3wCopyBufferSubData; +extern PFNGLDRAWELEMENTSBASEVERTEXPROC gl3wDrawElementsBaseVertex; +extern PFNGLDRAWRANGEELEMENTSBASEVERTEXPROC gl3wDrawRangeElementsBaseVertex; +extern PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXPROC gl3wDrawElementsInstancedBaseVertex; +extern PFNGLMULTIDRAWELEMENTSBASEVERTEXPROC gl3wMultiDrawElementsBaseVertex; +extern PFNGLPROVOKINGVERTEXPROC gl3wProvokingVertex; +extern PFNGLFENCESYNCPROC gl3wFenceSync; +extern PFNGLISSYNCPROC gl3wIsSync; +extern PFNGLDELETESYNCPROC gl3wDeleteSync; +extern PFNGLCLIENTWAITSYNCPROC gl3wClientWaitSync; +extern PFNGLWAITSYNCPROC gl3wWaitSync; +extern PFNGLGETINTEGER64VPROC gl3wGetInteger64v; +extern PFNGLGETSYNCIVPROC gl3wGetSynciv; +extern PFNGLTEXIMAGE2DMULTISAMPLEPROC gl3wTexImage2DMultisample; +extern PFNGLTEXIMAGE3DMULTISAMPLEPROC gl3wTexImage3DMultisample; +extern PFNGLGETMULTISAMPLEFVPROC gl3wGetMultisamplefv; +extern PFNGLSAMPLEMASKIPROC gl3wSampleMaski; +extern PFNGLBLENDEQUATIONIARBPROC gl3wBlendEquationiARB; +extern PFNGLBLENDEQUATIONSEPARATEIARBPROC gl3wBlendEquationSeparateiARB; +extern PFNGLBLENDFUNCIARBPROC gl3wBlendFunciARB; +extern PFNGLBLENDFUNCSEPARATEIARBPROC gl3wBlendFuncSeparateiARB; +extern PFNGLMINSAMPLESHADINGARBPROC gl3wMinSampleShadingARB; +extern PFNGLNAMEDSTRINGARBPROC gl3wNamedStringARB; +extern PFNGLDELETENAMEDSTRINGARBPROC gl3wDeleteNamedStringARB; +extern PFNGLCOMPILESHADERINCLUDEARBPROC gl3wCompileShaderIncludeARB; +extern PFNGLISNAMEDSTRINGARBPROC gl3wIsNamedStringARB; +extern PFNGLGETNAMEDSTRINGARBPROC gl3wGetNamedStringARB; +extern PFNGLGETNAMEDSTRINGIVARBPROC gl3wGetNamedStringivARB; +extern PFNGLBINDFRAGDATALOCATIONINDEXEDPROC gl3wBindFragDataLocationIndexed; +extern PFNGLGETFRAGDATAINDEXPROC gl3wGetFragDataIndex; +extern PFNGLGENSAMPLERSPROC gl3wGenSamplers; +extern PFNGLDELETESAMPLERSPROC gl3wDeleteSamplers; +extern PFNGLISSAMPLERPROC gl3wIsSampler; +extern PFNGLBINDSAMPLERPROC gl3wBindSampler; +extern PFNGLSAMPLERPARAMETERIPROC gl3wSamplerParameteri; +extern PFNGLSAMPLERPARAMETERIVPROC gl3wSamplerParameteriv; +extern PFNGLSAMPLERPARAMETERFPROC gl3wSamplerParameterf; +extern PFNGLSAMPLERPARAMETERFVPROC gl3wSamplerParameterfv; +extern PFNGLSAMPLERPARAMETERIIVPROC gl3wSamplerParameterIiv; +extern PFNGLSAMPLERPARAMETERIUIVPROC gl3wSamplerParameterIuiv; +extern PFNGLGETSAMPLERPARAMETERIVPROC gl3wGetSamplerParameteriv; +extern PFNGLGETSAMPLERPARAMETERIIVPROC gl3wGetSamplerParameterIiv; +extern PFNGLGETSAMPLERPARAMETERFVPROC gl3wGetSamplerParameterfv; +extern PFNGLGETSAMPLERPARAMETERIUIVPROC gl3wGetSamplerParameterIuiv; +extern PFNGLQUERYCOUNTERPROC gl3wQueryCounter; +extern PFNGLGETQUERYOBJECTI64VPROC gl3wGetQueryObjecti64v; +extern PFNGLGETQUERYOBJECTUI64VPROC gl3wGetQueryObjectui64v; +extern PFNGLVERTEXP2UIPROC gl3wVertexP2ui; +extern PFNGLVERTEXP2UIVPROC gl3wVertexP2uiv; +extern PFNGLVERTEXP3UIPROC gl3wVertexP3ui; +extern PFNGLVERTEXP3UIVPROC gl3wVertexP3uiv; +extern PFNGLVERTEXP4UIPROC gl3wVertexP4ui; +extern PFNGLVERTEXP4UIVPROC gl3wVertexP4uiv; +extern PFNGLTEXCOORDP1UIPROC gl3wTexCoordP1ui; +extern PFNGLTEXCOORDP1UIVPROC gl3wTexCoordP1uiv; +extern PFNGLTEXCOORDP2UIPROC gl3wTexCoordP2ui; +extern PFNGLTEXCOORDP2UIVPROC gl3wTexCoordP2uiv; +extern PFNGLTEXCOORDP3UIPROC gl3wTexCoordP3ui; +extern PFNGLTEXCOORDP3UIVPROC gl3wTexCoordP3uiv; +extern PFNGLTEXCOORDP4UIPROC gl3wTexCoordP4ui; +extern PFNGLTEXCOORDP4UIVPROC gl3wTexCoordP4uiv; +extern PFNGLMULTITEXCOORDP1UIPROC gl3wMultiTexCoordP1ui; +extern PFNGLMULTITEXCOORDP1UIVPROC gl3wMultiTexCoordP1uiv; +extern PFNGLMULTITEXCOORDP2UIPROC gl3wMultiTexCoordP2ui; +extern PFNGLMULTITEXCOORDP2UIVPROC gl3wMultiTexCoordP2uiv; +extern PFNGLMULTITEXCOORDP3UIPROC gl3wMultiTexCoordP3ui; +extern PFNGLMULTITEXCOORDP3UIVPROC gl3wMultiTexCoordP3uiv; +extern PFNGLMULTITEXCOORDP4UIPROC gl3wMultiTexCoordP4ui; +extern PFNGLMULTITEXCOORDP4UIVPROC gl3wMultiTexCoordP4uiv; +extern PFNGLNORMALP3UIPROC gl3wNormalP3ui; +extern PFNGLNORMALP3UIVPROC gl3wNormalP3uiv; +extern PFNGLCOLORP3UIPROC gl3wColorP3ui; +extern PFNGLCOLORP3UIVPROC gl3wColorP3uiv; +extern PFNGLCOLORP4UIPROC gl3wColorP4ui; +extern PFNGLCOLORP4UIVPROC gl3wColorP4uiv; +extern PFNGLSECONDARYCOLORP3UIPROC gl3wSecondaryColorP3ui; +extern PFNGLSECONDARYCOLORP3UIVPROC gl3wSecondaryColorP3uiv; +extern PFNGLVERTEXATTRIBP1UIPROC gl3wVertexAttribP1ui; +extern PFNGLVERTEXATTRIBP1UIVPROC gl3wVertexAttribP1uiv; +extern PFNGLVERTEXATTRIBP2UIPROC gl3wVertexAttribP2ui; +extern PFNGLVERTEXATTRIBP2UIVPROC gl3wVertexAttribP2uiv; +extern PFNGLVERTEXATTRIBP3UIPROC gl3wVertexAttribP3ui; +extern PFNGLVERTEXATTRIBP3UIVPROC gl3wVertexAttribP3uiv; +extern PFNGLVERTEXATTRIBP4UIPROC gl3wVertexAttribP4ui; +extern PFNGLVERTEXATTRIBP4UIVPROC gl3wVertexAttribP4uiv; +extern PFNGLDRAWARRAYSINDIRECTPROC gl3wDrawArraysIndirect; +extern PFNGLDRAWELEMENTSINDIRECTPROC gl3wDrawElementsIndirect; +extern PFNGLUNIFORM1DPROC gl3wUniform1d; +extern PFNGLUNIFORM2DPROC gl3wUniform2d; +extern PFNGLUNIFORM3DPROC gl3wUniform3d; +extern PFNGLUNIFORM4DPROC gl3wUniform4d; +extern PFNGLUNIFORM1DVPROC gl3wUniform1dv; +extern PFNGLUNIFORM2DVPROC gl3wUniform2dv; +extern PFNGLUNIFORM3DVPROC gl3wUniform3dv; +extern PFNGLUNIFORM4DVPROC gl3wUniform4dv; +extern PFNGLUNIFORMMATRIX2DVPROC gl3wUniformMatrix2dv; +extern PFNGLUNIFORMMATRIX3DVPROC gl3wUniformMatrix3dv; +extern PFNGLUNIFORMMATRIX4DVPROC gl3wUniformMatrix4dv; +extern PFNGLUNIFORMMATRIX2X3DVPROC gl3wUniformMatrix2x3dv; +extern PFNGLUNIFORMMATRIX2X4DVPROC gl3wUniformMatrix2x4dv; +extern PFNGLUNIFORMMATRIX3X2DVPROC gl3wUniformMatrix3x2dv; +extern PFNGLUNIFORMMATRIX3X4DVPROC gl3wUniformMatrix3x4dv; +extern PFNGLUNIFORMMATRIX4X2DVPROC gl3wUniformMatrix4x2dv; +extern PFNGLUNIFORMMATRIX4X3DVPROC gl3wUniformMatrix4x3dv; +extern PFNGLGETUNIFORMDVPROC gl3wGetUniformdv; +extern PFNGLGETSUBROUTINEUNIFORMLOCATIONPROC gl3wGetSubroutineUniformLocation; +extern PFNGLGETSUBROUTINEINDEXPROC gl3wGetSubroutineIndex; +extern PFNGLGETACTIVESUBROUTINEUNIFORMIVPROC gl3wGetActiveSubroutineUniformiv; +extern PFNGLGETACTIVESUBROUTINEUNIFORMNAMEPROC gl3wGetActiveSubroutineUniformName; +extern PFNGLGETACTIVESUBROUTINENAMEPROC gl3wGetActiveSubroutineName; +extern PFNGLUNIFORMSUBROUTINESUIVPROC gl3wUniformSubroutinesuiv; +extern PFNGLGETUNIFORMSUBROUTINEUIVPROC gl3wGetUniformSubroutineuiv; +extern PFNGLGETPROGRAMSTAGEIVPROC gl3wGetProgramStageiv; +extern PFNGLPATCHPARAMETERIPROC gl3wPatchParameteri; +extern PFNGLPATCHPARAMETERFVPROC gl3wPatchParameterfv; +extern PFNGLBINDTRANSFORMFEEDBACKPROC gl3wBindTransformFeedback; +extern PFNGLDELETETRANSFORMFEEDBACKSPROC gl3wDeleteTransformFeedbacks; +extern PFNGLGENTRANSFORMFEEDBACKSPROC gl3wGenTransformFeedbacks; +extern PFNGLISTRANSFORMFEEDBACKPROC gl3wIsTransformFeedback; +extern PFNGLPAUSETRANSFORMFEEDBACKPROC gl3wPauseTransformFeedback; +extern PFNGLRESUMETRANSFORMFEEDBACKPROC gl3wResumeTransformFeedback; +extern PFNGLDRAWTRANSFORMFEEDBACKPROC gl3wDrawTransformFeedback; +extern PFNGLDRAWTRANSFORMFEEDBACKSTREAMPROC gl3wDrawTransformFeedbackStream; +extern PFNGLBEGINQUERYINDEXEDPROC gl3wBeginQueryIndexed; +extern PFNGLENDQUERYINDEXEDPROC gl3wEndQueryIndexed; +extern PFNGLGETQUERYINDEXEDIVPROC gl3wGetQueryIndexediv; +extern PFNGLRELEASESHADERCOMPILERPROC gl3wReleaseShaderCompiler; +extern PFNGLSHADERBINARYPROC gl3wShaderBinary; +extern PFNGLGETSHADERPRECISIONFORMATPROC gl3wGetShaderPrecisionFormat; +extern PFNGLDEPTHRANGEFPROC gl3wDepthRangef; +extern PFNGLCLEARDEPTHFPROC gl3wClearDepthf; +extern PFNGLGETPROGRAMBINARYPROC gl3wGetProgramBinary; +extern PFNGLPROGRAMBINARYPROC gl3wProgramBinary; +extern PFNGLPROGRAMPARAMETERIPROC gl3wProgramParameteri; +extern PFNGLUSEPROGRAMSTAGESPROC gl3wUseProgramStages; +extern PFNGLACTIVESHADERPROGRAMPROC gl3wActiveShaderProgram; +extern PFNGLCREATESHADERPROGRAMVPROC gl3wCreateShaderProgramv; +extern PFNGLBINDPROGRAMPIPELINEPROC gl3wBindProgramPipeline; +extern PFNGLDELETEPROGRAMPIPELINESPROC gl3wDeleteProgramPipelines; +extern PFNGLGENPROGRAMPIPELINESPROC gl3wGenProgramPipelines; +extern PFNGLISPROGRAMPIPELINEPROC gl3wIsProgramPipeline; +extern PFNGLGETPROGRAMPIPELINEIVPROC gl3wGetProgramPipelineiv; +extern PFNGLPROGRAMUNIFORM1IPROC gl3wProgramUniform1i; +extern PFNGLPROGRAMUNIFORM1IVPROC gl3wProgramUniform1iv; +extern PFNGLPROGRAMUNIFORM1FPROC gl3wProgramUniform1f; +extern PFNGLPROGRAMUNIFORM1FVPROC gl3wProgramUniform1fv; +extern PFNGLPROGRAMUNIFORM1DPROC gl3wProgramUniform1d; +extern PFNGLPROGRAMUNIFORM1DVPROC gl3wProgramUniform1dv; +extern PFNGLPROGRAMUNIFORM1UIPROC gl3wProgramUniform1ui; +extern PFNGLPROGRAMUNIFORM1UIVPROC gl3wProgramUniform1uiv; +extern PFNGLPROGRAMUNIFORM2IPROC gl3wProgramUniform2i; +extern PFNGLPROGRAMUNIFORM2IVPROC gl3wProgramUniform2iv; +extern PFNGLPROGRAMUNIFORM2FPROC gl3wProgramUniform2f; +extern PFNGLPROGRAMUNIFORM2FVPROC gl3wProgramUniform2fv; +extern PFNGLPROGRAMUNIFORM2DPROC gl3wProgramUniform2d; +extern PFNGLPROGRAMUNIFORM2DVPROC gl3wProgramUniform2dv; +extern PFNGLPROGRAMUNIFORM2UIPROC gl3wProgramUniform2ui; +extern PFNGLPROGRAMUNIFORM2UIVPROC gl3wProgramUniform2uiv; +extern PFNGLPROGRAMUNIFORM3IPROC gl3wProgramUniform3i; +extern PFNGLPROGRAMUNIFORM3IVPROC gl3wProgramUniform3iv; +extern PFNGLPROGRAMUNIFORM3FPROC gl3wProgramUniform3f; +extern PFNGLPROGRAMUNIFORM3FVPROC gl3wProgramUniform3fv; +extern PFNGLPROGRAMUNIFORM3DPROC gl3wProgramUniform3d; +extern PFNGLPROGRAMUNIFORM3DVPROC gl3wProgramUniform3dv; +extern PFNGLPROGRAMUNIFORM3UIPROC gl3wProgramUniform3ui; +extern PFNGLPROGRAMUNIFORM3UIVPROC gl3wProgramUniform3uiv; +extern PFNGLPROGRAMUNIFORM4IPROC gl3wProgramUniform4i; +extern PFNGLPROGRAMUNIFORM4IVPROC gl3wProgramUniform4iv; +extern PFNGLPROGRAMUNIFORM4FPROC gl3wProgramUniform4f; +extern PFNGLPROGRAMUNIFORM4FVPROC gl3wProgramUniform4fv; +extern PFNGLPROGRAMUNIFORM4DPROC gl3wProgramUniform4d; +extern PFNGLPROGRAMUNIFORM4DVPROC gl3wProgramUniform4dv; +extern PFNGLPROGRAMUNIFORM4UIPROC gl3wProgramUniform4ui; +extern PFNGLPROGRAMUNIFORM4UIVPROC gl3wProgramUniform4uiv; +extern PFNGLPROGRAMUNIFORMMATRIX2FVPROC gl3wProgramUniformMatrix2fv; +extern PFNGLPROGRAMUNIFORMMATRIX3FVPROC gl3wProgramUniformMatrix3fv; +extern PFNGLPROGRAMUNIFORMMATRIX4FVPROC gl3wProgramUniformMatrix4fv; +extern PFNGLPROGRAMUNIFORMMATRIX2DVPROC gl3wProgramUniformMatrix2dv; +extern PFNGLPROGRAMUNIFORMMATRIX3DVPROC gl3wProgramUniformMatrix3dv; +extern PFNGLPROGRAMUNIFORMMATRIX4DVPROC gl3wProgramUniformMatrix4dv; +extern PFNGLPROGRAMUNIFORMMATRIX2X3FVPROC gl3wProgramUniformMatrix2x3fv; +extern PFNGLPROGRAMUNIFORMMATRIX3X2FVPROC gl3wProgramUniformMatrix3x2fv; +extern PFNGLPROGRAMUNIFORMMATRIX2X4FVPROC gl3wProgramUniformMatrix2x4fv; +extern PFNGLPROGRAMUNIFORMMATRIX4X2FVPROC gl3wProgramUniformMatrix4x2fv; +extern PFNGLPROGRAMUNIFORMMATRIX3X4FVPROC gl3wProgramUniformMatrix3x4fv; +extern PFNGLPROGRAMUNIFORMMATRIX4X3FVPROC gl3wProgramUniformMatrix4x3fv; +extern PFNGLPROGRAMUNIFORMMATRIX2X3DVPROC gl3wProgramUniformMatrix2x3dv; +extern PFNGLPROGRAMUNIFORMMATRIX3X2DVPROC gl3wProgramUniformMatrix3x2dv; +extern PFNGLPROGRAMUNIFORMMATRIX2X4DVPROC gl3wProgramUniformMatrix2x4dv; +extern PFNGLPROGRAMUNIFORMMATRIX4X2DVPROC gl3wProgramUniformMatrix4x2dv; +extern PFNGLPROGRAMUNIFORMMATRIX3X4DVPROC gl3wProgramUniformMatrix3x4dv; +extern PFNGLPROGRAMUNIFORMMATRIX4X3DVPROC gl3wProgramUniformMatrix4x3dv; +extern PFNGLVALIDATEPROGRAMPIPELINEPROC gl3wValidateProgramPipeline; +extern PFNGLGETPROGRAMPIPELINEINFOLOGPROC gl3wGetProgramPipelineInfoLog; +extern PFNGLVERTEXATTRIBL1DPROC gl3wVertexAttribL1d; +extern PFNGLVERTEXATTRIBL2DPROC gl3wVertexAttribL2d; +extern PFNGLVERTEXATTRIBL3DPROC gl3wVertexAttribL3d; +extern PFNGLVERTEXATTRIBL4DPROC gl3wVertexAttribL4d; +extern PFNGLVERTEXATTRIBL1DVPROC gl3wVertexAttribL1dv; +extern PFNGLVERTEXATTRIBL2DVPROC gl3wVertexAttribL2dv; +extern PFNGLVERTEXATTRIBL3DVPROC gl3wVertexAttribL3dv; +extern PFNGLVERTEXATTRIBL4DVPROC gl3wVertexAttribL4dv; +extern PFNGLVERTEXATTRIBLPOINTERPROC gl3wVertexAttribLPointer; +extern PFNGLGETVERTEXATTRIBLDVPROC gl3wGetVertexAttribLdv; +extern PFNGLVIEWPORTARRAYVPROC gl3wViewportArrayv; +extern PFNGLVIEWPORTINDEXEDFPROC gl3wViewportIndexedf; +extern PFNGLVIEWPORTINDEXEDFVPROC gl3wViewportIndexedfv; +extern PFNGLSCISSORARRAYVPROC gl3wScissorArrayv; +extern PFNGLSCISSORINDEXEDPROC gl3wScissorIndexed; +extern PFNGLSCISSORINDEXEDVPROC gl3wScissorIndexedv; +extern PFNGLDEPTHRANGEARRAYVPROC gl3wDepthRangeArrayv; +extern PFNGLDEPTHRANGEINDEXEDPROC gl3wDepthRangeIndexed; +extern PFNGLGETFLOATI_VPROC gl3wGetFloati_v; +extern PFNGLGETDOUBLEI_VPROC gl3wGetDoublei_v; +extern PFNGLCREATESYNCFROMCLEVENTARBPROC gl3wCreateSyncFromCLeventARB; +extern PFNGLDEBUGMESSAGECONTROLARBPROC gl3wDebugMessageControlARB; +extern PFNGLDEBUGMESSAGEINSERTARBPROC gl3wDebugMessageInsertARB; +extern PFNGLDEBUGMESSAGECALLBACKARBPROC gl3wDebugMessageCallbackARB; +extern PFNGLGETDEBUGMESSAGELOGARBPROC gl3wGetDebugMessageLogARB; +extern PFNGLGETGRAPHICSRESETSTATUSARBPROC gl3wGetGraphicsResetStatusARB; +extern PFNGLGETNTEXIMAGEARBPROC gl3wGetnTexImageARB; +extern PFNGLREADNPIXELSARBPROC gl3wReadnPixelsARB; +extern PFNGLGETNCOMPRESSEDTEXIMAGEARBPROC gl3wGetnCompressedTexImageARB; +extern PFNGLGETNUNIFORMFVARBPROC gl3wGetnUniformfvARB; +extern PFNGLGETNUNIFORMIVARBPROC gl3wGetnUniformivARB; +extern PFNGLGETNUNIFORMUIVARBPROC gl3wGetnUniformuivARB; +extern PFNGLGETNUNIFORMDVARBPROC gl3wGetnUniformdvARB; +extern PFNGLDRAWARRAYSINSTANCEDBASEINSTANCEPROC gl3wDrawArraysInstancedBaseInstance; +extern PFNGLDRAWELEMENTSINSTANCEDBASEINSTANCEPROC gl3wDrawElementsInstancedBaseInstance; +extern PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXBASEINSTANCEPROC gl3wDrawElementsInstancedBaseVertexBaseInstance; +extern PFNGLDRAWTRANSFORMFEEDBACKINSTANCEDPROC gl3wDrawTransformFeedbackInstanced; +extern PFNGLDRAWTRANSFORMFEEDBACKSTREAMINSTANCEDPROC gl3wDrawTransformFeedbackStreamInstanced; +extern PFNGLGETINTERNALFORMATIVPROC gl3wGetInternalformativ; +extern PFNGLGETACTIVEATOMICCOUNTERBUFFERIVPROC gl3wGetActiveAtomicCounterBufferiv; +extern PFNGLBINDIMAGETEXTUREPROC gl3wBindImageTexture; +extern PFNGLMEMORYBARRIERPROC gl3wMemoryBarrier; +extern PFNGLTEXSTORAGE1DPROC gl3wTexStorage1D; +extern PFNGLTEXSTORAGE2DPROC gl3wTexStorage2D; +extern PFNGLTEXSTORAGE3DPROC gl3wTexStorage3D; +extern PFNGLTEXTURESTORAGE1DEXTPROC gl3wTextureStorage1DEXT; +extern PFNGLTEXTURESTORAGE2DEXTPROC gl3wTextureStorage2DEXT; +extern PFNGLTEXTURESTORAGE3DEXTPROC gl3wTextureStorage3DEXT; +extern PFNGLDEBUGMESSAGECONTROLPROC gl3wDebugMessageControl; +extern PFNGLDEBUGMESSAGEINSERTPROC gl3wDebugMessageInsert; +extern PFNGLDEBUGMESSAGECALLBACKPROC gl3wDebugMessageCallback; +extern PFNGLGETDEBUGMESSAGELOGPROC gl3wGetDebugMessageLog; +extern PFNGLPUSHDEBUGGROUPPROC gl3wPushDebugGroup; +extern PFNGLPOPDEBUGGROUPPROC gl3wPopDebugGroup; +extern PFNGLOBJECTLABELPROC gl3wObjectLabel; +extern PFNGLGETOBJECTLABELPROC gl3wGetObjectLabel; +extern PFNGLOBJECTPTRLABELPROC gl3wObjectPtrLabel; +extern PFNGLGETOBJECTPTRLABELPROC gl3wGetObjectPtrLabel; +extern PFNGLCLEARBUFFERDATAPROC gl3wClearBufferData; +extern PFNGLCLEARBUFFERSUBDATAPROC gl3wClearBufferSubData; +extern PFNGLCLEARNAMEDBUFFERDATAEXTPROC gl3wClearNamedBufferDataEXT; +extern PFNGLCLEARNAMEDBUFFERSUBDATAEXTPROC gl3wClearNamedBufferSubDataEXT; +extern PFNGLDISPATCHCOMPUTEPROC gl3wDispatchCompute; +extern PFNGLDISPATCHCOMPUTEINDIRECTPROC gl3wDispatchComputeIndirect; +extern PFNGLCOPYIMAGESUBDATAPROC gl3wCopyImageSubData; +extern PFNGLTEXTUREVIEWPROC gl3wTextureView; +extern PFNGLBINDVERTEXBUFFERPROC gl3wBindVertexBuffer; +extern PFNGLVERTEXATTRIBFORMATPROC gl3wVertexAttribFormat; +extern PFNGLVERTEXATTRIBIFORMATPROC gl3wVertexAttribIFormat; +extern PFNGLVERTEXATTRIBLFORMATPROC gl3wVertexAttribLFormat; +extern PFNGLVERTEXATTRIBBINDINGPROC gl3wVertexAttribBinding; +extern PFNGLVERTEXBINDINGDIVISORPROC gl3wVertexBindingDivisor; +extern PFNGLVERTEXARRAYBINDVERTEXBUFFEREXTPROC gl3wVertexArrayBindVertexBufferEXT; +extern PFNGLVERTEXARRAYVERTEXATTRIBFORMATEXTPROC gl3wVertexArrayVertexAttribFormatEXT; +extern PFNGLVERTEXARRAYVERTEXATTRIBIFORMATEXTPROC gl3wVertexArrayVertexAttribIFormatEXT; +extern PFNGLVERTEXARRAYVERTEXATTRIBLFORMATEXTPROC gl3wVertexArrayVertexAttribLFormatEXT; +extern PFNGLVERTEXARRAYVERTEXATTRIBBINDINGEXTPROC gl3wVertexArrayVertexAttribBindingEXT; +extern PFNGLVERTEXARRAYVERTEXBINDINGDIVISOREXTPROC gl3wVertexArrayVertexBindingDivisorEXT; +extern PFNGLFRAMEBUFFERPARAMETERIPROC gl3wFramebufferParameteri; +extern PFNGLGETFRAMEBUFFERPARAMETERIVPROC gl3wGetFramebufferParameteriv; +extern PFNGLNAMEDFRAMEBUFFERPARAMETERIEXTPROC gl3wNamedFramebufferParameteriEXT; +extern PFNGLGETNAMEDFRAMEBUFFERPARAMETERIVEXTPROC gl3wGetNamedFramebufferParameterivEXT; +extern PFNGLGETINTERNALFORMATI64VPROC gl3wGetInternalformati64v; +extern PFNGLINVALIDATETEXSUBIMAGEPROC gl3wInvalidateTexSubImage; +extern PFNGLINVALIDATETEXIMAGEPROC gl3wInvalidateTexImage; +extern PFNGLINVALIDATEBUFFERSUBDATAPROC gl3wInvalidateBufferSubData; +extern PFNGLINVALIDATEBUFFERDATAPROC gl3wInvalidateBufferData; +extern PFNGLINVALIDATEFRAMEBUFFERPROC gl3wInvalidateFramebuffer; +extern PFNGLINVALIDATESUBFRAMEBUFFERPROC gl3wInvalidateSubFramebuffer; +extern PFNGLMULTIDRAWARRAYSINDIRECTPROC gl3wMultiDrawArraysIndirect; +extern PFNGLMULTIDRAWELEMENTSINDIRECTPROC gl3wMultiDrawElementsIndirect; +extern PFNGLGETPROGRAMINTERFACEIVPROC gl3wGetProgramInterfaceiv; +extern PFNGLGETPROGRAMRESOURCEINDEXPROC gl3wGetProgramResourceIndex; +extern PFNGLGETPROGRAMRESOURCENAMEPROC gl3wGetProgramResourceName; +extern PFNGLGETPROGRAMRESOURCEIVPROC gl3wGetProgramResourceiv; +extern PFNGLGETPROGRAMRESOURCELOCATIONPROC gl3wGetProgramResourceLocation; +extern PFNGLGETPROGRAMRESOURCELOCATIONINDEXPROC gl3wGetProgramResourceLocationIndex; +extern PFNGLSHADERSTORAGEBLOCKBINDINGPROC gl3wShaderStorageBlockBinding; +extern PFNGLTEXBUFFERRANGEPROC gl3wTexBufferRange; +extern PFNGLTEXTUREBUFFERRANGEEXTPROC gl3wTextureBufferRangeEXT; +extern PFNGLTEXSTORAGE2DMULTISAMPLEPROC gl3wTexStorage2DMultisample; +extern PFNGLTEXSTORAGE3DMULTISAMPLEPROC gl3wTexStorage3DMultisample; +extern PFNGLTEXTURESTORAGE2DMULTISAMPLEEXTPROC gl3wTextureStorage2DMultisampleEXT; +extern PFNGLTEXTURESTORAGE3DMULTISAMPLEEXTPROC gl3wTextureStorage3DMultisampleEXT; + +#define glCullFace gl3wCullFace +#define glFrontFace gl3wFrontFace +#define glHint gl3wHint +#define glLineWidth gl3wLineWidth +#define glPointSize gl3wPointSize +#define glPolygonMode gl3wPolygonMode +#define glScissor gl3wScissor +#define glTexParameterf gl3wTexParameterf +#define glTexParameterfv gl3wTexParameterfv +#define glTexParameteri gl3wTexParameteri +#define glTexParameteriv gl3wTexParameteriv +#define glTexImage1D gl3wTexImage1D +#define glTexImage2D gl3wTexImage2D +#define glDrawBuffer gl3wDrawBuffer +#define glClear gl3wClear +#define glClearColor gl3wClearColor +#define glClearStencil gl3wClearStencil +#define glClearDepth gl3wClearDepth +#define glStencilMask gl3wStencilMask +#define glColorMask gl3wColorMask +#define glDepthMask gl3wDepthMask +#define glDisable gl3wDisable +#define glEnable gl3wEnable +#define glFinish gl3wFinish +#define glFlush gl3wFlush +#define glBlendFunc gl3wBlendFunc +#define glLogicOp gl3wLogicOp +#define glStencilFunc gl3wStencilFunc +#define glStencilOp gl3wStencilOp +#define glDepthFunc gl3wDepthFunc +#define glPixelStoref gl3wPixelStoref +#define glPixelStorei gl3wPixelStorei +#define glReadBuffer gl3wReadBuffer +#define glReadPixels gl3wReadPixels +#define glGetBooleanv gl3wGetBooleanv +#define glGetDoublev gl3wGetDoublev +#define glGetError gl3wGetError +#define glGetFloatv gl3wGetFloatv +#define glGetIntegerv gl3wGetIntegerv +#define glGetString gl3wGetString +#define glGetTexImage gl3wGetTexImage +#define glGetTexParameterfv gl3wGetTexParameterfv +#define glGetTexParameteriv gl3wGetTexParameteriv +#define glGetTexLevelParameterfv gl3wGetTexLevelParameterfv +#define glGetTexLevelParameteriv gl3wGetTexLevelParameteriv +#define glIsEnabled gl3wIsEnabled +#define glDepthRange gl3wDepthRange +#define glViewport gl3wViewport +#define glDrawArrays gl3wDrawArrays +#define glDrawElements gl3wDrawElements +#define glGetPointerv gl3wGetPointerv +#define glPolygonOffset gl3wPolygonOffset +#define glCopyTexImage1D gl3wCopyTexImage1D +#define glCopyTexImage2D gl3wCopyTexImage2D +#define glCopyTexSubImage1D gl3wCopyTexSubImage1D +#define glCopyTexSubImage2D gl3wCopyTexSubImage2D +#define glTexSubImage1D gl3wTexSubImage1D +#define glTexSubImage2D gl3wTexSubImage2D +#define glBindTexture gl3wBindTexture +#define glDeleteTextures gl3wDeleteTextures +#define glGenTextures gl3wGenTextures +#define glIsTexture gl3wIsTexture +#define glBlendColor gl3wBlendColor +#define glBlendEquation gl3wBlendEquation +#define glDrawRangeElements gl3wDrawRangeElements +#define glTexImage3D gl3wTexImage3D +#define glTexSubImage3D gl3wTexSubImage3D +#define glCopyTexSubImage3D gl3wCopyTexSubImage3D +#define glActiveTexture gl3wActiveTexture +#define glSampleCoverage gl3wSampleCoverage +#define glCompressedTexImage3D gl3wCompressedTexImage3D +#define glCompressedTexImage2D gl3wCompressedTexImage2D +#define glCompressedTexImage1D gl3wCompressedTexImage1D +#define glCompressedTexSubImage3D gl3wCompressedTexSubImage3D +#define glCompressedTexSubImage2D gl3wCompressedTexSubImage2D +#define glCompressedTexSubImage1D gl3wCompressedTexSubImage1D +#define glGetCompressedTexImage gl3wGetCompressedTexImage +#define glBlendFuncSeparate gl3wBlendFuncSeparate +#define glMultiDrawArrays gl3wMultiDrawArrays +#define glMultiDrawElements gl3wMultiDrawElements +#define glPointParameterf gl3wPointParameterf +#define glPointParameterfv gl3wPointParameterfv +#define glPointParameteri gl3wPointParameteri +#define glPointParameteriv gl3wPointParameteriv +#define glGenQueries gl3wGenQueries +#define glDeleteQueries gl3wDeleteQueries +#define glIsQuery gl3wIsQuery +#define glBeginQuery gl3wBeginQuery +#define glEndQuery gl3wEndQuery +#define glGetQueryiv gl3wGetQueryiv +#define glGetQueryObjectiv gl3wGetQueryObjectiv +#define glGetQueryObjectuiv gl3wGetQueryObjectuiv +#define glBindBuffer gl3wBindBuffer +#define glDeleteBuffers gl3wDeleteBuffers +#define glGenBuffers gl3wGenBuffers +#define glIsBuffer gl3wIsBuffer +#define glBufferData gl3wBufferData +#define glBufferSubData gl3wBufferSubData +#define glGetBufferSubData gl3wGetBufferSubData +#define glMapBuffer gl3wMapBuffer +#define glUnmapBuffer gl3wUnmapBuffer +#define glGetBufferParameteriv gl3wGetBufferParameteriv +#define glGetBufferPointerv gl3wGetBufferPointerv +#define glBlendEquationSeparate gl3wBlendEquationSeparate +#define glDrawBuffers gl3wDrawBuffers +#define glStencilOpSeparate gl3wStencilOpSeparate +#define glStencilFuncSeparate gl3wStencilFuncSeparate +#define glStencilMaskSeparate gl3wStencilMaskSeparate +#define glAttachShader gl3wAttachShader +#define glBindAttribLocation gl3wBindAttribLocation +#define glCompileShader gl3wCompileShader +#define glCreateProgram gl3wCreateProgram +#define glCreateShader gl3wCreateShader +#define glDeleteProgram gl3wDeleteProgram +#define glDeleteShader gl3wDeleteShader +#define glDetachShader gl3wDetachShader +#define glDisableVertexAttribArray gl3wDisableVertexAttribArray +#define glEnableVertexAttribArray gl3wEnableVertexAttribArray +#define glGetActiveAttrib gl3wGetActiveAttrib +#define glGetActiveUniform gl3wGetActiveUniform +#define glGetAttachedShaders gl3wGetAttachedShaders +#define glGetAttribLocation gl3wGetAttribLocation +#define glGetProgramiv gl3wGetProgramiv +#define glGetProgramInfoLog gl3wGetProgramInfoLog +#define glGetShaderiv gl3wGetShaderiv +#define glGetShaderInfoLog gl3wGetShaderInfoLog +#define glGetShaderSource gl3wGetShaderSource +#define glGetUniformLocation gl3wGetUniformLocation +#define glGetUniformfv gl3wGetUniformfv +#define glGetUniformiv gl3wGetUniformiv +#define glGetVertexAttribdv gl3wGetVertexAttribdv +#define glGetVertexAttribfv gl3wGetVertexAttribfv +#define glGetVertexAttribiv gl3wGetVertexAttribiv +#define glGetVertexAttribPointerv gl3wGetVertexAttribPointerv +#define glIsProgram gl3wIsProgram +#define glIsShader gl3wIsShader +#define glLinkProgram gl3wLinkProgram +#define glShaderSource gl3wShaderSource +#define glUseProgram gl3wUseProgram +#define glUniform1f gl3wUniform1f +#define glUniform2f gl3wUniform2f +#define glUniform3f gl3wUniform3f +#define glUniform4f gl3wUniform4f +#define glUniform1i gl3wUniform1i +#define glUniform2i gl3wUniform2i +#define glUniform3i gl3wUniform3i +#define glUniform4i gl3wUniform4i +#define glUniform1fv gl3wUniform1fv +#define glUniform2fv gl3wUniform2fv +#define glUniform3fv gl3wUniform3fv +#define glUniform4fv gl3wUniform4fv +#define glUniform1iv gl3wUniform1iv +#define glUniform2iv gl3wUniform2iv +#define glUniform3iv gl3wUniform3iv +#define glUniform4iv gl3wUniform4iv +#define glUniformMatrix2fv gl3wUniformMatrix2fv +#define glUniformMatrix3fv gl3wUniformMatrix3fv +#define glUniformMatrix4fv gl3wUniformMatrix4fv +#define glValidateProgram gl3wValidateProgram +#define glVertexAttrib1d gl3wVertexAttrib1d +#define glVertexAttrib1dv gl3wVertexAttrib1dv +#define glVertexAttrib1f gl3wVertexAttrib1f +#define glVertexAttrib1fv gl3wVertexAttrib1fv +#define glVertexAttrib1s gl3wVertexAttrib1s +#define glVertexAttrib1sv gl3wVertexAttrib1sv +#define glVertexAttrib2d gl3wVertexAttrib2d +#define glVertexAttrib2dv gl3wVertexAttrib2dv +#define glVertexAttrib2f gl3wVertexAttrib2f +#define glVertexAttrib2fv gl3wVertexAttrib2fv +#define glVertexAttrib2s gl3wVertexAttrib2s +#define glVertexAttrib2sv gl3wVertexAttrib2sv +#define glVertexAttrib3d gl3wVertexAttrib3d +#define glVertexAttrib3dv gl3wVertexAttrib3dv +#define glVertexAttrib3f gl3wVertexAttrib3f +#define glVertexAttrib3fv gl3wVertexAttrib3fv +#define glVertexAttrib3s gl3wVertexAttrib3s +#define glVertexAttrib3sv gl3wVertexAttrib3sv +#define glVertexAttrib4Nbv gl3wVertexAttrib4Nbv +#define glVertexAttrib4Niv gl3wVertexAttrib4Niv +#define glVertexAttrib4Nsv gl3wVertexAttrib4Nsv +#define glVertexAttrib4Nub gl3wVertexAttrib4Nub +#define glVertexAttrib4Nubv gl3wVertexAttrib4Nubv +#define glVertexAttrib4Nuiv gl3wVertexAttrib4Nuiv +#define glVertexAttrib4Nusv gl3wVertexAttrib4Nusv +#define glVertexAttrib4bv gl3wVertexAttrib4bv +#define glVertexAttrib4d gl3wVertexAttrib4d +#define glVertexAttrib4dv gl3wVertexAttrib4dv +#define glVertexAttrib4f gl3wVertexAttrib4f +#define glVertexAttrib4fv gl3wVertexAttrib4fv +#define glVertexAttrib4iv gl3wVertexAttrib4iv +#define glVertexAttrib4s gl3wVertexAttrib4s +#define glVertexAttrib4sv gl3wVertexAttrib4sv +#define glVertexAttrib4ubv gl3wVertexAttrib4ubv +#define glVertexAttrib4uiv gl3wVertexAttrib4uiv +#define glVertexAttrib4usv gl3wVertexAttrib4usv +#define glVertexAttribPointer gl3wVertexAttribPointer +#define glUniformMatrix2x3fv gl3wUniformMatrix2x3fv +#define glUniformMatrix3x2fv gl3wUniformMatrix3x2fv +#define glUniformMatrix2x4fv gl3wUniformMatrix2x4fv +#define glUniformMatrix4x2fv gl3wUniformMatrix4x2fv +#define glUniformMatrix3x4fv gl3wUniformMatrix3x4fv +#define glUniformMatrix4x3fv gl3wUniformMatrix4x3fv +#define glColorMaski gl3wColorMaski +#define glGetBooleani_v gl3wGetBooleani_v +#define glGetIntegeri_v gl3wGetIntegeri_v +#define glEnablei gl3wEnablei +#define glDisablei gl3wDisablei +#define glIsEnabledi gl3wIsEnabledi +#define glBeginTransformFeedback gl3wBeginTransformFeedback +#define glEndTransformFeedback gl3wEndTransformFeedback +#define glBindBufferRange gl3wBindBufferRange +#define glBindBufferBase gl3wBindBufferBase +#define glTransformFeedbackVaryings gl3wTransformFeedbackVaryings +#define glGetTransformFeedbackVarying gl3wGetTransformFeedbackVarying +#define glClampColor gl3wClampColor +#define glBeginConditionalRender gl3wBeginConditionalRender +#define glEndConditionalRender gl3wEndConditionalRender +#define glVertexAttribIPointer gl3wVertexAttribIPointer +#define glGetVertexAttribIiv gl3wGetVertexAttribIiv +#define glGetVertexAttribIuiv gl3wGetVertexAttribIuiv +#define glVertexAttribI1i gl3wVertexAttribI1i +#define glVertexAttribI2i gl3wVertexAttribI2i +#define glVertexAttribI3i gl3wVertexAttribI3i +#define glVertexAttribI4i gl3wVertexAttribI4i +#define glVertexAttribI1ui gl3wVertexAttribI1ui +#define glVertexAttribI2ui gl3wVertexAttribI2ui +#define glVertexAttribI3ui gl3wVertexAttribI3ui +#define glVertexAttribI4ui gl3wVertexAttribI4ui +#define glVertexAttribI1iv gl3wVertexAttribI1iv +#define glVertexAttribI2iv gl3wVertexAttribI2iv +#define glVertexAttribI3iv gl3wVertexAttribI3iv +#define glVertexAttribI4iv gl3wVertexAttribI4iv +#define glVertexAttribI1uiv gl3wVertexAttribI1uiv +#define glVertexAttribI2uiv gl3wVertexAttribI2uiv +#define glVertexAttribI3uiv gl3wVertexAttribI3uiv +#define glVertexAttribI4uiv gl3wVertexAttribI4uiv +#define glVertexAttribI4bv gl3wVertexAttribI4bv +#define glVertexAttribI4sv gl3wVertexAttribI4sv +#define glVertexAttribI4ubv gl3wVertexAttribI4ubv +#define glVertexAttribI4usv gl3wVertexAttribI4usv +#define glGetUniformuiv gl3wGetUniformuiv +#define glBindFragDataLocation gl3wBindFragDataLocation +#define glGetFragDataLocation gl3wGetFragDataLocation +#define glUniform1ui gl3wUniform1ui +#define glUniform2ui gl3wUniform2ui +#define glUniform3ui gl3wUniform3ui +#define glUniform4ui gl3wUniform4ui +#define glUniform1uiv gl3wUniform1uiv +#define glUniform2uiv gl3wUniform2uiv +#define glUniform3uiv gl3wUniform3uiv +#define glUniform4uiv gl3wUniform4uiv +#define glTexParameterIiv gl3wTexParameterIiv +#define glTexParameterIuiv gl3wTexParameterIuiv +#define glGetTexParameterIiv gl3wGetTexParameterIiv +#define glGetTexParameterIuiv gl3wGetTexParameterIuiv +#define glClearBufferiv gl3wClearBufferiv +#define glClearBufferuiv gl3wClearBufferuiv +#define glClearBufferfv gl3wClearBufferfv +#define glClearBufferfi gl3wClearBufferfi +#define glGetStringi gl3wGetStringi +#define glDrawArraysInstanced gl3wDrawArraysInstanced +#define glDrawElementsInstanced gl3wDrawElementsInstanced +#define glTexBuffer gl3wTexBuffer +#define glPrimitiveRestartIndex gl3wPrimitiveRestartIndex +#define glGetInteger64i_v gl3wGetInteger64i_v +#define glGetBufferParameteri64v gl3wGetBufferParameteri64v +#define glFramebufferTexture gl3wFramebufferTexture +#define glVertexAttribDivisor gl3wVertexAttribDivisor +#define glMinSampleShading gl3wMinSampleShading +#define glBlendEquationi gl3wBlendEquationi +#define glBlendEquationSeparatei gl3wBlendEquationSeparatei +#define glBlendFunci gl3wBlendFunci +#define glBlendFuncSeparatei gl3wBlendFuncSeparatei +#define glIsRenderbuffer gl3wIsRenderbuffer +#define glBindRenderbuffer gl3wBindRenderbuffer +#define glDeleteRenderbuffers gl3wDeleteRenderbuffers +#define glGenRenderbuffers gl3wGenRenderbuffers +#define glRenderbufferStorage gl3wRenderbufferStorage +#define glGetRenderbufferParameteriv gl3wGetRenderbufferParameteriv +#define glIsFramebuffer gl3wIsFramebuffer +#define glBindFramebuffer gl3wBindFramebuffer +#define glDeleteFramebuffers gl3wDeleteFramebuffers +#define glGenFramebuffers gl3wGenFramebuffers +#define glCheckFramebufferStatus gl3wCheckFramebufferStatus +#define glFramebufferTexture1D gl3wFramebufferTexture1D +#define glFramebufferTexture2D gl3wFramebufferTexture2D +#define glFramebufferTexture3D gl3wFramebufferTexture3D +#define glFramebufferRenderbuffer gl3wFramebufferRenderbuffer +#define glGetFramebufferAttachmentParameteriv gl3wGetFramebufferAttachmentParameteriv +#define glGenerateMipmap gl3wGenerateMipmap +#define glBlitFramebuffer gl3wBlitFramebuffer +#define glRenderbufferStorageMultisample gl3wRenderbufferStorageMultisample +#define glFramebufferTextureLayer gl3wFramebufferTextureLayer +#define glMapBufferRange gl3wMapBufferRange +#define glFlushMappedBufferRange gl3wFlushMappedBufferRange +#define glBindVertexArray gl3wBindVertexArray +#define glDeleteVertexArrays gl3wDeleteVertexArrays +#define glGenVertexArrays gl3wGenVertexArrays +#define glIsVertexArray gl3wIsVertexArray +#define glGetUniformIndices gl3wGetUniformIndices +#define glGetActiveUniformsiv gl3wGetActiveUniformsiv +#define glGetActiveUniformName gl3wGetActiveUniformName +#define glGetUniformBlockIndex gl3wGetUniformBlockIndex +#define glGetActiveUniformBlockiv gl3wGetActiveUniformBlockiv +#define glGetActiveUniformBlockName gl3wGetActiveUniformBlockName +#define glUniformBlockBinding gl3wUniformBlockBinding +#define glCopyBufferSubData gl3wCopyBufferSubData +#define glDrawElementsBaseVertex gl3wDrawElementsBaseVertex +#define glDrawRangeElementsBaseVertex gl3wDrawRangeElementsBaseVertex +#define glDrawElementsInstancedBaseVertex gl3wDrawElementsInstancedBaseVertex +#define glMultiDrawElementsBaseVertex gl3wMultiDrawElementsBaseVertex +#define glProvokingVertex gl3wProvokingVertex +#define glFenceSync gl3wFenceSync +#define glIsSync gl3wIsSync +#define glDeleteSync gl3wDeleteSync +#define glClientWaitSync gl3wClientWaitSync +#define glWaitSync gl3wWaitSync +#define glGetInteger64v gl3wGetInteger64v +#define glGetSynciv gl3wGetSynciv +#define glTexImage2DMultisample gl3wTexImage2DMultisample +#define glTexImage3DMultisample gl3wTexImage3DMultisample +#define glGetMultisamplefv gl3wGetMultisamplefv +#define glSampleMaski gl3wSampleMaski +#define glBlendEquationiARB gl3wBlendEquationiARB +#define glBlendEquationSeparateiARB gl3wBlendEquationSeparateiARB +#define glBlendFunciARB gl3wBlendFunciARB +#define glBlendFuncSeparateiARB gl3wBlendFuncSeparateiARB +#define glMinSampleShadingARB gl3wMinSampleShadingARB +#define glNamedStringARB gl3wNamedStringARB +#define glDeleteNamedStringARB gl3wDeleteNamedStringARB +#define glCompileShaderIncludeARB gl3wCompileShaderIncludeARB +#define glIsNamedStringARB gl3wIsNamedStringARB +#define glGetNamedStringARB gl3wGetNamedStringARB +#define glGetNamedStringivARB gl3wGetNamedStringivARB +#define glBindFragDataLocationIndexed gl3wBindFragDataLocationIndexed +#define glGetFragDataIndex gl3wGetFragDataIndex +#define glGenSamplers gl3wGenSamplers +#define glDeleteSamplers gl3wDeleteSamplers +#define glIsSampler gl3wIsSampler +#define glBindSampler gl3wBindSampler +#define glSamplerParameteri gl3wSamplerParameteri +#define glSamplerParameteriv gl3wSamplerParameteriv +#define glSamplerParameterf gl3wSamplerParameterf +#define glSamplerParameterfv gl3wSamplerParameterfv +#define glSamplerParameterIiv gl3wSamplerParameterIiv +#define glSamplerParameterIuiv gl3wSamplerParameterIuiv +#define glGetSamplerParameteriv gl3wGetSamplerParameteriv +#define glGetSamplerParameterIiv gl3wGetSamplerParameterIiv +#define glGetSamplerParameterfv gl3wGetSamplerParameterfv +#define glGetSamplerParameterIuiv gl3wGetSamplerParameterIuiv +#define glQueryCounter gl3wQueryCounter +#define glGetQueryObjecti64v gl3wGetQueryObjecti64v +#define glGetQueryObjectui64v gl3wGetQueryObjectui64v +#define glVertexP2ui gl3wVertexP2ui +#define glVertexP2uiv gl3wVertexP2uiv +#define glVertexP3ui gl3wVertexP3ui +#define glVertexP3uiv gl3wVertexP3uiv +#define glVertexP4ui gl3wVertexP4ui +#define glVertexP4uiv gl3wVertexP4uiv +#define glTexCoordP1ui gl3wTexCoordP1ui +#define glTexCoordP1uiv gl3wTexCoordP1uiv +#define glTexCoordP2ui gl3wTexCoordP2ui +#define glTexCoordP2uiv gl3wTexCoordP2uiv +#define glTexCoordP3ui gl3wTexCoordP3ui +#define glTexCoordP3uiv gl3wTexCoordP3uiv +#define glTexCoordP4ui gl3wTexCoordP4ui +#define glTexCoordP4uiv gl3wTexCoordP4uiv +#define glMultiTexCoordP1ui gl3wMultiTexCoordP1ui +#define glMultiTexCoordP1uiv gl3wMultiTexCoordP1uiv +#define glMultiTexCoordP2ui gl3wMultiTexCoordP2ui +#define glMultiTexCoordP2uiv gl3wMultiTexCoordP2uiv +#define glMultiTexCoordP3ui gl3wMultiTexCoordP3ui +#define glMultiTexCoordP3uiv gl3wMultiTexCoordP3uiv +#define glMultiTexCoordP4ui gl3wMultiTexCoordP4ui +#define glMultiTexCoordP4uiv gl3wMultiTexCoordP4uiv +#define glNormalP3ui gl3wNormalP3ui +#define glNormalP3uiv gl3wNormalP3uiv +#define glColorP3ui gl3wColorP3ui +#define glColorP3uiv gl3wColorP3uiv +#define glColorP4ui gl3wColorP4ui +#define glColorP4uiv gl3wColorP4uiv +#define glSecondaryColorP3ui gl3wSecondaryColorP3ui +#define glSecondaryColorP3uiv gl3wSecondaryColorP3uiv +#define glVertexAttribP1ui gl3wVertexAttribP1ui +#define glVertexAttribP1uiv gl3wVertexAttribP1uiv +#define glVertexAttribP2ui gl3wVertexAttribP2ui +#define glVertexAttribP2uiv gl3wVertexAttribP2uiv +#define glVertexAttribP3ui gl3wVertexAttribP3ui +#define glVertexAttribP3uiv gl3wVertexAttribP3uiv +#define glVertexAttribP4ui gl3wVertexAttribP4ui +#define glVertexAttribP4uiv gl3wVertexAttribP4uiv +#define glDrawArraysIndirect gl3wDrawArraysIndirect +#define glDrawElementsIndirect gl3wDrawElementsIndirect +#define glUniform1d gl3wUniform1d +#define glUniform2d gl3wUniform2d +#define glUniform3d gl3wUniform3d +#define glUniform4d gl3wUniform4d +#define glUniform1dv gl3wUniform1dv +#define glUniform2dv gl3wUniform2dv +#define glUniform3dv gl3wUniform3dv +#define glUniform4dv gl3wUniform4dv +#define glUniformMatrix2dv gl3wUniformMatrix2dv +#define glUniformMatrix3dv gl3wUniformMatrix3dv +#define glUniformMatrix4dv gl3wUniformMatrix4dv +#define glUniformMatrix2x3dv gl3wUniformMatrix2x3dv +#define glUniformMatrix2x4dv gl3wUniformMatrix2x4dv +#define glUniformMatrix3x2dv gl3wUniformMatrix3x2dv +#define glUniformMatrix3x4dv gl3wUniformMatrix3x4dv +#define glUniformMatrix4x2dv gl3wUniformMatrix4x2dv +#define glUniformMatrix4x3dv gl3wUniformMatrix4x3dv +#define glGetUniformdv gl3wGetUniformdv +#define glGetSubroutineUniformLocation gl3wGetSubroutineUniformLocation +#define glGetSubroutineIndex gl3wGetSubroutineIndex +#define glGetActiveSubroutineUniformiv gl3wGetActiveSubroutineUniformiv +#define glGetActiveSubroutineUniformName gl3wGetActiveSubroutineUniformName +#define glGetActiveSubroutineName gl3wGetActiveSubroutineName +#define glUniformSubroutinesuiv gl3wUniformSubroutinesuiv +#define glGetUniformSubroutineuiv gl3wGetUniformSubroutineuiv +#define glGetProgramStageiv gl3wGetProgramStageiv +#define glPatchParameteri gl3wPatchParameteri +#define glPatchParameterfv gl3wPatchParameterfv +#define glBindTransformFeedback gl3wBindTransformFeedback +#define glDeleteTransformFeedbacks gl3wDeleteTransformFeedbacks +#define glGenTransformFeedbacks gl3wGenTransformFeedbacks +#define glIsTransformFeedback gl3wIsTransformFeedback +#define glPauseTransformFeedback gl3wPauseTransformFeedback +#define glResumeTransformFeedback gl3wResumeTransformFeedback +#define glDrawTransformFeedback gl3wDrawTransformFeedback +#define glDrawTransformFeedbackStream gl3wDrawTransformFeedbackStream +#define glBeginQueryIndexed gl3wBeginQueryIndexed +#define glEndQueryIndexed gl3wEndQueryIndexed +#define glGetQueryIndexediv gl3wGetQueryIndexediv +#define glReleaseShaderCompiler gl3wReleaseShaderCompiler +#define glShaderBinary gl3wShaderBinary +#define glGetShaderPrecisionFormat gl3wGetShaderPrecisionFormat +#define glDepthRangef gl3wDepthRangef +#define glClearDepthf gl3wClearDepthf +#define glGetProgramBinary gl3wGetProgramBinary +#define glProgramBinary gl3wProgramBinary +#define glProgramParameteri gl3wProgramParameteri +#define glUseProgramStages gl3wUseProgramStages +#define glActiveShaderProgram gl3wActiveShaderProgram +#define glCreateShaderProgramv gl3wCreateShaderProgramv +#define glBindProgramPipeline gl3wBindProgramPipeline +#define glDeleteProgramPipelines gl3wDeleteProgramPipelines +#define glGenProgramPipelines gl3wGenProgramPipelines +#define glIsProgramPipeline gl3wIsProgramPipeline +#define glGetProgramPipelineiv gl3wGetProgramPipelineiv +#define glProgramUniform1i gl3wProgramUniform1i +#define glProgramUniform1iv gl3wProgramUniform1iv +#define glProgramUniform1f gl3wProgramUniform1f +#define glProgramUniform1fv gl3wProgramUniform1fv +#define glProgramUniform1d gl3wProgramUniform1d +#define glProgramUniform1dv gl3wProgramUniform1dv +#define glProgramUniform1ui gl3wProgramUniform1ui +#define glProgramUniform1uiv gl3wProgramUniform1uiv +#define glProgramUniform2i gl3wProgramUniform2i +#define glProgramUniform2iv gl3wProgramUniform2iv +#define glProgramUniform2f gl3wProgramUniform2f +#define glProgramUniform2fv gl3wProgramUniform2fv +#define glProgramUniform2d gl3wProgramUniform2d +#define glProgramUniform2dv gl3wProgramUniform2dv +#define glProgramUniform2ui gl3wProgramUniform2ui +#define glProgramUniform2uiv gl3wProgramUniform2uiv +#define glProgramUniform3i gl3wProgramUniform3i +#define glProgramUniform3iv gl3wProgramUniform3iv +#define glProgramUniform3f gl3wProgramUniform3f +#define glProgramUniform3fv gl3wProgramUniform3fv +#define glProgramUniform3d gl3wProgramUniform3d +#define glProgramUniform3dv gl3wProgramUniform3dv +#define glProgramUniform3ui gl3wProgramUniform3ui +#define glProgramUniform3uiv gl3wProgramUniform3uiv +#define glProgramUniform4i gl3wProgramUniform4i +#define glProgramUniform4iv gl3wProgramUniform4iv +#define glProgramUniform4f gl3wProgramUniform4f +#define glProgramUniform4fv gl3wProgramUniform4fv +#define glProgramUniform4d gl3wProgramUniform4d +#define glProgramUniform4dv gl3wProgramUniform4dv +#define glProgramUniform4ui gl3wProgramUniform4ui +#define glProgramUniform4uiv gl3wProgramUniform4uiv +#define glProgramUniformMatrix2fv gl3wProgramUniformMatrix2fv +#define glProgramUniformMatrix3fv gl3wProgramUniformMatrix3fv +#define glProgramUniformMatrix4fv gl3wProgramUniformMatrix4fv +#define glProgramUniformMatrix2dv gl3wProgramUniformMatrix2dv +#define glProgramUniformMatrix3dv gl3wProgramUniformMatrix3dv +#define glProgramUniformMatrix4dv gl3wProgramUniformMatrix4dv +#define glProgramUniformMatrix2x3fv gl3wProgramUniformMatrix2x3fv +#define glProgramUniformMatrix3x2fv gl3wProgramUniformMatrix3x2fv +#define glProgramUniformMatrix2x4fv gl3wProgramUniformMatrix2x4fv +#define glProgramUniformMatrix4x2fv gl3wProgramUniformMatrix4x2fv +#define glProgramUniformMatrix3x4fv gl3wProgramUniformMatrix3x4fv +#define glProgramUniformMatrix4x3fv gl3wProgramUniformMatrix4x3fv +#define glProgramUniformMatrix2x3dv gl3wProgramUniformMatrix2x3dv +#define glProgramUniformMatrix3x2dv gl3wProgramUniformMatrix3x2dv +#define glProgramUniformMatrix2x4dv gl3wProgramUniformMatrix2x4dv +#define glProgramUniformMatrix4x2dv gl3wProgramUniformMatrix4x2dv +#define glProgramUniformMatrix3x4dv gl3wProgramUniformMatrix3x4dv +#define glProgramUniformMatrix4x3dv gl3wProgramUniformMatrix4x3dv +#define glValidateProgramPipeline gl3wValidateProgramPipeline +#define glGetProgramPipelineInfoLog gl3wGetProgramPipelineInfoLog +#define glVertexAttribL1d gl3wVertexAttribL1d +#define glVertexAttribL2d gl3wVertexAttribL2d +#define glVertexAttribL3d gl3wVertexAttribL3d +#define glVertexAttribL4d gl3wVertexAttribL4d +#define glVertexAttribL1dv gl3wVertexAttribL1dv +#define glVertexAttribL2dv gl3wVertexAttribL2dv +#define glVertexAttribL3dv gl3wVertexAttribL3dv +#define glVertexAttribL4dv gl3wVertexAttribL4dv +#define glVertexAttribLPointer gl3wVertexAttribLPointer +#define glGetVertexAttribLdv gl3wGetVertexAttribLdv +#define glViewportArrayv gl3wViewportArrayv +#define glViewportIndexedf gl3wViewportIndexedf +#define glViewportIndexedfv gl3wViewportIndexedfv +#define glScissorArrayv gl3wScissorArrayv +#define glScissorIndexed gl3wScissorIndexed +#define glScissorIndexedv gl3wScissorIndexedv +#define glDepthRangeArrayv gl3wDepthRangeArrayv +#define glDepthRangeIndexed gl3wDepthRangeIndexed +#define glGetFloati_v gl3wGetFloati_v +#define glGetDoublei_v gl3wGetDoublei_v +#define glCreateSyncFromCLeventARB gl3wCreateSyncFromCLeventARB +#define glDebugMessageControlARB gl3wDebugMessageControlARB +#define glDebugMessageInsertARB gl3wDebugMessageInsertARB +#define glDebugMessageCallbackARB gl3wDebugMessageCallbackARB +#define glGetDebugMessageLogARB gl3wGetDebugMessageLogARB +#define glGetGraphicsResetStatusARB gl3wGetGraphicsResetStatusARB +#define glGetnTexImageARB gl3wGetnTexImageARB +#define glReadnPixelsARB gl3wReadnPixelsARB +#define glGetnCompressedTexImageARB gl3wGetnCompressedTexImageARB +#define glGetnUniformfvARB gl3wGetnUniformfvARB +#define glGetnUniformivARB gl3wGetnUniformivARB +#define glGetnUniformuivARB gl3wGetnUniformuivARB +#define glGetnUniformdvARB gl3wGetnUniformdvARB +#define glDrawArraysInstancedBaseInstance gl3wDrawArraysInstancedBaseInstance +#define glDrawElementsInstancedBaseInstance gl3wDrawElementsInstancedBaseInstance +#define glDrawElementsInstancedBaseVertexBaseInstance gl3wDrawElementsInstancedBaseVertexBaseInstance +#define glDrawTransformFeedbackInstanced gl3wDrawTransformFeedbackInstanced +#define glDrawTransformFeedbackStreamInstanced gl3wDrawTransformFeedbackStreamInstanced +#define glGetInternalformativ gl3wGetInternalformativ +#define glGetActiveAtomicCounterBufferiv gl3wGetActiveAtomicCounterBufferiv +#define glBindImageTexture gl3wBindImageTexture +#define glMemoryBarrier gl3wMemoryBarrier +#define glTexStorage1D gl3wTexStorage1D +#define glTexStorage2D gl3wTexStorage2D +#define glTexStorage3D gl3wTexStorage3D +#define glTextureStorage1DEXT gl3wTextureStorage1DEXT +#define glTextureStorage2DEXT gl3wTextureStorage2DEXT +#define glTextureStorage3DEXT gl3wTextureStorage3DEXT +#define glDebugMessageControl gl3wDebugMessageControl +#define glDebugMessageInsert gl3wDebugMessageInsert +#define glDebugMessageCallback gl3wDebugMessageCallback +#define glGetDebugMessageLog gl3wGetDebugMessageLog +#define glPushDebugGroup gl3wPushDebugGroup +#define glPopDebugGroup gl3wPopDebugGroup +#define glObjectLabel gl3wObjectLabel +#define glGetObjectLabel gl3wGetObjectLabel +#define glObjectPtrLabel gl3wObjectPtrLabel +#define glGetObjectPtrLabel gl3wGetObjectPtrLabel +#define glClearBufferData gl3wClearBufferData +#define glClearBufferSubData gl3wClearBufferSubData +#define glClearNamedBufferDataEXT gl3wClearNamedBufferDataEXT +#define glClearNamedBufferSubDataEXT gl3wClearNamedBufferSubDataEXT +#define glDispatchCompute gl3wDispatchCompute +#define glDispatchComputeIndirect gl3wDispatchComputeIndirect +#define glCopyImageSubData gl3wCopyImageSubData +#define glTextureView gl3wTextureView +#define glBindVertexBuffer gl3wBindVertexBuffer +#define glVertexAttribFormat gl3wVertexAttribFormat +#define glVertexAttribIFormat gl3wVertexAttribIFormat +#define glVertexAttribLFormat gl3wVertexAttribLFormat +#define glVertexAttribBinding gl3wVertexAttribBinding +#define glVertexBindingDivisor gl3wVertexBindingDivisor +#define glVertexArrayBindVertexBufferEXT gl3wVertexArrayBindVertexBufferEXT +#define glVertexArrayVertexAttribFormatEXT gl3wVertexArrayVertexAttribFormatEXT +#define glVertexArrayVertexAttribIFormatEXT gl3wVertexArrayVertexAttribIFormatEXT +#define glVertexArrayVertexAttribLFormatEXT gl3wVertexArrayVertexAttribLFormatEXT +#define glVertexArrayVertexAttribBindingEXT gl3wVertexArrayVertexAttribBindingEXT +#define glVertexArrayVertexBindingDivisorEXT gl3wVertexArrayVertexBindingDivisorEXT +#define glFramebufferParameteri gl3wFramebufferParameteri +#define glGetFramebufferParameteriv gl3wGetFramebufferParameteriv +#define glNamedFramebufferParameteriEXT gl3wNamedFramebufferParameteriEXT +#define glGetNamedFramebufferParameterivEXT gl3wGetNamedFramebufferParameterivEXT +#define glGetInternalformati64v gl3wGetInternalformati64v +#define glInvalidateTexSubImage gl3wInvalidateTexSubImage +#define glInvalidateTexImage gl3wInvalidateTexImage +#define glInvalidateBufferSubData gl3wInvalidateBufferSubData +#define glInvalidateBufferData gl3wInvalidateBufferData +#define glInvalidateFramebuffer gl3wInvalidateFramebuffer +#define glInvalidateSubFramebuffer gl3wInvalidateSubFramebuffer +#define glMultiDrawArraysIndirect gl3wMultiDrawArraysIndirect +#define glMultiDrawElementsIndirect gl3wMultiDrawElementsIndirect +#define glGetProgramInterfaceiv gl3wGetProgramInterfaceiv +#define glGetProgramResourceIndex gl3wGetProgramResourceIndex +#define glGetProgramResourceName gl3wGetProgramResourceName +#define glGetProgramResourceiv gl3wGetProgramResourceiv +#define glGetProgramResourceLocation gl3wGetProgramResourceLocation +#define glGetProgramResourceLocationIndex gl3wGetProgramResourceLocationIndex +#define glShaderStorageBlockBinding gl3wShaderStorageBlockBinding +#define glTexBufferRange gl3wTexBufferRange +#define glTextureBufferRangeEXT gl3wTextureBufferRangeEXT +#define glTexStorage2DMultisample gl3wTexStorage2DMultisample +#define glTexStorage3DMultisample gl3wTexStorage3DMultisample +#define glTextureStorage2DMultisampleEXT gl3wTextureStorage2DMultisampleEXT +#define glTextureStorage3DMultisampleEXT gl3wTextureStorage3DMultisampleEXT + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/gui/dependencies/gl3w/GL/glcorearb.h b/gui/dependencies/gl3w/GL/glcorearb.h new file mode 100644 index 0000000000000000000000000000000000000000..07cb03e13af4191cfa52435ee9d055977036fad8 --- /dev/null +++ b/gui/dependencies/gl3w/GL/glcorearb.h @@ -0,0 +1,4533 @@ +#ifndef __glcorearb_h_ +#define __glcorearb_h_ + +#ifdef __cplusplus +extern "C" { +#endif + +/* +** Copyright (c) 2007-2012 The Khronos Group Inc. +** +** Permission is hereby granted, free of charge, to any person obtaining a +** copy of this software and/or associated documentation files (the +** "Materials"), to deal in the Materials without restriction, including +** without limitation the rights to use, copy, modify, merge, publish, +** distribute, sublicense, and/or sell copies of the Materials, and to +** permit persons to whom the Materials are furnished to do so, subject to +** the following conditions: +** +** The above copyright notice and this permission notice shall be included +** in all copies or substantial portions of the Materials. +** +** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS. +*/ + +/* glcorearb.h replaces gl3.h. It is for use with OpenGL core + * profile implementations. + * + * glcorearb.h last updated on $Date: 2012-09-19 19:02:24 -0700 (Wed, 19 Sep 2012) $ + * + * RELEASE NOTES - 2012/09/19 + * + * glcorearb.h should be placed in the same directory as gl.h and + * included as + * ''. + * + * glcorearb.h includes only APIs in the latest OpenGL core profile + * implementation together with APIs in newer ARB extensions which can be + * can be supported by the core profile. It does not, and never will + * include functionality removed from the core profile, such as + * fixed-function vertex and fragment processing. + * + * It is not possible to #include both and either of + * or in the same source file. + * + * Feedback can be given by registering for the Khronos Bugzilla + * (www.khronos.org/bugzilla) and filing issues there under product + * "OpenGL", category "Registry". + */ + +/* Function declaration macros - to move into glplatform.h */ + +#if defined(_WIN32) && !defined(APIENTRY) && !defined(__CYGWIN__) && !defined(__SCITECH_SNAP__) +#define WIN32_LEAN_AND_MEAN 1 +#include +#endif + +#ifndef APIENTRY +#define APIENTRY +#endif +#ifndef APIENTRYP +#define APIENTRYP APIENTRY * +#endif +#ifndef GLAPI +#define GLAPI extern +#endif + +/* Base GL types */ + +typedef unsigned int GLenum; +typedef unsigned char GLboolean; +typedef unsigned int GLbitfield; +typedef signed char GLbyte; +typedef short GLshort; +typedef int GLint; +typedef int GLsizei; +typedef unsigned char GLubyte; +typedef unsigned short GLushort; +typedef unsigned int GLuint; +typedef unsigned short GLhalf; +typedef float GLfloat; +typedef float GLclampf; +typedef double GLdouble; +typedef double GLclampd; +typedef void GLvoid; + +/*************************************************************/ + +#ifndef GL_VERSION_1_1 +/* AttribMask */ +#define GL_DEPTH_BUFFER_BIT 0x00000100 +#define GL_STENCIL_BUFFER_BIT 0x00000400 +#define GL_COLOR_BUFFER_BIT 0x00004000 +/* Boolean */ +#define GL_FALSE 0 +#define GL_TRUE 1 +/* BeginMode */ +#define GL_POINTS 0x0000 +#define GL_LINES 0x0001 +#define GL_LINE_LOOP 0x0002 +#define GL_LINE_STRIP 0x0003 +#define GL_TRIANGLES 0x0004 +#define GL_TRIANGLE_STRIP 0x0005 +#define GL_TRIANGLE_FAN 0x0006 +#define GL_QUADS 0x0007 +/* AlphaFunction */ +#define GL_NEVER 0x0200 +#define GL_LESS 0x0201 +#define GL_EQUAL 0x0202 +#define GL_LEQUAL 0x0203 +#define GL_GREATER 0x0204 +#define GL_NOTEQUAL 0x0205 +#define GL_GEQUAL 0x0206 +#define GL_ALWAYS 0x0207 +/* BlendingFactorDest */ +#define GL_ZERO 0 +#define GL_ONE 1 +#define GL_SRC_COLOR 0x0300 +#define GL_ONE_MINUS_SRC_COLOR 0x0301 +#define GL_SRC_ALPHA 0x0302 +#define GL_ONE_MINUS_SRC_ALPHA 0x0303 +#define GL_DST_ALPHA 0x0304 +#define GL_ONE_MINUS_DST_ALPHA 0x0305 +/* BlendingFactorSrc */ +#define GL_DST_COLOR 0x0306 +#define GL_ONE_MINUS_DST_COLOR 0x0307 +#define GL_SRC_ALPHA_SATURATE 0x0308 +/* DrawBufferMode */ +#define GL_NONE 0 +#define GL_FRONT_LEFT 0x0400 +#define GL_FRONT_RIGHT 0x0401 +#define GL_BACK_LEFT 0x0402 +#define GL_BACK_RIGHT 0x0403 +#define GL_FRONT 0x0404 +#define GL_BACK 0x0405 +#define GL_LEFT 0x0406 +#define GL_RIGHT 0x0407 +#define GL_FRONT_AND_BACK 0x0408 +/* ErrorCode */ +#define GL_NO_ERROR 0 +#define GL_INVALID_ENUM 0x0500 +#define GL_INVALID_VALUE 0x0501 +#define GL_INVALID_OPERATION 0x0502 +#define GL_OUT_OF_MEMORY 0x0505 +/* FrontFaceDirection */ +#define GL_CW 0x0900 +#define GL_CCW 0x0901 +/* GetPName */ +#define GL_POINT_SIZE 0x0B11 +#define GL_POINT_SIZE_RANGE 0x0B12 +#define GL_POINT_SIZE_GRANULARITY 0x0B13 +#define GL_LINE_SMOOTH 0x0B20 +#define GL_LINE_WIDTH 0x0B21 +#define GL_LINE_WIDTH_RANGE 0x0B22 +#define GL_LINE_WIDTH_GRANULARITY 0x0B23 +#define GL_POLYGON_MODE 0x0B40 +#define GL_POLYGON_SMOOTH 0x0B41 +#define GL_CULL_FACE 0x0B44 +#define GL_CULL_FACE_MODE 0x0B45 +#define GL_FRONT_FACE 0x0B46 +#define GL_DEPTH_RANGE 0x0B70 +#define GL_DEPTH_TEST 0x0B71 +#define GL_DEPTH_WRITEMASK 0x0B72 +#define GL_DEPTH_CLEAR_VALUE 0x0B73 +#define GL_DEPTH_FUNC 0x0B74 +#define GL_STENCIL_TEST 0x0B90 +#define GL_STENCIL_CLEAR_VALUE 0x0B91 +#define GL_STENCIL_FUNC 0x0B92 +#define GL_STENCIL_VALUE_MASK 0x0B93 +#define GL_STENCIL_FAIL 0x0B94 +#define GL_STENCIL_PASS_DEPTH_FAIL 0x0B95 +#define GL_STENCIL_PASS_DEPTH_PASS 0x0B96 +#define GL_STENCIL_REF 0x0B97 +#define GL_STENCIL_WRITEMASK 0x0B98 +#define GL_VIEWPORT 0x0BA2 +#define GL_DITHER 0x0BD0 +#define GL_BLEND_DST 0x0BE0 +#define GL_BLEND_SRC 0x0BE1 +#define GL_BLEND 0x0BE2 +#define GL_LOGIC_OP_MODE 0x0BF0 +#define GL_COLOR_LOGIC_OP 0x0BF2 +#define GL_DRAW_BUFFER 0x0C01 +#define GL_READ_BUFFER 0x0C02 +#define GL_SCISSOR_BOX 0x0C10 +#define GL_SCISSOR_TEST 0x0C11 +#define GL_COLOR_CLEAR_VALUE 0x0C22 +#define GL_COLOR_WRITEMASK 0x0C23 +#define GL_DOUBLEBUFFER 0x0C32 +#define GL_STEREO 0x0C33 +#define GL_LINE_SMOOTH_HINT 0x0C52 +#define GL_POLYGON_SMOOTH_HINT 0x0C53 +#define GL_UNPACK_SWAP_BYTES 0x0CF0 +#define GL_UNPACK_LSB_FIRST 0x0CF1 +#define GL_UNPACK_ROW_LENGTH 0x0CF2 +#define GL_UNPACK_SKIP_ROWS 0x0CF3 +#define GL_UNPACK_SKIP_PIXELS 0x0CF4 +#define GL_UNPACK_ALIGNMENT 0x0CF5 +#define GL_PACK_SWAP_BYTES 0x0D00 +#define GL_PACK_LSB_FIRST 0x0D01 +#define GL_PACK_ROW_LENGTH 0x0D02 +#define GL_PACK_SKIP_ROWS 0x0D03 +#define GL_PACK_SKIP_PIXELS 0x0D04 +#define GL_PACK_ALIGNMENT 0x0D05 +#define GL_MAX_TEXTURE_SIZE 0x0D33 +#define GL_MAX_VIEWPORT_DIMS 0x0D3A +#define GL_SUBPIXEL_BITS 0x0D50 +#define GL_TEXTURE_1D 0x0DE0 +#define GL_TEXTURE_2D 0x0DE1 +#define GL_POLYGON_OFFSET_UNITS 0x2A00 +#define GL_POLYGON_OFFSET_POINT 0x2A01 +#define GL_POLYGON_OFFSET_LINE 0x2A02 +#define GL_POLYGON_OFFSET_FILL 0x8037 +#define GL_POLYGON_OFFSET_FACTOR 0x8038 +#define GL_TEXTURE_BINDING_1D 0x8068 +#define GL_TEXTURE_BINDING_2D 0x8069 +/* GetTextureParameter */ +#define GL_TEXTURE_WIDTH 0x1000 +#define GL_TEXTURE_HEIGHT 0x1001 +#define GL_TEXTURE_INTERNAL_FORMAT 0x1003 +#define GL_TEXTURE_BORDER_COLOR 0x1004 +#define GL_TEXTURE_RED_SIZE 0x805C +#define GL_TEXTURE_GREEN_SIZE 0x805D +#define GL_TEXTURE_BLUE_SIZE 0x805E +#define GL_TEXTURE_ALPHA_SIZE 0x805F +/* HintMode */ +#define GL_DONT_CARE 0x1100 +#define GL_FASTEST 0x1101 +#define GL_NICEST 0x1102 +/* DataType */ +#define GL_BYTE 0x1400 +#define GL_UNSIGNED_BYTE 0x1401 +#define GL_SHORT 0x1402 +#define GL_UNSIGNED_SHORT 0x1403 +#define GL_INT 0x1404 +#define GL_UNSIGNED_INT 0x1405 +#define GL_FLOAT 0x1406 +#define GL_DOUBLE 0x140A +/* ErrorCode */ +#define GL_STACK_OVERFLOW 0x0503 +#define GL_STACK_UNDERFLOW 0x0504 +/* LogicOp */ +#define GL_CLEAR 0x1500 +#define GL_AND 0x1501 +#define GL_AND_REVERSE 0x1502 +#define GL_COPY 0x1503 +#define GL_AND_INVERTED 0x1504 +#define GL_NOOP 0x1505 +#define GL_XOR 0x1506 +#define GL_OR 0x1507 +#define GL_NOR 0x1508 +#define GL_EQUIV 0x1509 +#define GL_INVERT 0x150A +#define GL_OR_REVERSE 0x150B +#define GL_COPY_INVERTED 0x150C +#define GL_OR_INVERTED 0x150D +#define GL_NAND 0x150E +#define GL_SET 0x150F +/* MatrixMode (for gl3.h, FBO attachment type) */ +#define GL_TEXTURE 0x1702 +/* PixelCopyType */ +#define GL_COLOR 0x1800 +#define GL_DEPTH 0x1801 +#define GL_STENCIL 0x1802 +/* PixelFormat */ +#define GL_STENCIL_INDEX 0x1901 +#define GL_DEPTH_COMPONENT 0x1902 +#define GL_RED 0x1903 +#define GL_GREEN 0x1904 +#define GL_BLUE 0x1905 +#define GL_ALPHA 0x1906 +#define GL_RGB 0x1907 +#define GL_RGBA 0x1908 +/* PolygonMode */ +#define GL_POINT 0x1B00 +#define GL_LINE 0x1B01 +#define GL_FILL 0x1B02 +/* StencilOp */ +#define GL_KEEP 0x1E00 +#define GL_REPLACE 0x1E01 +#define GL_INCR 0x1E02 +#define GL_DECR 0x1E03 +/* StringName */ +#define GL_VENDOR 0x1F00 +#define GL_RENDERER 0x1F01 +#define GL_VERSION 0x1F02 +#define GL_EXTENSIONS 0x1F03 +/* TextureMagFilter */ +#define GL_NEAREST 0x2600 +#define GL_LINEAR 0x2601 +/* TextureMinFilter */ +#define GL_NEAREST_MIPMAP_NEAREST 0x2700 +#define GL_LINEAR_MIPMAP_NEAREST 0x2701 +#define GL_NEAREST_MIPMAP_LINEAR 0x2702 +#define GL_LINEAR_MIPMAP_LINEAR 0x2703 +/* TextureParameterName */ +#define GL_TEXTURE_MAG_FILTER 0x2800 +#define GL_TEXTURE_MIN_FILTER 0x2801 +#define GL_TEXTURE_WRAP_S 0x2802 +#define GL_TEXTURE_WRAP_T 0x2803 +/* TextureTarget */ +#define GL_PROXY_TEXTURE_1D 0x8063 +#define GL_PROXY_TEXTURE_2D 0x8064 +/* TextureWrapMode */ +#define GL_REPEAT 0x2901 +/* PixelInternalFormat */ +#define GL_R3_G3_B2 0x2A10 +#define GL_RGB4 0x804F +#define GL_RGB5 0x8050 +#define GL_RGB8 0x8051 +#define GL_RGB10 0x8052 +#define GL_RGB12 0x8053 +#define GL_RGB16 0x8054 +#define GL_RGBA2 0x8055 +#define GL_RGBA4 0x8056 +#define GL_RGB5_A1 0x8057 +#define GL_RGBA8 0x8058 +#define GL_RGB10_A2 0x8059 +#define GL_RGBA12 0x805A +#define GL_RGBA16 0x805B +#endif + +#ifndef GL_VERSION_1_2 +#define GL_UNSIGNED_BYTE_3_3_2 0x8032 +#define GL_UNSIGNED_SHORT_4_4_4_4 0x8033 +#define GL_UNSIGNED_SHORT_5_5_5_1 0x8034 +#define GL_UNSIGNED_INT_8_8_8_8 0x8035 +#define GL_UNSIGNED_INT_10_10_10_2 0x8036 +#define GL_TEXTURE_BINDING_3D 0x806A +#define GL_PACK_SKIP_IMAGES 0x806B +#define GL_PACK_IMAGE_HEIGHT 0x806C +#define GL_UNPACK_SKIP_IMAGES 0x806D +#define GL_UNPACK_IMAGE_HEIGHT 0x806E +#define GL_TEXTURE_3D 0x806F +#define GL_PROXY_TEXTURE_3D 0x8070 +#define GL_TEXTURE_DEPTH 0x8071 +#define GL_TEXTURE_WRAP_R 0x8072 +#define GL_MAX_3D_TEXTURE_SIZE 0x8073 +#define GL_UNSIGNED_BYTE_2_3_3_REV 0x8362 +#define GL_UNSIGNED_SHORT_5_6_5 0x8363 +#define GL_UNSIGNED_SHORT_5_6_5_REV 0x8364 +#define GL_UNSIGNED_SHORT_4_4_4_4_REV 0x8365 +#define GL_UNSIGNED_SHORT_1_5_5_5_REV 0x8366 +#define GL_UNSIGNED_INT_8_8_8_8_REV 0x8367 +#define GL_UNSIGNED_INT_2_10_10_10_REV 0x8368 +#define GL_BGR 0x80E0 +#define GL_BGRA 0x80E1 +#define GL_MAX_ELEMENTS_VERTICES 0x80E8 +#define GL_MAX_ELEMENTS_INDICES 0x80E9 +#define GL_CLAMP_TO_EDGE 0x812F +#define GL_TEXTURE_MIN_LOD 0x813A +#define GL_TEXTURE_MAX_LOD 0x813B +#define GL_TEXTURE_BASE_LEVEL 0x813C +#define GL_TEXTURE_MAX_LEVEL 0x813D +#define GL_SMOOTH_POINT_SIZE_RANGE 0x0B12 +#define GL_SMOOTH_POINT_SIZE_GRANULARITY 0x0B13 +#define GL_SMOOTH_LINE_WIDTH_RANGE 0x0B22 +#define GL_SMOOTH_LINE_WIDTH_GRANULARITY 0x0B23 +#define GL_ALIASED_LINE_WIDTH_RANGE 0x846E +#endif + +#ifndef GL_ARB_imaging +#define GL_CONSTANT_COLOR 0x8001 +#define GL_ONE_MINUS_CONSTANT_COLOR 0x8002 +#define GL_CONSTANT_ALPHA 0x8003 +#define GL_ONE_MINUS_CONSTANT_ALPHA 0x8004 +#define GL_BLEND_COLOR 0x8005 +#define GL_FUNC_ADD 0x8006 +#define GL_MIN 0x8007 +#define GL_MAX 0x8008 +#define GL_BLEND_EQUATION 0x8009 +#define GL_FUNC_SUBTRACT 0x800A +#define GL_FUNC_REVERSE_SUBTRACT 0x800B +#endif + +#ifndef GL_VERSION_1_3 +#define GL_TEXTURE0 0x84C0 +#define GL_TEXTURE1 0x84C1 +#define GL_TEXTURE2 0x84C2 +#define GL_TEXTURE3 0x84C3 +#define GL_TEXTURE4 0x84C4 +#define GL_TEXTURE5 0x84C5 +#define GL_TEXTURE6 0x84C6 +#define GL_TEXTURE7 0x84C7 +#define GL_TEXTURE8 0x84C8 +#define GL_TEXTURE9 0x84C9 +#define GL_TEXTURE10 0x84CA +#define GL_TEXTURE11 0x84CB +#define GL_TEXTURE12 0x84CC +#define GL_TEXTURE13 0x84CD +#define GL_TEXTURE14 0x84CE +#define GL_TEXTURE15 0x84CF +#define GL_TEXTURE16 0x84D0 +#define GL_TEXTURE17 0x84D1 +#define GL_TEXTURE18 0x84D2 +#define GL_TEXTURE19 0x84D3 +#define GL_TEXTURE20 0x84D4 +#define GL_TEXTURE21 0x84D5 +#define GL_TEXTURE22 0x84D6 +#define GL_TEXTURE23 0x84D7 +#define GL_TEXTURE24 0x84D8 +#define GL_TEXTURE25 0x84D9 +#define GL_TEXTURE26 0x84DA +#define GL_TEXTURE27 0x84DB +#define GL_TEXTURE28 0x84DC +#define GL_TEXTURE29 0x84DD +#define GL_TEXTURE30 0x84DE +#define GL_TEXTURE31 0x84DF +#define GL_ACTIVE_TEXTURE 0x84E0 +#define GL_MULTISAMPLE 0x809D +#define GL_SAMPLE_ALPHA_TO_COVERAGE 0x809E +#define GL_SAMPLE_ALPHA_TO_ONE 0x809F +#define GL_SAMPLE_COVERAGE 0x80A0 +#define GL_SAMPLE_BUFFERS 0x80A8 +#define GL_SAMPLES 0x80A9 +#define GL_SAMPLE_COVERAGE_VALUE 0x80AA +#define GL_SAMPLE_COVERAGE_INVERT 0x80AB +#define GL_TEXTURE_CUBE_MAP 0x8513 +#define GL_TEXTURE_BINDING_CUBE_MAP 0x8514 +#define GL_TEXTURE_CUBE_MAP_POSITIVE_X 0x8515 +#define GL_TEXTURE_CUBE_MAP_NEGATIVE_X 0x8516 +#define GL_TEXTURE_CUBE_MAP_POSITIVE_Y 0x8517 +#define GL_TEXTURE_CUBE_MAP_NEGATIVE_Y 0x8518 +#define GL_TEXTURE_CUBE_MAP_POSITIVE_Z 0x8519 +#define GL_TEXTURE_CUBE_MAP_NEGATIVE_Z 0x851A +#define GL_PROXY_TEXTURE_CUBE_MAP 0x851B +#define GL_MAX_CUBE_MAP_TEXTURE_SIZE 0x851C +#define GL_COMPRESSED_RGB 0x84ED +#define GL_COMPRESSED_RGBA 0x84EE +#define GL_TEXTURE_COMPRESSION_HINT 0x84EF +#define GL_TEXTURE_COMPRESSED_IMAGE_SIZE 0x86A0 +#define GL_TEXTURE_COMPRESSED 0x86A1 +#define GL_NUM_COMPRESSED_TEXTURE_FORMATS 0x86A2 +#define GL_COMPRESSED_TEXTURE_FORMATS 0x86A3 +#define GL_CLAMP_TO_BORDER 0x812D +#endif + +#ifndef GL_VERSION_1_4 +#define GL_BLEND_DST_RGB 0x80C8 +#define GL_BLEND_SRC_RGB 0x80C9 +#define GL_BLEND_DST_ALPHA 0x80CA +#define GL_BLEND_SRC_ALPHA 0x80CB +#define GL_POINT_FADE_THRESHOLD_SIZE 0x8128 +#define GL_DEPTH_COMPONENT16 0x81A5 +#define GL_DEPTH_COMPONENT24 0x81A6 +#define GL_DEPTH_COMPONENT32 0x81A7 +#define GL_MIRRORED_REPEAT 0x8370 +#define GL_MAX_TEXTURE_LOD_BIAS 0x84FD +#define GL_TEXTURE_LOD_BIAS 0x8501 +#define GL_INCR_WRAP 0x8507 +#define GL_DECR_WRAP 0x8508 +#define GL_TEXTURE_DEPTH_SIZE 0x884A +#define GL_TEXTURE_COMPARE_MODE 0x884C +#define GL_TEXTURE_COMPARE_FUNC 0x884D +#endif + +#ifndef GL_VERSION_1_5 +#define GL_BUFFER_SIZE 0x8764 +#define GL_BUFFER_USAGE 0x8765 +#define GL_QUERY_COUNTER_BITS 0x8864 +#define GL_CURRENT_QUERY 0x8865 +#define GL_QUERY_RESULT 0x8866 +#define GL_QUERY_RESULT_AVAILABLE 0x8867 +#define GL_ARRAY_BUFFER 0x8892 +#define GL_ELEMENT_ARRAY_BUFFER 0x8893 +#define GL_ARRAY_BUFFER_BINDING 0x8894 +#define GL_ELEMENT_ARRAY_BUFFER_BINDING 0x8895 +#define GL_VERTEX_ATTRIB_ARRAY_BUFFER_BINDING 0x889F +#define GL_READ_ONLY 0x88B8 +#define GL_WRITE_ONLY 0x88B9 +#define GL_READ_WRITE 0x88BA +#define GL_BUFFER_ACCESS 0x88BB +#define GL_BUFFER_MAPPED 0x88BC +#define GL_BUFFER_MAP_POINTER 0x88BD +#define GL_STREAM_DRAW 0x88E0 +#define GL_STREAM_READ 0x88E1 +#define GL_STREAM_COPY 0x88E2 +#define GL_STATIC_DRAW 0x88E4 +#define GL_STATIC_READ 0x88E5 +#define GL_STATIC_COPY 0x88E6 +#define GL_DYNAMIC_DRAW 0x88E8 +#define GL_DYNAMIC_READ 0x88E9 +#define GL_DYNAMIC_COPY 0x88EA +#define GL_SAMPLES_PASSED 0x8914 +#define GL_SRC1_ALPHA 0x8589 +#endif + +#ifndef GL_VERSION_2_0 +#define GL_BLEND_EQUATION_RGB 0x8009 +#define GL_VERTEX_ATTRIB_ARRAY_ENABLED 0x8622 +#define GL_VERTEX_ATTRIB_ARRAY_SIZE 0x8623 +#define GL_VERTEX_ATTRIB_ARRAY_STRIDE 0x8624 +#define GL_VERTEX_ATTRIB_ARRAY_TYPE 0x8625 +#define GL_CURRENT_VERTEX_ATTRIB 0x8626 +#define GL_VERTEX_PROGRAM_POINT_SIZE 0x8642 +#define GL_VERTEX_ATTRIB_ARRAY_POINTER 0x8645 +#define GL_STENCIL_BACK_FUNC 0x8800 +#define GL_STENCIL_BACK_FAIL 0x8801 +#define GL_STENCIL_BACK_PASS_DEPTH_FAIL 0x8802 +#define GL_STENCIL_BACK_PASS_DEPTH_PASS 0x8803 +#define GL_MAX_DRAW_BUFFERS 0x8824 +#define GL_DRAW_BUFFER0 0x8825 +#define GL_DRAW_BUFFER1 0x8826 +#define GL_DRAW_BUFFER2 0x8827 +#define GL_DRAW_BUFFER3 0x8828 +#define GL_DRAW_BUFFER4 0x8829 +#define GL_DRAW_BUFFER5 0x882A +#define GL_DRAW_BUFFER6 0x882B +#define GL_DRAW_BUFFER7 0x882C +#define GL_DRAW_BUFFER8 0x882D +#define GL_DRAW_BUFFER9 0x882E +#define GL_DRAW_BUFFER10 0x882F +#define GL_DRAW_BUFFER11 0x8830 +#define GL_DRAW_BUFFER12 0x8831 +#define GL_DRAW_BUFFER13 0x8832 +#define GL_DRAW_BUFFER14 0x8833 +#define GL_DRAW_BUFFER15 0x8834 +#define GL_BLEND_EQUATION_ALPHA 0x883D +#define GL_MAX_VERTEX_ATTRIBS 0x8869 +#define GL_VERTEX_ATTRIB_ARRAY_NORMALIZED 0x886A +#define GL_MAX_TEXTURE_IMAGE_UNITS 0x8872 +#define GL_FRAGMENT_SHADER 0x8B30 +#define GL_VERTEX_SHADER 0x8B31 +#define GL_MAX_FRAGMENT_UNIFORM_COMPONENTS 0x8B49 +#define GL_MAX_VERTEX_UNIFORM_COMPONENTS 0x8B4A +#define GL_MAX_VARYING_FLOATS 0x8B4B +#define GL_MAX_VERTEX_TEXTURE_IMAGE_UNITS 0x8B4C +#define GL_MAX_COMBINED_TEXTURE_IMAGE_UNITS 0x8B4D +#define GL_SHADER_TYPE 0x8B4F +#define GL_FLOAT_VEC2 0x8B50 +#define GL_FLOAT_VEC3 0x8B51 +#define GL_FLOAT_VEC4 0x8B52 +#define GL_INT_VEC2 0x8B53 +#define GL_INT_VEC3 0x8B54 +#define GL_INT_VEC4 0x8B55 +#define GL_BOOL 0x8B56 +#define GL_BOOL_VEC2 0x8B57 +#define GL_BOOL_VEC3 0x8B58 +#define GL_BOOL_VEC4 0x8B59 +#define GL_FLOAT_MAT2 0x8B5A +#define GL_FLOAT_MAT3 0x8B5B +#define GL_FLOAT_MAT4 0x8B5C +#define GL_SAMPLER_1D 0x8B5D +#define GL_SAMPLER_2D 0x8B5E +#define GL_SAMPLER_3D 0x8B5F +#define GL_SAMPLER_CUBE 0x8B60 +#define GL_SAMPLER_1D_SHADOW 0x8B61 +#define GL_SAMPLER_2D_SHADOW 0x8B62 +#define GL_DELETE_STATUS 0x8B80 +#define GL_COMPILE_STATUS 0x8B81 +#define GL_LINK_STATUS 0x8B82 +#define GL_VALIDATE_STATUS 0x8B83 +#define GL_INFO_LOG_LENGTH 0x8B84 +#define GL_ATTACHED_SHADERS 0x8B85 +#define GL_ACTIVE_UNIFORMS 0x8B86 +#define GL_ACTIVE_UNIFORM_MAX_LENGTH 0x8B87 +#define GL_SHADER_SOURCE_LENGTH 0x8B88 +#define GL_ACTIVE_ATTRIBUTES 0x8B89 +#define GL_ACTIVE_ATTRIBUTE_MAX_LENGTH 0x8B8A +#define GL_FRAGMENT_SHADER_DERIVATIVE_HINT 0x8B8B +#define GL_SHADING_LANGUAGE_VERSION 0x8B8C +#define GL_CURRENT_PROGRAM 0x8B8D +#define GL_POINT_SPRITE_COORD_ORIGIN 0x8CA0 +#define GL_LOWER_LEFT 0x8CA1 +#define GL_UPPER_LEFT 0x8CA2 +#define GL_STENCIL_BACK_REF 0x8CA3 +#define GL_STENCIL_BACK_VALUE_MASK 0x8CA4 +#define GL_STENCIL_BACK_WRITEMASK 0x8CA5 +#endif + +#ifndef GL_VERSION_2_1 +#define GL_PIXEL_PACK_BUFFER 0x88EB +#define GL_PIXEL_UNPACK_BUFFER 0x88EC +#define GL_PIXEL_PACK_BUFFER_BINDING 0x88ED +#define GL_PIXEL_UNPACK_BUFFER_BINDING 0x88EF +#define GL_FLOAT_MAT2x3 0x8B65 +#define GL_FLOAT_MAT2x4 0x8B66 +#define GL_FLOAT_MAT3x2 0x8B67 +#define GL_FLOAT_MAT3x4 0x8B68 +#define GL_FLOAT_MAT4x2 0x8B69 +#define GL_FLOAT_MAT4x3 0x8B6A +#define GL_SRGB 0x8C40 +#define GL_SRGB8 0x8C41 +#define GL_SRGB_ALPHA 0x8C42 +#define GL_SRGB8_ALPHA8 0x8C43 +#define GL_COMPRESSED_SRGB 0x8C48 +#define GL_COMPRESSED_SRGB_ALPHA 0x8C49 +#endif + +#ifndef GL_VERSION_3_0 +#define GL_COMPARE_REF_TO_TEXTURE 0x884E +#define GL_CLIP_DISTANCE0 0x3000 +#define GL_CLIP_DISTANCE1 0x3001 +#define GL_CLIP_DISTANCE2 0x3002 +#define GL_CLIP_DISTANCE3 0x3003 +#define GL_CLIP_DISTANCE4 0x3004 +#define GL_CLIP_DISTANCE5 0x3005 +#define GL_CLIP_DISTANCE6 0x3006 +#define GL_CLIP_DISTANCE7 0x3007 +#define GL_MAX_CLIP_DISTANCES 0x0D32 +#define GL_MAJOR_VERSION 0x821B +#define GL_MINOR_VERSION 0x821C +#define GL_NUM_EXTENSIONS 0x821D +#define GL_CONTEXT_FLAGS 0x821E +#define GL_COMPRESSED_RED 0x8225 +#define GL_COMPRESSED_RG 0x8226 +#define GL_CONTEXT_FLAG_FORWARD_COMPATIBLE_BIT 0x0001 +#define GL_RGBA32F 0x8814 +#define GL_RGB32F 0x8815 +#define GL_RGBA16F 0x881A +#define GL_RGB16F 0x881B +#define GL_VERTEX_ATTRIB_ARRAY_INTEGER 0x88FD +#define GL_MAX_ARRAY_TEXTURE_LAYERS 0x88FF +#define GL_MIN_PROGRAM_TEXEL_OFFSET 0x8904 +#define GL_MAX_PROGRAM_TEXEL_OFFSET 0x8905 +#define GL_CLAMP_READ_COLOR 0x891C +#define GL_FIXED_ONLY 0x891D +#define GL_MAX_VARYING_COMPONENTS 0x8B4B +#define GL_TEXTURE_1D_ARRAY 0x8C18 +#define GL_PROXY_TEXTURE_1D_ARRAY 0x8C19 +#define GL_TEXTURE_2D_ARRAY 0x8C1A +#define GL_PROXY_TEXTURE_2D_ARRAY 0x8C1B +#define GL_TEXTURE_BINDING_1D_ARRAY 0x8C1C +#define GL_TEXTURE_BINDING_2D_ARRAY 0x8C1D +#define GL_R11F_G11F_B10F 0x8C3A +#define GL_UNSIGNED_INT_10F_11F_11F_REV 0x8C3B +#define GL_RGB9_E5 0x8C3D +#define GL_UNSIGNED_INT_5_9_9_9_REV 0x8C3E +#define GL_TEXTURE_SHARED_SIZE 0x8C3F +#define GL_TRANSFORM_FEEDBACK_VARYING_MAX_LENGTH 0x8C76 +#define GL_TRANSFORM_FEEDBACK_BUFFER_MODE 0x8C7F +#define GL_MAX_TRANSFORM_FEEDBACK_SEPARATE_COMPONENTS 0x8C80 +#define GL_TRANSFORM_FEEDBACK_VARYINGS 0x8C83 +#define GL_TRANSFORM_FEEDBACK_BUFFER_START 0x8C84 +#define GL_TRANSFORM_FEEDBACK_BUFFER_SIZE 0x8C85 +#define GL_PRIMITIVES_GENERATED 0x8C87 +#define GL_TRANSFORM_FEEDBACK_PRIMITIVES_WRITTEN 0x8C88 +#define GL_RASTERIZER_DISCARD 0x8C89 +#define GL_MAX_TRANSFORM_FEEDBACK_INTERLEAVED_COMPONENTS 0x8C8A +#define GL_MAX_TRANSFORM_FEEDBACK_SEPARATE_ATTRIBS 0x8C8B +#define GL_INTERLEAVED_ATTRIBS 0x8C8C +#define GL_SEPARATE_ATTRIBS 0x8C8D +#define GL_TRANSFORM_FEEDBACK_BUFFER 0x8C8E +#define GL_TRANSFORM_FEEDBACK_BUFFER_BINDING 0x8C8F +#define GL_RGBA32UI 0x8D70 +#define GL_RGB32UI 0x8D71 +#define GL_RGBA16UI 0x8D76 +#define GL_RGB16UI 0x8D77 +#define GL_RGBA8UI 0x8D7C +#define GL_RGB8UI 0x8D7D +#define GL_RGBA32I 0x8D82 +#define GL_RGB32I 0x8D83 +#define GL_RGBA16I 0x8D88 +#define GL_RGB16I 0x8D89 +#define GL_RGBA8I 0x8D8E +#define GL_RGB8I 0x8D8F +#define GL_RED_INTEGER 0x8D94 +#define GL_GREEN_INTEGER 0x8D95 +#define GL_BLUE_INTEGER 0x8D96 +#define GL_RGB_INTEGER 0x8D98 +#define GL_RGBA_INTEGER 0x8D99 +#define GL_BGR_INTEGER 0x8D9A +#define GL_BGRA_INTEGER 0x8D9B +#define GL_SAMPLER_1D_ARRAY 0x8DC0 +#define GL_SAMPLER_2D_ARRAY 0x8DC1 +#define GL_SAMPLER_1D_ARRAY_SHADOW 0x8DC3 +#define GL_SAMPLER_2D_ARRAY_SHADOW 0x8DC4 +#define GL_SAMPLER_CUBE_SHADOW 0x8DC5 +#define GL_UNSIGNED_INT_VEC2 0x8DC6 +#define GL_UNSIGNED_INT_VEC3 0x8DC7 +#define GL_UNSIGNED_INT_VEC4 0x8DC8 +#define GL_INT_SAMPLER_1D 0x8DC9 +#define GL_INT_SAMPLER_2D 0x8DCA +#define GL_INT_SAMPLER_3D 0x8DCB +#define GL_INT_SAMPLER_CUBE 0x8DCC +#define GL_INT_SAMPLER_1D_ARRAY 0x8DCE +#define GL_INT_SAMPLER_2D_ARRAY 0x8DCF +#define GL_UNSIGNED_INT_SAMPLER_1D 0x8DD1 +#define GL_UNSIGNED_INT_SAMPLER_2D 0x8DD2 +#define GL_UNSIGNED_INT_SAMPLER_3D 0x8DD3 +#define GL_UNSIGNED_INT_SAMPLER_CUBE 0x8DD4 +#define GL_UNSIGNED_INT_SAMPLER_1D_ARRAY 0x8DD6 +#define GL_UNSIGNED_INT_SAMPLER_2D_ARRAY 0x8DD7 +#define GL_QUERY_WAIT 0x8E13 +#define GL_QUERY_NO_WAIT 0x8E14 +#define GL_QUERY_BY_REGION_WAIT 0x8E15 +#define GL_QUERY_BY_REGION_NO_WAIT 0x8E16 +#define GL_BUFFER_ACCESS_FLAGS 0x911F +#define GL_BUFFER_MAP_LENGTH 0x9120 +#define GL_BUFFER_MAP_OFFSET 0x9121 +/* Reuse tokens from ARB_depth_buffer_float */ +/* reuse GL_DEPTH_COMPONENT32F */ +/* reuse GL_DEPTH32F_STENCIL8 */ +/* reuse GL_FLOAT_32_UNSIGNED_INT_24_8_REV */ +/* Reuse tokens from ARB_framebuffer_object */ +/* reuse GL_INVALID_FRAMEBUFFER_OPERATION */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_COLOR_ENCODING */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_COMPONENT_TYPE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_RED_SIZE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_GREEN_SIZE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_BLUE_SIZE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_ALPHA_SIZE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_DEPTH_SIZE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_STENCIL_SIZE */ +/* reuse GL_FRAMEBUFFER_DEFAULT */ +/* reuse GL_FRAMEBUFFER_UNDEFINED */ +/* reuse GL_DEPTH_STENCIL_ATTACHMENT */ +/* reuse GL_INDEX */ +/* reuse GL_MAX_RENDERBUFFER_SIZE */ +/* reuse GL_DEPTH_STENCIL */ +/* reuse GL_UNSIGNED_INT_24_8 */ +/* reuse GL_DEPTH24_STENCIL8 */ +/* reuse GL_TEXTURE_STENCIL_SIZE */ +/* reuse GL_TEXTURE_RED_TYPE */ +/* reuse GL_TEXTURE_GREEN_TYPE */ +/* reuse GL_TEXTURE_BLUE_TYPE */ +/* reuse GL_TEXTURE_ALPHA_TYPE */ +/* reuse GL_TEXTURE_DEPTH_TYPE */ +/* reuse GL_UNSIGNED_NORMALIZED */ +/* reuse GL_FRAMEBUFFER_BINDING */ +/* reuse GL_DRAW_FRAMEBUFFER_BINDING */ +/* reuse GL_RENDERBUFFER_BINDING */ +/* reuse GL_READ_FRAMEBUFFER */ +/* reuse GL_DRAW_FRAMEBUFFER */ +/* reuse GL_READ_FRAMEBUFFER_BINDING */ +/* reuse GL_RENDERBUFFER_SAMPLES */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_LEVEL */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_CUBE_MAP_FACE */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_LAYER */ +/* reuse GL_FRAMEBUFFER_COMPLETE */ +/* reuse GL_FRAMEBUFFER_INCOMPLETE_ATTACHMENT */ +/* reuse GL_FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT */ +/* reuse GL_FRAMEBUFFER_INCOMPLETE_DRAW_BUFFER */ +/* reuse GL_FRAMEBUFFER_INCOMPLETE_READ_BUFFER */ +/* reuse GL_FRAMEBUFFER_UNSUPPORTED */ +/* reuse GL_MAX_COLOR_ATTACHMENTS */ +/* reuse GL_COLOR_ATTACHMENT0 */ +/* reuse GL_COLOR_ATTACHMENT1 */ +/* reuse GL_COLOR_ATTACHMENT2 */ +/* reuse GL_COLOR_ATTACHMENT3 */ +/* reuse GL_COLOR_ATTACHMENT4 */ +/* reuse GL_COLOR_ATTACHMENT5 */ +/* reuse GL_COLOR_ATTACHMENT6 */ +/* reuse GL_COLOR_ATTACHMENT7 */ +/* reuse GL_COLOR_ATTACHMENT8 */ +/* reuse GL_COLOR_ATTACHMENT9 */ +/* reuse GL_COLOR_ATTACHMENT10 */ +/* reuse GL_COLOR_ATTACHMENT11 */ +/* reuse GL_COLOR_ATTACHMENT12 */ +/* reuse GL_COLOR_ATTACHMENT13 */ +/* reuse GL_COLOR_ATTACHMENT14 */ +/* reuse GL_COLOR_ATTACHMENT15 */ +/* reuse GL_DEPTH_ATTACHMENT */ +/* reuse GL_STENCIL_ATTACHMENT */ +/* reuse GL_FRAMEBUFFER */ +/* reuse GL_RENDERBUFFER */ +/* reuse GL_RENDERBUFFER_WIDTH */ +/* reuse GL_RENDERBUFFER_HEIGHT */ +/* reuse GL_RENDERBUFFER_INTERNAL_FORMAT */ +/* reuse GL_STENCIL_INDEX1 */ +/* reuse GL_STENCIL_INDEX4 */ +/* reuse GL_STENCIL_INDEX8 */ +/* reuse GL_STENCIL_INDEX16 */ +/* reuse GL_RENDERBUFFER_RED_SIZE */ +/* reuse GL_RENDERBUFFER_GREEN_SIZE */ +/* reuse GL_RENDERBUFFER_BLUE_SIZE */ +/* reuse GL_RENDERBUFFER_ALPHA_SIZE */ +/* reuse GL_RENDERBUFFER_DEPTH_SIZE */ +/* reuse GL_RENDERBUFFER_STENCIL_SIZE */ +/* reuse GL_FRAMEBUFFER_INCOMPLETE_MULTISAMPLE */ +/* reuse GL_MAX_SAMPLES */ +/* Reuse tokens from ARB_framebuffer_sRGB */ +/* reuse GL_FRAMEBUFFER_SRGB */ +/* Reuse tokens from ARB_half_float_vertex */ +/* reuse GL_HALF_FLOAT */ +/* Reuse tokens from ARB_map_buffer_range */ +/* reuse GL_MAP_READ_BIT */ +/* reuse GL_MAP_WRITE_BIT */ +/* reuse GL_MAP_INVALIDATE_RANGE_BIT */ +/* reuse GL_MAP_INVALIDATE_BUFFER_BIT */ +/* reuse GL_MAP_FLUSH_EXPLICIT_BIT */ +/* reuse GL_MAP_UNSYNCHRONIZED_BIT */ +/* Reuse tokens from ARB_texture_compression_rgtc */ +/* reuse GL_COMPRESSED_RED_RGTC1 */ +/* reuse GL_COMPRESSED_SIGNED_RED_RGTC1 */ +/* reuse GL_COMPRESSED_RG_RGTC2 */ +/* reuse GL_COMPRESSED_SIGNED_RG_RGTC2 */ +/* Reuse tokens from ARB_texture_rg */ +/* reuse GL_RG */ +/* reuse GL_RG_INTEGER */ +/* reuse GL_R8 */ +/* reuse GL_R16 */ +/* reuse GL_RG8 */ +/* reuse GL_RG16 */ +/* reuse GL_R16F */ +/* reuse GL_R32F */ +/* reuse GL_RG16F */ +/* reuse GL_RG32F */ +/* reuse GL_R8I */ +/* reuse GL_R8UI */ +/* reuse GL_R16I */ +/* reuse GL_R16UI */ +/* reuse GL_R32I */ +/* reuse GL_R32UI */ +/* reuse GL_RG8I */ +/* reuse GL_RG8UI */ +/* reuse GL_RG16I */ +/* reuse GL_RG16UI */ +/* reuse GL_RG32I */ +/* reuse GL_RG32UI */ +/* Reuse tokens from ARB_vertex_array_object */ +/* reuse GL_VERTEX_ARRAY_BINDING */ +#endif + +#ifndef GL_VERSION_3_1 +#define GL_SAMPLER_2D_RECT 0x8B63 +#define GL_SAMPLER_2D_RECT_SHADOW 0x8B64 +#define GL_SAMPLER_BUFFER 0x8DC2 +#define GL_INT_SAMPLER_2D_RECT 0x8DCD +#define GL_INT_SAMPLER_BUFFER 0x8DD0 +#define GL_UNSIGNED_INT_SAMPLER_2D_RECT 0x8DD5 +#define GL_UNSIGNED_INT_SAMPLER_BUFFER 0x8DD8 +#define GL_TEXTURE_BUFFER 0x8C2A +#define GL_MAX_TEXTURE_BUFFER_SIZE 0x8C2B +#define GL_TEXTURE_BINDING_BUFFER 0x8C2C +#define GL_TEXTURE_BUFFER_DATA_STORE_BINDING 0x8C2D +#define GL_TEXTURE_RECTANGLE 0x84F5 +#define GL_TEXTURE_BINDING_RECTANGLE 0x84F6 +#define GL_PROXY_TEXTURE_RECTANGLE 0x84F7 +#define GL_MAX_RECTANGLE_TEXTURE_SIZE 0x84F8 +#define GL_RED_SNORM 0x8F90 +#define GL_RG_SNORM 0x8F91 +#define GL_RGB_SNORM 0x8F92 +#define GL_RGBA_SNORM 0x8F93 +#define GL_R8_SNORM 0x8F94 +#define GL_RG8_SNORM 0x8F95 +#define GL_RGB8_SNORM 0x8F96 +#define GL_RGBA8_SNORM 0x8F97 +#define GL_R16_SNORM 0x8F98 +#define GL_RG16_SNORM 0x8F99 +#define GL_RGB16_SNORM 0x8F9A +#define GL_RGBA16_SNORM 0x8F9B +#define GL_SIGNED_NORMALIZED 0x8F9C +#define GL_PRIMITIVE_RESTART 0x8F9D +#define GL_PRIMITIVE_RESTART_INDEX 0x8F9E +/* Reuse tokens from ARB_copy_buffer */ +/* reuse GL_COPY_READ_BUFFER */ +/* reuse GL_COPY_WRITE_BUFFER */ +/* Reuse tokens from ARB_draw_instanced (none) */ +/* Reuse tokens from ARB_uniform_buffer_object */ +/* reuse GL_UNIFORM_BUFFER */ +/* reuse GL_UNIFORM_BUFFER_BINDING */ +/* reuse GL_UNIFORM_BUFFER_START */ +/* reuse GL_UNIFORM_BUFFER_SIZE */ +/* reuse GL_MAX_VERTEX_UNIFORM_BLOCKS */ +/* reuse GL_MAX_FRAGMENT_UNIFORM_BLOCKS */ +/* reuse GL_MAX_COMBINED_UNIFORM_BLOCKS */ +/* reuse GL_MAX_UNIFORM_BUFFER_BINDINGS */ +/* reuse GL_MAX_UNIFORM_BLOCK_SIZE */ +/* reuse GL_MAX_COMBINED_VERTEX_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_COMBINED_FRAGMENT_UNIFORM_COMPONENTS */ +/* reuse GL_UNIFORM_BUFFER_OFFSET_ALIGNMENT */ +/* reuse GL_ACTIVE_UNIFORM_BLOCK_MAX_NAME_LENGTH */ +/* reuse GL_ACTIVE_UNIFORM_BLOCKS */ +/* reuse GL_UNIFORM_TYPE */ +/* reuse GL_UNIFORM_SIZE */ +/* reuse GL_UNIFORM_NAME_LENGTH */ +/* reuse GL_UNIFORM_BLOCK_INDEX */ +/* reuse GL_UNIFORM_OFFSET */ +/* reuse GL_UNIFORM_ARRAY_STRIDE */ +/* reuse GL_UNIFORM_MATRIX_STRIDE */ +/* reuse GL_UNIFORM_IS_ROW_MAJOR */ +/* reuse GL_UNIFORM_BLOCK_BINDING */ +/* reuse GL_UNIFORM_BLOCK_DATA_SIZE */ +/* reuse GL_UNIFORM_BLOCK_NAME_LENGTH */ +/* reuse GL_UNIFORM_BLOCK_ACTIVE_UNIFORMS */ +/* reuse GL_UNIFORM_BLOCK_ACTIVE_UNIFORM_INDICES */ +/* reuse GL_UNIFORM_BLOCK_REFERENCED_BY_VERTEX_SHADER */ +/* reuse GL_UNIFORM_BLOCK_REFERENCED_BY_FRAGMENT_SHADER */ +/* reuse GL_INVALID_INDEX */ +#endif + +#ifndef GL_VERSION_3_2 +#define GL_CONTEXT_CORE_PROFILE_BIT 0x00000001 +#define GL_CONTEXT_COMPATIBILITY_PROFILE_BIT 0x00000002 +#define GL_LINES_ADJACENCY 0x000A +#define GL_LINE_STRIP_ADJACENCY 0x000B +#define GL_TRIANGLES_ADJACENCY 0x000C +#define GL_TRIANGLE_STRIP_ADJACENCY 0x000D +#define GL_PROGRAM_POINT_SIZE 0x8642 +#define GL_MAX_GEOMETRY_TEXTURE_IMAGE_UNITS 0x8C29 +#define GL_FRAMEBUFFER_ATTACHMENT_LAYERED 0x8DA7 +#define GL_FRAMEBUFFER_INCOMPLETE_LAYER_TARGETS 0x8DA8 +#define GL_GEOMETRY_SHADER 0x8DD9 +#define GL_GEOMETRY_VERTICES_OUT 0x8916 +#define GL_GEOMETRY_INPUT_TYPE 0x8917 +#define GL_GEOMETRY_OUTPUT_TYPE 0x8918 +#define GL_MAX_GEOMETRY_UNIFORM_COMPONENTS 0x8DDF +#define GL_MAX_GEOMETRY_OUTPUT_VERTICES 0x8DE0 +#define GL_MAX_GEOMETRY_TOTAL_OUTPUT_COMPONENTS 0x8DE1 +#define GL_MAX_VERTEX_OUTPUT_COMPONENTS 0x9122 +#define GL_MAX_GEOMETRY_INPUT_COMPONENTS 0x9123 +#define GL_MAX_GEOMETRY_OUTPUT_COMPONENTS 0x9124 +#define GL_MAX_FRAGMENT_INPUT_COMPONENTS 0x9125 +#define GL_CONTEXT_PROFILE_MASK 0x9126 +/* reuse GL_MAX_VARYING_COMPONENTS */ +/* reuse GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_LAYER */ +/* Reuse tokens from ARB_depth_clamp */ +/* reuse GL_DEPTH_CLAMP */ +/* Reuse tokens from ARB_draw_elements_base_vertex (none) */ +/* Reuse tokens from ARB_fragment_coord_conventions (none) */ +/* Reuse tokens from ARB_provoking_vertex */ +/* reuse GL_QUADS_FOLLOW_PROVOKING_VERTEX_CONVENTION */ +/* reuse GL_FIRST_VERTEX_CONVENTION */ +/* reuse GL_LAST_VERTEX_CONVENTION */ +/* reuse GL_PROVOKING_VERTEX */ +/* Reuse tokens from ARB_seamless_cube_map */ +/* reuse GL_TEXTURE_CUBE_MAP_SEAMLESS */ +/* Reuse tokens from ARB_sync */ +/* reuse GL_MAX_SERVER_WAIT_TIMEOUT */ +/* reuse GL_OBJECT_TYPE */ +/* reuse GL_SYNC_CONDITION */ +/* reuse GL_SYNC_STATUS */ +/* reuse GL_SYNC_FLAGS */ +/* reuse GL_SYNC_FENCE */ +/* reuse GL_SYNC_GPU_COMMANDS_COMPLETE */ +/* reuse GL_UNSIGNALED */ +/* reuse GL_SIGNALED */ +/* reuse GL_ALREADY_SIGNALED */ +/* reuse GL_TIMEOUT_EXPIRED */ +/* reuse GL_CONDITION_SATISFIED */ +/* reuse GL_WAIT_FAILED */ +/* reuse GL_TIMEOUT_IGNORED */ +/* reuse GL_SYNC_FLUSH_COMMANDS_BIT */ +/* reuse GL_TIMEOUT_IGNORED */ +/* Reuse tokens from ARB_texture_multisample */ +/* reuse GL_SAMPLE_POSITION */ +/* reuse GL_SAMPLE_MASK */ +/* reuse GL_SAMPLE_MASK_VALUE */ +/* reuse GL_MAX_SAMPLE_MASK_WORDS */ +/* reuse GL_TEXTURE_2D_MULTISAMPLE */ +/* reuse GL_PROXY_TEXTURE_2D_MULTISAMPLE */ +/* reuse GL_TEXTURE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_PROXY_TEXTURE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_TEXTURE_BINDING_2D_MULTISAMPLE */ +/* reuse GL_TEXTURE_BINDING_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_TEXTURE_SAMPLES */ +/* reuse GL_TEXTURE_FIXED_SAMPLE_LOCATIONS */ +/* reuse GL_SAMPLER_2D_MULTISAMPLE */ +/* reuse GL_INT_SAMPLER_2D_MULTISAMPLE */ +/* reuse GL_UNSIGNED_INT_SAMPLER_2D_MULTISAMPLE */ +/* reuse GL_SAMPLER_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_INT_SAMPLER_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_UNSIGNED_INT_SAMPLER_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_MAX_COLOR_TEXTURE_SAMPLES */ +/* reuse GL_MAX_DEPTH_TEXTURE_SAMPLES */ +/* reuse GL_MAX_INTEGER_SAMPLES */ +/* Don't need to reuse tokens from ARB_vertex_array_bgra since they're already in 1.2 core */ +#endif + +#ifndef GL_VERSION_3_3 +#define GL_VERTEX_ATTRIB_ARRAY_DIVISOR 0x88FE +/* Reuse tokens from ARB_blend_func_extended */ +/* reuse GL_SRC1_COLOR */ +/* reuse GL_ONE_MINUS_SRC1_COLOR */ +/* reuse GL_ONE_MINUS_SRC1_ALPHA */ +/* reuse GL_MAX_DUAL_SOURCE_DRAW_BUFFERS */ +/* Reuse tokens from ARB_explicit_attrib_location (none) */ +/* Reuse tokens from ARB_occlusion_query2 */ +/* reuse GL_ANY_SAMPLES_PASSED */ +/* Reuse tokens from ARB_sampler_objects */ +/* reuse GL_SAMPLER_BINDING */ +/* Reuse tokens from ARB_shader_bit_encoding (none) */ +/* Reuse tokens from ARB_texture_rgb10_a2ui */ +/* reuse GL_RGB10_A2UI */ +/* Reuse tokens from ARB_texture_swizzle */ +/* reuse GL_TEXTURE_SWIZZLE_R */ +/* reuse GL_TEXTURE_SWIZZLE_G */ +/* reuse GL_TEXTURE_SWIZZLE_B */ +/* reuse GL_TEXTURE_SWIZZLE_A */ +/* reuse GL_TEXTURE_SWIZZLE_RGBA */ +/* Reuse tokens from ARB_timer_query */ +/* reuse GL_TIME_ELAPSED */ +/* reuse GL_TIMESTAMP */ +/* Reuse tokens from ARB_vertex_type_2_10_10_10_rev */ +/* reuse GL_INT_2_10_10_10_REV */ +#endif + +#ifndef GL_VERSION_4_0 +#define GL_SAMPLE_SHADING 0x8C36 +#define GL_MIN_SAMPLE_SHADING_VALUE 0x8C37 +#define GL_MIN_PROGRAM_TEXTURE_GATHER_OFFSET 0x8E5E +#define GL_MAX_PROGRAM_TEXTURE_GATHER_OFFSET 0x8E5F +#define GL_TEXTURE_CUBE_MAP_ARRAY 0x9009 +#define GL_TEXTURE_BINDING_CUBE_MAP_ARRAY 0x900A +#define GL_PROXY_TEXTURE_CUBE_MAP_ARRAY 0x900B +#define GL_SAMPLER_CUBE_MAP_ARRAY 0x900C +#define GL_SAMPLER_CUBE_MAP_ARRAY_SHADOW 0x900D +#define GL_INT_SAMPLER_CUBE_MAP_ARRAY 0x900E +#define GL_UNSIGNED_INT_SAMPLER_CUBE_MAP_ARRAY 0x900F +/* Reuse tokens from ARB_texture_query_lod (none) */ +/* Reuse tokens from ARB_draw_buffers_blend (none) */ +/* Reuse tokens from ARB_draw_indirect */ +/* reuse GL_DRAW_INDIRECT_BUFFER */ +/* reuse GL_DRAW_INDIRECT_BUFFER_BINDING */ +/* Reuse tokens from ARB_gpu_shader5 */ +/* reuse GL_GEOMETRY_SHADER_INVOCATIONS */ +/* reuse GL_MAX_GEOMETRY_SHADER_INVOCATIONS */ +/* reuse GL_MIN_FRAGMENT_INTERPOLATION_OFFSET */ +/* reuse GL_MAX_FRAGMENT_INTERPOLATION_OFFSET */ +/* reuse GL_FRAGMENT_INTERPOLATION_OFFSET_BITS */ +/* reuse GL_MAX_VERTEX_STREAMS */ +/* Reuse tokens from ARB_gpu_shader_fp64 */ +/* reuse GL_DOUBLE_VEC2 */ +/* reuse GL_DOUBLE_VEC3 */ +/* reuse GL_DOUBLE_VEC4 */ +/* reuse GL_DOUBLE_MAT2 */ +/* reuse GL_DOUBLE_MAT3 */ +/* reuse GL_DOUBLE_MAT4 */ +/* reuse GL_DOUBLE_MAT2x3 */ +/* reuse GL_DOUBLE_MAT2x4 */ +/* reuse GL_DOUBLE_MAT3x2 */ +/* reuse GL_DOUBLE_MAT3x4 */ +/* reuse GL_DOUBLE_MAT4x2 */ +/* reuse GL_DOUBLE_MAT4x3 */ +/* Reuse tokens from ARB_shader_subroutine */ +/* reuse GL_ACTIVE_SUBROUTINES */ +/* reuse GL_ACTIVE_SUBROUTINE_UNIFORMS */ +/* reuse GL_ACTIVE_SUBROUTINE_UNIFORM_LOCATIONS */ +/* reuse GL_ACTIVE_SUBROUTINE_MAX_LENGTH */ +/* reuse GL_ACTIVE_SUBROUTINE_UNIFORM_MAX_LENGTH */ +/* reuse GL_MAX_SUBROUTINES */ +/* reuse GL_MAX_SUBROUTINE_UNIFORM_LOCATIONS */ +/* reuse GL_NUM_COMPATIBLE_SUBROUTINES */ +/* reuse GL_COMPATIBLE_SUBROUTINES */ +/* Reuse tokens from ARB_tessellation_shader */ +/* reuse GL_PATCHES */ +/* reuse GL_PATCH_VERTICES */ +/* reuse GL_PATCH_DEFAULT_INNER_LEVEL */ +/* reuse GL_PATCH_DEFAULT_OUTER_LEVEL */ +/* reuse GL_TESS_CONTROL_OUTPUT_VERTICES */ +/* reuse GL_TESS_GEN_MODE */ +/* reuse GL_TESS_GEN_SPACING */ +/* reuse GL_TESS_GEN_VERTEX_ORDER */ +/* reuse GL_TESS_GEN_POINT_MODE */ +/* reuse GL_ISOLINES */ +/* reuse GL_FRACTIONAL_ODD */ +/* reuse GL_FRACTIONAL_EVEN */ +/* reuse GL_MAX_PATCH_VERTICES */ +/* reuse GL_MAX_TESS_GEN_LEVEL */ +/* reuse GL_MAX_TESS_CONTROL_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_TESS_EVALUATION_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_TESS_CONTROL_TEXTURE_IMAGE_UNITS */ +/* reuse GL_MAX_TESS_EVALUATION_TEXTURE_IMAGE_UNITS */ +/* reuse GL_MAX_TESS_CONTROL_OUTPUT_COMPONENTS */ +/* reuse GL_MAX_TESS_PATCH_COMPONENTS */ +/* reuse GL_MAX_TESS_CONTROL_TOTAL_OUTPUT_COMPONENTS */ +/* reuse GL_MAX_TESS_EVALUATION_OUTPUT_COMPONENTS */ +/* reuse GL_MAX_TESS_CONTROL_UNIFORM_BLOCKS */ +/* reuse GL_MAX_TESS_EVALUATION_UNIFORM_BLOCKS */ +/* reuse GL_MAX_TESS_CONTROL_INPUT_COMPONENTS */ +/* reuse GL_MAX_TESS_EVALUATION_INPUT_COMPONENTS */ +/* reuse GL_MAX_COMBINED_TESS_CONTROL_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_COMBINED_TESS_EVALUATION_UNIFORM_COMPONENTS */ +/* reuse GL_UNIFORM_BLOCK_REFERENCED_BY_TESS_CONTROL_SHADER */ +/* reuse GL_UNIFORM_BLOCK_REFERENCED_BY_TESS_EVALUATION_SHADER */ +/* reuse GL_TESS_EVALUATION_SHADER */ +/* reuse GL_TESS_CONTROL_SHADER */ +/* Reuse tokens from ARB_texture_buffer_object_rgb32 (none) */ +/* Reuse tokens from ARB_transform_feedback2 */ +/* reuse GL_TRANSFORM_FEEDBACK */ +/* reuse GL_TRANSFORM_FEEDBACK_BUFFER_PAUSED */ +/* reuse GL_TRANSFORM_FEEDBACK_BUFFER_ACTIVE */ +/* reuse GL_TRANSFORM_FEEDBACK_BINDING */ +/* Reuse tokens from ARB_transform_feedback3 */ +/* reuse GL_MAX_TRANSFORM_FEEDBACK_BUFFERS */ +/* reuse GL_MAX_VERTEX_STREAMS */ +#endif + +#ifndef GL_VERSION_4_1 +/* Reuse tokens from ARB_ES2_compatibility */ +/* reuse GL_FIXED */ +/* reuse GL_IMPLEMENTATION_COLOR_READ_TYPE */ +/* reuse GL_IMPLEMENTATION_COLOR_READ_FORMAT */ +/* reuse GL_LOW_FLOAT */ +/* reuse GL_MEDIUM_FLOAT */ +/* reuse GL_HIGH_FLOAT */ +/* reuse GL_LOW_INT */ +/* reuse GL_MEDIUM_INT */ +/* reuse GL_HIGH_INT */ +/* reuse GL_SHADER_COMPILER */ +/* reuse GL_SHADER_BINARY_FORMATS */ +/* reuse GL_NUM_SHADER_BINARY_FORMATS */ +/* reuse GL_MAX_VERTEX_UNIFORM_VECTORS */ +/* reuse GL_MAX_VARYING_VECTORS */ +/* reuse GL_MAX_FRAGMENT_UNIFORM_VECTORS */ +/* reuse GL_RGB565 */ +/* Reuse tokens from ARB_get_program_binary */ +/* reuse GL_PROGRAM_BINARY_RETRIEVABLE_HINT */ +/* reuse GL_PROGRAM_BINARY_LENGTH */ +/* reuse GL_NUM_PROGRAM_BINARY_FORMATS */ +/* reuse GL_PROGRAM_BINARY_FORMATS */ +/* Reuse tokens from ARB_separate_shader_objects */ +/* reuse GL_VERTEX_SHADER_BIT */ +/* reuse GL_FRAGMENT_SHADER_BIT */ +/* reuse GL_GEOMETRY_SHADER_BIT */ +/* reuse GL_TESS_CONTROL_SHADER_BIT */ +/* reuse GL_TESS_EVALUATION_SHADER_BIT */ +/* reuse GL_ALL_SHADER_BITS */ +/* reuse GL_PROGRAM_SEPARABLE */ +/* reuse GL_ACTIVE_PROGRAM */ +/* reuse GL_PROGRAM_PIPELINE_BINDING */ +/* Reuse tokens from ARB_shader_precision (none) */ +/* Reuse tokens from ARB_vertex_attrib_64bit - all are in GL 3.0 and 4.0 already */ +/* Reuse tokens from ARB_viewport_array - some are in GL 1.1 and ARB_provoking_vertex already */ +/* reuse GL_MAX_VIEWPORTS */ +/* reuse GL_VIEWPORT_SUBPIXEL_BITS */ +/* reuse GL_VIEWPORT_BOUNDS_RANGE */ +/* reuse GL_LAYER_PROVOKING_VERTEX */ +/* reuse GL_VIEWPORT_INDEX_PROVOKING_VERTEX */ +/* reuse GL_UNDEFINED_VERTEX */ +#endif + +#ifndef GL_VERSION_4_2 +/* Reuse tokens from ARB_base_instance (none) */ +/* Reuse tokens from ARB_shading_language_420pack (none) */ +/* Reuse tokens from ARB_transform_feedback_instanced (none) */ +/* Reuse tokens from ARB_compressed_texture_pixel_storage */ +/* reuse GL_UNPACK_COMPRESSED_BLOCK_WIDTH */ +/* reuse GL_UNPACK_COMPRESSED_BLOCK_HEIGHT */ +/* reuse GL_UNPACK_COMPRESSED_BLOCK_DEPTH */ +/* reuse GL_UNPACK_COMPRESSED_BLOCK_SIZE */ +/* reuse GL_PACK_COMPRESSED_BLOCK_WIDTH */ +/* reuse GL_PACK_COMPRESSED_BLOCK_HEIGHT */ +/* reuse GL_PACK_COMPRESSED_BLOCK_DEPTH */ +/* reuse GL_PACK_COMPRESSED_BLOCK_SIZE */ +/* Reuse tokens from ARB_conservative_depth (none) */ +/* Reuse tokens from ARB_internalformat_query */ +/* reuse GL_NUM_SAMPLE_COUNTS */ +/* Reuse tokens from ARB_map_buffer_alignment */ +/* reuse GL_MIN_MAP_BUFFER_ALIGNMENT */ +/* Reuse tokens from ARB_shader_atomic_counters */ +/* reuse GL_ATOMIC_COUNTER_BUFFER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_BINDING */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_START */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_SIZE */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_DATA_SIZE */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_ACTIVE_ATOMIC_COUNTERS */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_ACTIVE_ATOMIC_COUNTER_INDICES */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_VERTEX_SHADER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_TESS_CONTROL_SHADER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_TESS_EVALUATION_SHADER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_GEOMETRY_SHADER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_FRAGMENT_SHADER */ +/* reuse GL_MAX_VERTEX_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_TESS_CONTROL_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_TESS_EVALUATION_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_GEOMETRY_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_FRAGMENT_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_COMBINED_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_VERTEX_ATOMIC_COUNTERS */ +/* reuse GL_MAX_TESS_CONTROL_ATOMIC_COUNTERS */ +/* reuse GL_MAX_TESS_EVALUATION_ATOMIC_COUNTERS */ +/* reuse GL_MAX_GEOMETRY_ATOMIC_COUNTERS */ +/* reuse GL_MAX_FRAGMENT_ATOMIC_COUNTERS */ +/* reuse GL_MAX_COMBINED_ATOMIC_COUNTERS */ +/* reuse GL_MAX_ATOMIC_COUNTER_BUFFER_SIZE */ +/* reuse GL_MAX_ATOMIC_COUNTER_BUFFER_BINDINGS */ +/* reuse GL_ACTIVE_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_UNIFORM_ATOMIC_COUNTER_BUFFER_INDEX */ +/* reuse GL_UNSIGNED_INT_ATOMIC_COUNTER */ +/* Reuse tokens from ARB_shader_image_load_store */ +/* reuse GL_VERTEX_ATTRIB_ARRAY_BARRIER_BIT */ +/* reuse GL_ELEMENT_ARRAY_BARRIER_BIT */ +/* reuse GL_UNIFORM_BARRIER_BIT */ +/* reuse GL_TEXTURE_FETCH_BARRIER_BIT */ +/* reuse GL_SHADER_IMAGE_ACCESS_BARRIER_BIT */ +/* reuse GL_COMMAND_BARRIER_BIT */ +/* reuse GL_PIXEL_BUFFER_BARRIER_BIT */ +/* reuse GL_TEXTURE_UPDATE_BARRIER_BIT */ +/* reuse GL_BUFFER_UPDATE_BARRIER_BIT */ +/* reuse GL_FRAMEBUFFER_BARRIER_BIT */ +/* reuse GL_TRANSFORM_FEEDBACK_BARRIER_BIT */ +/* reuse GL_ATOMIC_COUNTER_BARRIER_BIT */ +/* reuse GL_ALL_BARRIER_BITS */ +/* reuse GL_MAX_IMAGE_UNITS */ +/* reuse GL_MAX_COMBINED_IMAGE_UNITS_AND_FRAGMENT_OUTPUTS */ +/* reuse GL_IMAGE_BINDING_NAME */ +/* reuse GL_IMAGE_BINDING_LEVEL */ +/* reuse GL_IMAGE_BINDING_LAYERED */ +/* reuse GL_IMAGE_BINDING_LAYER */ +/* reuse GL_IMAGE_BINDING_ACCESS */ +/* reuse GL_IMAGE_1D */ +/* reuse GL_IMAGE_2D */ +/* reuse GL_IMAGE_3D */ +/* reuse GL_IMAGE_2D_RECT */ +/* reuse GL_IMAGE_CUBE */ +/* reuse GL_IMAGE_BUFFER */ +/* reuse GL_IMAGE_1D_ARRAY */ +/* reuse GL_IMAGE_2D_ARRAY */ +/* reuse GL_IMAGE_CUBE_MAP_ARRAY */ +/* reuse GL_IMAGE_2D_MULTISAMPLE */ +/* reuse GL_IMAGE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_INT_IMAGE_1D */ +/* reuse GL_INT_IMAGE_2D */ +/* reuse GL_INT_IMAGE_3D */ +/* reuse GL_INT_IMAGE_2D_RECT */ +/* reuse GL_INT_IMAGE_CUBE */ +/* reuse GL_INT_IMAGE_BUFFER */ +/* reuse GL_INT_IMAGE_1D_ARRAY */ +/* reuse GL_INT_IMAGE_2D_ARRAY */ +/* reuse GL_INT_IMAGE_CUBE_MAP_ARRAY */ +/* reuse GL_INT_IMAGE_2D_MULTISAMPLE */ +/* reuse GL_INT_IMAGE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_UNSIGNED_INT_IMAGE_1D */ +/* reuse GL_UNSIGNED_INT_IMAGE_2D */ +/* reuse GL_UNSIGNED_INT_IMAGE_3D */ +/* reuse GL_UNSIGNED_INT_IMAGE_2D_RECT */ +/* reuse GL_UNSIGNED_INT_IMAGE_CUBE */ +/* reuse GL_UNSIGNED_INT_IMAGE_BUFFER */ +/* reuse GL_UNSIGNED_INT_IMAGE_1D_ARRAY */ +/* reuse GL_UNSIGNED_INT_IMAGE_2D_ARRAY */ +/* reuse GL_UNSIGNED_INT_IMAGE_CUBE_MAP_ARRAY */ +/* reuse GL_UNSIGNED_INT_IMAGE_2D_MULTISAMPLE */ +/* reuse GL_UNSIGNED_INT_IMAGE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_MAX_IMAGE_SAMPLES */ +/* reuse GL_IMAGE_BINDING_FORMAT */ +/* reuse GL_IMAGE_FORMAT_COMPATIBILITY_TYPE */ +/* reuse GL_IMAGE_FORMAT_COMPATIBILITY_BY_SIZE */ +/* reuse GL_IMAGE_FORMAT_COMPATIBILITY_BY_CLASS */ +/* reuse GL_MAX_VERTEX_IMAGE_UNIFORMS */ +/* reuse GL_MAX_TESS_CONTROL_IMAGE_UNIFORMS */ +/* reuse GL_MAX_TESS_EVALUATION_IMAGE_UNIFORMS */ +/* reuse GL_MAX_GEOMETRY_IMAGE_UNIFORMS */ +/* reuse GL_MAX_FRAGMENT_IMAGE_UNIFORMS */ +/* reuse GL_MAX_COMBINED_IMAGE_UNIFORMS */ +/* Reuse tokens from ARB_shading_language_packing (none) */ +/* Reuse tokens from ARB_texture_storage */ +/* reuse GL_TEXTURE_IMMUTABLE_FORMAT */ +#endif + +#ifndef GL_VERSION_4_3 +#define GL_NUM_SHADING_LANGUAGE_VERSIONS 0x82E9 +#define GL_VERTEX_ATTRIB_ARRAY_LONG 0x874E +/* Reuse tokens from ARB_arrays_of_arrays (none, GLSL only) */ +/* Reuse tokens from ARB_fragment_layer_viewport (none, GLSL only) */ +/* Reuse tokens from ARB_shader_image_size (none, GLSL only) */ +/* Reuse tokens from ARB_ES3_compatibility */ +/* reuse GL_COMPRESSED_RGB8_ETC2 */ +/* reuse GL_COMPRESSED_SRGB8_ETC2 */ +/* reuse GL_COMPRESSED_RGB8_PUNCHTHROUGH_ALPHA1_ETC2 */ +/* reuse GL_COMPRESSED_SRGB8_PUNCHTHROUGH_ALPHA1_ETC2 */ +/* reuse GL_COMPRESSED_RGBA8_ETC2_EAC */ +/* reuse GL_COMPRESSED_SRGB8_ALPHA8_ETC2_EAC */ +/* reuse GL_COMPRESSED_R11_EAC */ +/* reuse GL_COMPRESSED_SIGNED_R11_EAC */ +/* reuse GL_COMPRESSED_RG11_EAC */ +/* reuse GL_COMPRESSED_SIGNED_RG11_EAC */ +/* reuse GL_PRIMITIVE_RESTART_FIXED_INDEX */ +/* reuse GL_ANY_SAMPLES_PASSED_CONSERVATIVE */ +/* reuse GL_MAX_ELEMENT_INDEX */ +/* Reuse tokens from ARB_clear_buffer_object (none) */ +/* Reuse tokens from ARB_compute_shader */ +/* reuse GL_COMPUTE_SHADER */ +/* reuse GL_MAX_COMPUTE_UNIFORM_BLOCKS */ +/* reuse GL_MAX_COMPUTE_TEXTURE_IMAGE_UNITS */ +/* reuse GL_MAX_COMPUTE_IMAGE_UNIFORMS */ +/* reuse GL_MAX_COMPUTE_SHARED_MEMORY_SIZE */ +/* reuse GL_MAX_COMPUTE_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_COMPUTE_ATOMIC_COUNTER_BUFFERS */ +/* reuse GL_MAX_COMPUTE_ATOMIC_COUNTERS */ +/* reuse GL_MAX_COMBINED_COMPUTE_UNIFORM_COMPONENTS */ +/* reuse GL_MAX_COMPUTE_LOCAL_INVOCATIONS */ +/* reuse GL_MAX_COMPUTE_WORK_GROUP_COUNT */ +/* reuse GL_MAX_COMPUTE_WORK_GROUP_SIZE */ +/* reuse GL_COMPUTE_LOCAL_WORK_SIZE */ +/* reuse GL_UNIFORM_BLOCK_REFERENCED_BY_COMPUTE_SHADER */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_COMPUTE_SHADER */ +/* reuse GL_DISPATCH_INDIRECT_BUFFER */ +/* reuse GL_DISPATCH_INDIRECT_BUFFER_BINDING */ +/* Reuse tokens from ARB_copy_image (none) */ +/* Reuse tokens from KHR_debug */ +/* reuse GL_DEBUG_OUTPUT_SYNCHRONOUS */ +/* reuse GL_DEBUG_NEXT_LOGGED_MESSAGE_LENGTH */ +/* reuse GL_DEBUG_CALLBACK_FUNCTION */ +/* reuse GL_DEBUG_CALLBACK_USER_PARAM */ +/* reuse GL_DEBUG_SOURCE_API */ +/* reuse GL_DEBUG_SOURCE_WINDOW_SYSTEM */ +/* reuse GL_DEBUG_SOURCE_SHADER_COMPILER */ +/* reuse GL_DEBUG_SOURCE_THIRD_PARTY */ +/* reuse GL_DEBUG_SOURCE_APPLICATION */ +/* reuse GL_DEBUG_SOURCE_OTHER */ +/* reuse GL_DEBUG_TYPE_ERROR */ +/* reuse GL_DEBUG_TYPE_DEPRECATED_BEHAVIOR */ +/* reuse GL_DEBUG_TYPE_UNDEFINED_BEHAVIOR */ +/* reuse GL_DEBUG_TYPE_PORTABILITY */ +/* reuse GL_DEBUG_TYPE_PERFORMANCE */ +/* reuse GL_DEBUG_TYPE_OTHER */ +/* reuse GL_MAX_DEBUG_MESSAGE_LENGTH */ +/* reuse GL_MAX_DEBUG_LOGGED_MESSAGES */ +/* reuse GL_DEBUG_LOGGED_MESSAGES */ +/* reuse GL_DEBUG_SEVERITY_HIGH */ +/* reuse GL_DEBUG_SEVERITY_MEDIUM */ +/* reuse GL_DEBUG_SEVERITY_LOW */ +/* reuse GL_DEBUG_TYPE_MARKER */ +/* reuse GL_DEBUG_TYPE_PUSH_GROUP */ +/* reuse GL_DEBUG_TYPE_POP_GROUP */ +/* reuse GL_DEBUG_SEVERITY_NOTIFICATION */ +/* reuse GL_MAX_DEBUG_GROUP_STACK_DEPTH */ +/* reuse GL_DEBUG_GROUP_STACK_DEPTH */ +/* reuse GL_BUFFER */ +/* reuse GL_SHADER */ +/* reuse GL_PROGRAM */ +/* reuse GL_QUERY */ +/* reuse GL_PROGRAM_PIPELINE */ +/* reuse GL_SAMPLER */ +/* reuse GL_DISPLAY_LIST */ +/* reuse GL_MAX_LABEL_LENGTH */ +/* reuse GL_DEBUG_OUTPUT */ +/* reuse GL_CONTEXT_FLAG_DEBUG_BIT */ +/* reuse GL_STACK_UNDERFLOW */ +/* reuse GL_STACK_OVERFLOW */ +/* Reuse tokens from ARB_explicit_uniform_location */ +/* reuse GL_MAX_UNIFORM_LOCATIONS */ +/* Reuse tokens from ARB_framebuffer_no_attachments */ +/* reuse GL_FRAMEBUFFER_DEFAULT_WIDTH */ +/* reuse GL_FRAMEBUFFER_DEFAULT_HEIGHT */ +/* reuse GL_FRAMEBUFFER_DEFAULT_LAYERS */ +/* reuse GL_FRAMEBUFFER_DEFAULT_SAMPLES */ +/* reuse GL_FRAMEBUFFER_DEFAULT_FIXED_SAMPLE_LOCATIONS */ +/* reuse GL_MAX_FRAMEBUFFER_WIDTH */ +/* reuse GL_MAX_FRAMEBUFFER_HEIGHT */ +/* reuse GL_MAX_FRAMEBUFFER_LAYERS */ +/* reuse GL_MAX_FRAMEBUFFER_SAMPLES */ +/* Reuse tokens from ARB_internalformat_query2 */ +/* reuse GL_INTERNALFORMAT_SUPPORTED */ +/* reuse GL_INTERNALFORMAT_PREFERRED */ +/* reuse GL_INTERNALFORMAT_RED_SIZE */ +/* reuse GL_INTERNALFORMAT_GREEN_SIZE */ +/* reuse GL_INTERNALFORMAT_BLUE_SIZE */ +/* reuse GL_INTERNALFORMAT_ALPHA_SIZE */ +/* reuse GL_INTERNALFORMAT_DEPTH_SIZE */ +/* reuse GL_INTERNALFORMAT_STENCIL_SIZE */ +/* reuse GL_INTERNALFORMAT_SHARED_SIZE */ +/* reuse GL_INTERNALFORMAT_RED_TYPE */ +/* reuse GL_INTERNALFORMAT_GREEN_TYPE */ +/* reuse GL_INTERNALFORMAT_BLUE_TYPE */ +/* reuse GL_INTERNALFORMAT_ALPHA_TYPE */ +/* reuse GL_INTERNALFORMAT_DEPTH_TYPE */ +/* reuse GL_INTERNALFORMAT_STENCIL_TYPE */ +/* reuse GL_MAX_WIDTH */ +/* reuse GL_MAX_HEIGHT */ +/* reuse GL_MAX_DEPTH */ +/* reuse GL_MAX_LAYERS */ +/* reuse GL_MAX_COMBINED_DIMENSIONS */ +/* reuse GL_COLOR_COMPONENTS */ +/* reuse GL_DEPTH_COMPONENTS */ +/* reuse GL_STENCIL_COMPONENTS */ +/* reuse GL_COLOR_RENDERABLE */ +/* reuse GL_DEPTH_RENDERABLE */ +/* reuse GL_STENCIL_RENDERABLE */ +/* reuse GL_FRAMEBUFFER_RENDERABLE */ +/* reuse GL_FRAMEBUFFER_RENDERABLE_LAYERED */ +/* reuse GL_FRAMEBUFFER_BLEND */ +/* reuse GL_READ_PIXELS */ +/* reuse GL_READ_PIXELS_FORMAT */ +/* reuse GL_READ_PIXELS_TYPE */ +/* reuse GL_TEXTURE_IMAGE_FORMAT */ +/* reuse GL_TEXTURE_IMAGE_TYPE */ +/* reuse GL_GET_TEXTURE_IMAGE_FORMAT */ +/* reuse GL_GET_TEXTURE_IMAGE_TYPE */ +/* reuse GL_MIPMAP */ +/* reuse GL_MANUAL_GENERATE_MIPMAP */ +/* reuse GL_AUTO_GENERATE_MIPMAP */ +/* reuse GL_COLOR_ENCODING */ +/* reuse GL_SRGB_READ */ +/* reuse GL_SRGB_WRITE */ +/* reuse GL_FILTER */ +/* reuse GL_VERTEX_TEXTURE */ +/* reuse GL_TESS_CONTROL_TEXTURE */ +/* reuse GL_TESS_EVALUATION_TEXTURE */ +/* reuse GL_GEOMETRY_TEXTURE */ +/* reuse GL_FRAGMENT_TEXTURE */ +/* reuse GL_COMPUTE_TEXTURE */ +/* reuse GL_TEXTURE_SHADOW */ +/* reuse GL_TEXTURE_GATHER */ +/* reuse GL_TEXTURE_GATHER_SHADOW */ +/* reuse GL_SHADER_IMAGE_LOAD */ +/* reuse GL_SHADER_IMAGE_STORE */ +/* reuse GL_SHADER_IMAGE_ATOMIC */ +/* reuse GL_IMAGE_TEXEL_SIZE */ +/* reuse GL_IMAGE_COMPATIBILITY_CLASS */ +/* reuse GL_IMAGE_PIXEL_FORMAT */ +/* reuse GL_IMAGE_PIXEL_TYPE */ +/* reuse GL_SIMULTANEOUS_TEXTURE_AND_DEPTH_TEST */ +/* reuse GL_SIMULTANEOUS_TEXTURE_AND_STENCIL_TEST */ +/* reuse GL_SIMULTANEOUS_TEXTURE_AND_DEPTH_WRITE */ +/* reuse GL_SIMULTANEOUS_TEXTURE_AND_STENCIL_WRITE */ +/* reuse GL_TEXTURE_COMPRESSED_BLOCK_WIDTH */ +/* reuse GL_TEXTURE_COMPRESSED_BLOCK_HEIGHT */ +/* reuse GL_TEXTURE_COMPRESSED_BLOCK_SIZE */ +/* reuse GL_CLEAR_BUFFER */ +/* reuse GL_TEXTURE_VIEW */ +/* reuse GL_VIEW_COMPATIBILITY_CLASS */ +/* reuse GL_FULL_SUPPORT */ +/* reuse GL_CAVEAT_SUPPORT */ +/* reuse GL_IMAGE_CLASS_4_X_32 */ +/* reuse GL_IMAGE_CLASS_2_X_32 */ +/* reuse GL_IMAGE_CLASS_1_X_32 */ +/* reuse GL_IMAGE_CLASS_4_X_16 */ +/* reuse GL_IMAGE_CLASS_2_X_16 */ +/* reuse GL_IMAGE_CLASS_1_X_16 */ +/* reuse GL_IMAGE_CLASS_4_X_8 */ +/* reuse GL_IMAGE_CLASS_2_X_8 */ +/* reuse GL_IMAGE_CLASS_1_X_8 */ +/* reuse GL_IMAGE_CLASS_11_11_10 */ +/* reuse GL_IMAGE_CLASS_10_10_10_2 */ +/* reuse GL_VIEW_CLASS_128_BITS */ +/* reuse GL_VIEW_CLASS_96_BITS */ +/* reuse GL_VIEW_CLASS_64_BITS */ +/* reuse GL_VIEW_CLASS_48_BITS */ +/* reuse GL_VIEW_CLASS_32_BITS */ +/* reuse GL_VIEW_CLASS_24_BITS */ +/* reuse GL_VIEW_CLASS_16_BITS */ +/* reuse GL_VIEW_CLASS_8_BITS */ +/* reuse GL_VIEW_CLASS_S3TC_DXT1_RGB */ +/* reuse GL_VIEW_CLASS_S3TC_DXT1_RGBA */ +/* reuse GL_VIEW_CLASS_S3TC_DXT3_RGBA */ +/* reuse GL_VIEW_CLASS_S3TC_DXT5_RGBA */ +/* reuse GL_VIEW_CLASS_RGTC1_RED */ +/* reuse GL_VIEW_CLASS_RGTC2_RG */ +/* reuse GL_VIEW_CLASS_BPTC_UNORM */ +/* reuse GL_VIEW_CLASS_BPTC_FLOAT */ +/* Reuse tokens from ARB_invalidate_subdata (none) */ +/* Reuse tokens from ARB_multi_draw_indirect (none) */ +/* Reuse tokens from ARB_program_interface_query */ +/* reuse GL_UNIFORM */ +/* reuse GL_UNIFORM_BLOCK */ +/* reuse GL_PROGRAM_INPUT */ +/* reuse GL_PROGRAM_OUTPUT */ +/* reuse GL_BUFFER_VARIABLE */ +/* reuse GL_SHADER_STORAGE_BLOCK */ +/* reuse GL_VERTEX_SUBROUTINE */ +/* reuse GL_TESS_CONTROL_SUBROUTINE */ +/* reuse GL_TESS_EVALUATION_SUBROUTINE */ +/* reuse GL_GEOMETRY_SUBROUTINE */ +/* reuse GL_FRAGMENT_SUBROUTINE */ +/* reuse GL_COMPUTE_SUBROUTINE */ +/* reuse GL_VERTEX_SUBROUTINE_UNIFORM */ +/* reuse GL_TESS_CONTROL_SUBROUTINE_UNIFORM */ +/* reuse GL_TESS_EVALUATION_SUBROUTINE_UNIFORM */ +/* reuse GL_GEOMETRY_SUBROUTINE_UNIFORM */ +/* reuse GL_FRAGMENT_SUBROUTINE_UNIFORM */ +/* reuse GL_COMPUTE_SUBROUTINE_UNIFORM */ +/* reuse GL_TRANSFORM_FEEDBACK_VARYING */ +/* reuse GL_ACTIVE_RESOURCES */ +/* reuse GL_MAX_NAME_LENGTH */ +/* reuse GL_MAX_NUM_ACTIVE_VARIABLES */ +/* reuse GL_MAX_NUM_COMPATIBLE_SUBROUTINES */ +/* reuse GL_NAME_LENGTH */ +/* reuse GL_TYPE */ +/* reuse GL_ARRAY_SIZE */ +/* reuse GL_OFFSET */ +/* reuse GL_BLOCK_INDEX */ +/* reuse GL_ARRAY_STRIDE */ +/* reuse GL_MATRIX_STRIDE */ +/* reuse GL_IS_ROW_MAJOR */ +/* reuse GL_ATOMIC_COUNTER_BUFFER_INDEX */ +/* reuse GL_BUFFER_BINDING */ +/* reuse GL_BUFFER_DATA_SIZE */ +/* reuse GL_NUM_ACTIVE_VARIABLES */ +/* reuse GL_ACTIVE_VARIABLES */ +/* reuse GL_REFERENCED_BY_VERTEX_SHADER */ +/* reuse GL_REFERENCED_BY_TESS_CONTROL_SHADER */ +/* reuse GL_REFERENCED_BY_TESS_EVALUATION_SHADER */ +/* reuse GL_REFERENCED_BY_GEOMETRY_SHADER */ +/* reuse GL_REFERENCED_BY_FRAGMENT_SHADER */ +/* reuse GL_REFERENCED_BY_COMPUTE_SHADER */ +/* reuse GL_TOP_LEVEL_ARRAY_SIZE */ +/* reuse GL_TOP_LEVEL_ARRAY_STRIDE */ +/* reuse GL_LOCATION */ +/* reuse GL_LOCATION_INDEX */ +/* reuse GL_IS_PER_PATCH */ +/* Reuse tokens from ARB_robust_buffer_access_behavior (none) */ +/* Reuse tokens from ARB_shader_storage_buffer_object */ +/* reuse GL_SHADER_STORAGE_BUFFER */ +/* reuse GL_SHADER_STORAGE_BUFFER_BINDING */ +/* reuse GL_SHADER_STORAGE_BUFFER_START */ +/* reuse GL_SHADER_STORAGE_BUFFER_SIZE */ +/* reuse GL_MAX_VERTEX_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_GEOMETRY_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_TESS_CONTROL_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_TESS_EVALUATION_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_FRAGMENT_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_COMBINED_SHADER_STORAGE_BLOCKS */ +/* reuse GL_MAX_SHADER_STORAGE_BUFFER_BINDINGS */ +/* reuse GL_MAX_SHADER_STORAGE_BLOCK_SIZE */ +/* reuse GL_SHADER_STORAGE_BUFFER_OFFSET_ALIGNMENT */ +/* reuse GL_SHADER_STORAGE_BARRIER_BIT */ +/* reuse GL_MAX_COMBINED_SHADER_OUTPUT_RESOURCES */ +/* Reuse tokens from ARB_stencil_texturing */ +/* reuse GL_DEPTH_STENCIL_TEXTURE_MODE */ +/* Reuse tokens from ARB_texture_buffer_range */ +/* reuse GL_TEXTURE_BUFFER_OFFSET */ +/* reuse GL_TEXTURE_BUFFER_SIZE */ +/* reuse GL_TEXTURE_BUFFER_OFFSET_ALIGNMENT */ +/* Reuse tokens from ARB_texture_query_levels (none) */ +/* Reuse tokens from ARB_texture_storage_multisample (none) */ +/* Reuse tokens from ARB_texture_view */ +/* reuse GL_TEXTURE_VIEW_MIN_LEVEL */ +/* reuse GL_TEXTURE_VIEW_NUM_LEVELS */ +/* reuse GL_TEXTURE_VIEW_MIN_LAYER */ +/* reuse GL_TEXTURE_VIEW_NUM_LAYERS */ +/* reuse GL_TEXTURE_IMMUTABLE_LEVELS */ +/* Reuse tokens from ARB_vertex_attrib_binding */ +/* reuse GL_VERTEX_ATTRIB_BINDING */ +/* reuse GL_VERTEX_ATTRIB_RELATIVE_OFFSET */ +/* reuse GL_VERTEX_BINDING_DIVISOR */ +/* reuse GL_VERTEX_BINDING_OFFSET */ +/* reuse GL_VERTEX_BINDING_STRIDE */ +/* reuse GL_MAX_VERTEX_ATTRIB_RELATIVE_OFFSET */ +/* reuse GL_MAX_VERTEX_ATTRIB_BINDINGS */ +#endif + +#ifndef GL_ARB_depth_buffer_float +#define GL_DEPTH_COMPONENT32F 0x8CAC +#define GL_DEPTH32F_STENCIL8 0x8CAD +#define GL_FLOAT_32_UNSIGNED_INT_24_8_REV 0x8DAD +#endif + +#ifndef GL_ARB_framebuffer_object +#define GL_INVALID_FRAMEBUFFER_OPERATION 0x0506 +#define GL_FRAMEBUFFER_ATTACHMENT_COLOR_ENCODING 0x8210 +#define GL_FRAMEBUFFER_ATTACHMENT_COMPONENT_TYPE 0x8211 +#define GL_FRAMEBUFFER_ATTACHMENT_RED_SIZE 0x8212 +#define GL_FRAMEBUFFER_ATTACHMENT_GREEN_SIZE 0x8213 +#define GL_FRAMEBUFFER_ATTACHMENT_BLUE_SIZE 0x8214 +#define GL_FRAMEBUFFER_ATTACHMENT_ALPHA_SIZE 0x8215 +#define GL_FRAMEBUFFER_ATTACHMENT_DEPTH_SIZE 0x8216 +#define GL_FRAMEBUFFER_ATTACHMENT_STENCIL_SIZE 0x8217 +#define GL_FRAMEBUFFER_DEFAULT 0x8218 +#define GL_FRAMEBUFFER_UNDEFINED 0x8219 +#define GL_DEPTH_STENCIL_ATTACHMENT 0x821A +#define GL_MAX_RENDERBUFFER_SIZE 0x84E8 +#define GL_DEPTH_STENCIL 0x84F9 +#define GL_UNSIGNED_INT_24_8 0x84FA +#define GL_DEPTH24_STENCIL8 0x88F0 +#define GL_TEXTURE_STENCIL_SIZE 0x88F1 +#define GL_TEXTURE_RED_TYPE 0x8C10 +#define GL_TEXTURE_GREEN_TYPE 0x8C11 +#define GL_TEXTURE_BLUE_TYPE 0x8C12 +#define GL_TEXTURE_ALPHA_TYPE 0x8C13 +#define GL_TEXTURE_DEPTH_TYPE 0x8C16 +#define GL_UNSIGNED_NORMALIZED 0x8C17 +#define GL_FRAMEBUFFER_BINDING 0x8CA6 +#define GL_DRAW_FRAMEBUFFER_BINDING GL_FRAMEBUFFER_BINDING +#define GL_RENDERBUFFER_BINDING 0x8CA7 +#define GL_READ_FRAMEBUFFER 0x8CA8 +#define GL_DRAW_FRAMEBUFFER 0x8CA9 +#define GL_READ_FRAMEBUFFER_BINDING 0x8CAA +#define GL_RENDERBUFFER_SAMPLES 0x8CAB +#define GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE 0x8CD0 +#define GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME 0x8CD1 +#define GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_LEVEL 0x8CD2 +#define GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_CUBE_MAP_FACE 0x8CD3 +#define GL_FRAMEBUFFER_ATTACHMENT_TEXTURE_LAYER 0x8CD4 +#define GL_FRAMEBUFFER_COMPLETE 0x8CD5 +#define GL_FRAMEBUFFER_INCOMPLETE_ATTACHMENT 0x8CD6 +#define GL_FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT 0x8CD7 +#define GL_FRAMEBUFFER_INCOMPLETE_DRAW_BUFFER 0x8CDB +#define GL_FRAMEBUFFER_INCOMPLETE_READ_BUFFER 0x8CDC +#define GL_FRAMEBUFFER_UNSUPPORTED 0x8CDD +#define GL_MAX_COLOR_ATTACHMENTS 0x8CDF +#define GL_COLOR_ATTACHMENT0 0x8CE0 +#define GL_COLOR_ATTACHMENT1 0x8CE1 +#define GL_COLOR_ATTACHMENT2 0x8CE2 +#define GL_COLOR_ATTACHMENT3 0x8CE3 +#define GL_COLOR_ATTACHMENT4 0x8CE4 +#define GL_COLOR_ATTACHMENT5 0x8CE5 +#define GL_COLOR_ATTACHMENT6 0x8CE6 +#define GL_COLOR_ATTACHMENT7 0x8CE7 +#define GL_COLOR_ATTACHMENT8 0x8CE8 +#define GL_COLOR_ATTACHMENT9 0x8CE9 +#define GL_COLOR_ATTACHMENT10 0x8CEA +#define GL_COLOR_ATTACHMENT11 0x8CEB +#define GL_COLOR_ATTACHMENT12 0x8CEC +#define GL_COLOR_ATTACHMENT13 0x8CED +#define GL_COLOR_ATTACHMENT14 0x8CEE +#define GL_COLOR_ATTACHMENT15 0x8CEF +#define GL_DEPTH_ATTACHMENT 0x8D00 +#define GL_STENCIL_ATTACHMENT 0x8D20 +#define GL_FRAMEBUFFER 0x8D40 +#define GL_RENDERBUFFER 0x8D41 +#define GL_RENDERBUFFER_WIDTH 0x8D42 +#define GL_RENDERBUFFER_HEIGHT 0x8D43 +#define GL_RENDERBUFFER_INTERNAL_FORMAT 0x8D44 +#define GL_STENCIL_INDEX1 0x8D46 +#define GL_STENCIL_INDEX4 0x8D47 +#define GL_STENCIL_INDEX8 0x8D48 +#define GL_STENCIL_INDEX16 0x8D49 +#define GL_RENDERBUFFER_RED_SIZE 0x8D50 +#define GL_RENDERBUFFER_GREEN_SIZE 0x8D51 +#define GL_RENDERBUFFER_BLUE_SIZE 0x8D52 +#define GL_RENDERBUFFER_ALPHA_SIZE 0x8D53 +#define GL_RENDERBUFFER_DEPTH_SIZE 0x8D54 +#define GL_RENDERBUFFER_STENCIL_SIZE 0x8D55 +#define GL_FRAMEBUFFER_INCOMPLETE_MULTISAMPLE 0x8D56 +#define GL_MAX_SAMPLES 0x8D57 +#endif + +#ifndef GL_ARB_framebuffer_sRGB +#define GL_FRAMEBUFFER_SRGB 0x8DB9 +#endif + +#ifndef GL_ARB_half_float_vertex +#define GL_HALF_FLOAT 0x140B +#endif + +#ifndef GL_ARB_map_buffer_range +#define GL_MAP_READ_BIT 0x0001 +#define GL_MAP_WRITE_BIT 0x0002 +#define GL_MAP_INVALIDATE_RANGE_BIT 0x0004 +#define GL_MAP_INVALIDATE_BUFFER_BIT 0x0008 +#define GL_MAP_FLUSH_EXPLICIT_BIT 0x0010 +#define GL_MAP_UNSYNCHRONIZED_BIT 0x0020 +#endif + +#ifndef GL_ARB_texture_compression_rgtc +#define GL_COMPRESSED_RED_RGTC1 0x8DBB +#define GL_COMPRESSED_SIGNED_RED_RGTC1 0x8DBC +#define GL_COMPRESSED_RG_RGTC2 0x8DBD +#define GL_COMPRESSED_SIGNED_RG_RGTC2 0x8DBE +#endif + +#ifndef GL_ARB_texture_rg +#define GL_RG 0x8227 +#define GL_RG_INTEGER 0x8228 +#define GL_R8 0x8229 +#define GL_R16 0x822A +#define GL_RG8 0x822B +#define GL_RG16 0x822C +#define GL_R16F 0x822D +#define GL_R32F 0x822E +#define GL_RG16F 0x822F +#define GL_RG32F 0x8230 +#define GL_R8I 0x8231 +#define GL_R8UI 0x8232 +#define GL_R16I 0x8233 +#define GL_R16UI 0x8234 +#define GL_R32I 0x8235 +#define GL_R32UI 0x8236 +#define GL_RG8I 0x8237 +#define GL_RG8UI 0x8238 +#define GL_RG16I 0x8239 +#define GL_RG16UI 0x823A +#define GL_RG32I 0x823B +#define GL_RG32UI 0x823C +#endif + +#ifndef GL_ARB_vertex_array_object +#define GL_VERTEX_ARRAY_BINDING 0x85B5 +#endif + +#ifndef GL_ARB_uniform_buffer_object +#define GL_UNIFORM_BUFFER 0x8A11 +#define GL_UNIFORM_BUFFER_BINDING 0x8A28 +#define GL_UNIFORM_BUFFER_START 0x8A29 +#define GL_UNIFORM_BUFFER_SIZE 0x8A2A +#define GL_MAX_VERTEX_UNIFORM_BLOCKS 0x8A2B +#define GL_MAX_GEOMETRY_UNIFORM_BLOCKS 0x8A2C +#define GL_MAX_FRAGMENT_UNIFORM_BLOCKS 0x8A2D +#define GL_MAX_COMBINED_UNIFORM_BLOCKS 0x8A2E +#define GL_MAX_UNIFORM_BUFFER_BINDINGS 0x8A2F +#define GL_MAX_UNIFORM_BLOCK_SIZE 0x8A30 +#define GL_MAX_COMBINED_VERTEX_UNIFORM_COMPONENTS 0x8A31 +#define GL_MAX_COMBINED_GEOMETRY_UNIFORM_COMPONENTS 0x8A32 +#define GL_MAX_COMBINED_FRAGMENT_UNIFORM_COMPONENTS 0x8A33 +#define GL_UNIFORM_BUFFER_OFFSET_ALIGNMENT 0x8A34 +#define GL_ACTIVE_UNIFORM_BLOCK_MAX_NAME_LENGTH 0x8A35 +#define GL_ACTIVE_UNIFORM_BLOCKS 0x8A36 +#define GL_UNIFORM_TYPE 0x8A37 +#define GL_UNIFORM_SIZE 0x8A38 +#define GL_UNIFORM_NAME_LENGTH 0x8A39 +#define GL_UNIFORM_BLOCK_INDEX 0x8A3A +#define GL_UNIFORM_OFFSET 0x8A3B +#define GL_UNIFORM_ARRAY_STRIDE 0x8A3C +#define GL_UNIFORM_MATRIX_STRIDE 0x8A3D +#define GL_UNIFORM_IS_ROW_MAJOR 0x8A3E +#define GL_UNIFORM_BLOCK_BINDING 0x8A3F +#define GL_UNIFORM_BLOCK_DATA_SIZE 0x8A40 +#define GL_UNIFORM_BLOCK_NAME_LENGTH 0x8A41 +#define GL_UNIFORM_BLOCK_ACTIVE_UNIFORMS 0x8A42 +#define GL_UNIFORM_BLOCK_ACTIVE_UNIFORM_INDICES 0x8A43 +#define GL_UNIFORM_BLOCK_REFERENCED_BY_VERTEX_SHADER 0x8A44 +#define GL_UNIFORM_BLOCK_REFERENCED_BY_GEOMETRY_SHADER 0x8A45 +#define GL_UNIFORM_BLOCK_REFERENCED_BY_FRAGMENT_SHADER 0x8A46 +#define GL_INVALID_INDEX 0xFFFFFFFFu +#endif + +#ifndef GL_ARB_copy_buffer +#define GL_COPY_READ_BUFFER_BINDING 0x8F36 +#define GL_COPY_READ_BUFFER GL_COPY_READ_BUFFER_BINDING +#define GL_COPY_WRITE_BUFFER_BINDING 0x8F37 +#define GL_COPY_WRITE_BUFFER GL_COPY_WRITE_BUFFER_BINDING +#endif + +#ifndef GL_ARB_depth_clamp +#define GL_DEPTH_CLAMP 0x864F +#endif + +#ifndef GL_ARB_draw_elements_base_vertex +#endif + +#ifndef GL_ARB_fragment_coord_conventions +#endif + +#ifndef GL_ARB_provoking_vertex +#define GL_QUADS_FOLLOW_PROVOKING_VERTEX_CONVENTION 0x8E4C +#define GL_FIRST_VERTEX_CONVENTION 0x8E4D +#define GL_LAST_VERTEX_CONVENTION 0x8E4E +#define GL_PROVOKING_VERTEX 0x8E4F +#endif + +#ifndef GL_ARB_seamless_cube_map +#define GL_TEXTURE_CUBE_MAP_SEAMLESS 0x884F +#endif + +#ifndef GL_ARB_sync +#define GL_MAX_SERVER_WAIT_TIMEOUT 0x9111 +#define GL_OBJECT_TYPE 0x9112 +#define GL_SYNC_CONDITION 0x9113 +#define GL_SYNC_STATUS 0x9114 +#define GL_SYNC_FLAGS 0x9115 +#define GL_SYNC_FENCE 0x9116 +#define GL_SYNC_GPU_COMMANDS_COMPLETE 0x9117 +#define GL_UNSIGNALED 0x9118 +#define GL_SIGNALED 0x9119 +#define GL_ALREADY_SIGNALED 0x911A +#define GL_TIMEOUT_EXPIRED 0x911B +#define GL_CONDITION_SATISFIED 0x911C +#define GL_WAIT_FAILED 0x911D +#define GL_SYNC_FLUSH_COMMANDS_BIT 0x00000001 +#define GL_TIMEOUT_IGNORED 0xFFFFFFFFFFFFFFFFull +#endif + +#ifndef GL_ARB_texture_multisample +#define GL_SAMPLE_POSITION 0x8E50 +#define GL_SAMPLE_MASK 0x8E51 +#define GL_SAMPLE_MASK_VALUE 0x8E52 +#define GL_MAX_SAMPLE_MASK_WORDS 0x8E59 +#define GL_TEXTURE_2D_MULTISAMPLE 0x9100 +#define GL_PROXY_TEXTURE_2D_MULTISAMPLE 0x9101 +#define GL_TEXTURE_2D_MULTISAMPLE_ARRAY 0x9102 +#define GL_PROXY_TEXTURE_2D_MULTISAMPLE_ARRAY 0x9103 +#define GL_TEXTURE_BINDING_2D_MULTISAMPLE 0x9104 +#define GL_TEXTURE_BINDING_2D_MULTISAMPLE_ARRAY 0x9105 +#define GL_TEXTURE_SAMPLES 0x9106 +#define GL_TEXTURE_FIXED_SAMPLE_LOCATIONS 0x9107 +#define GL_SAMPLER_2D_MULTISAMPLE 0x9108 +#define GL_INT_SAMPLER_2D_MULTISAMPLE 0x9109 +#define GL_UNSIGNED_INT_SAMPLER_2D_MULTISAMPLE 0x910A +#define GL_SAMPLER_2D_MULTISAMPLE_ARRAY 0x910B +#define GL_INT_SAMPLER_2D_MULTISAMPLE_ARRAY 0x910C +#define GL_UNSIGNED_INT_SAMPLER_2D_MULTISAMPLE_ARRAY 0x910D +#define GL_MAX_COLOR_TEXTURE_SAMPLES 0x910E +#define GL_MAX_DEPTH_TEXTURE_SAMPLES 0x910F +#define GL_MAX_INTEGER_SAMPLES 0x9110 +#endif + +#ifndef GL_ARB_vertex_array_bgra +/* reuse GL_BGRA */ +#endif + +#ifndef GL_ARB_draw_buffers_blend +#endif + +#ifndef GL_ARB_sample_shading +#define GL_SAMPLE_SHADING_ARB 0x8C36 +#define GL_MIN_SAMPLE_SHADING_VALUE_ARB 0x8C37 +#endif + +#ifndef GL_ARB_texture_cube_map_array +#define GL_TEXTURE_CUBE_MAP_ARRAY_ARB 0x9009 +#define GL_TEXTURE_BINDING_CUBE_MAP_ARRAY_ARB 0x900A +#define GL_PROXY_TEXTURE_CUBE_MAP_ARRAY_ARB 0x900B +#define GL_SAMPLER_CUBE_MAP_ARRAY_ARB 0x900C +#define GL_SAMPLER_CUBE_MAP_ARRAY_SHADOW_ARB 0x900D +#define GL_INT_SAMPLER_CUBE_MAP_ARRAY_ARB 0x900E +#define GL_UNSIGNED_INT_SAMPLER_CUBE_MAP_ARRAY_ARB 0x900F +#endif + +#ifndef GL_ARB_texture_gather +#define GL_MIN_PROGRAM_TEXTURE_GATHER_OFFSET_ARB 0x8E5E +#define GL_MAX_PROGRAM_TEXTURE_GATHER_OFFSET_ARB 0x8E5F +#define GL_MAX_PROGRAM_TEXTURE_GATHER_COMPONENTS_ARB 0x8F9F +#endif + +#ifndef GL_ARB_texture_query_lod +#endif + +#ifndef GL_ARB_shading_language_include +#define GL_SHADER_INCLUDE_ARB 0x8DAE +#define GL_NAMED_STRING_LENGTH_ARB 0x8DE9 +#define GL_NAMED_STRING_TYPE_ARB 0x8DEA +#endif + +#ifndef GL_ARB_texture_compression_bptc +#define GL_COMPRESSED_RGBA_BPTC_UNORM_ARB 0x8E8C +#define GL_COMPRESSED_SRGB_ALPHA_BPTC_UNORM_ARB 0x8E8D +#define GL_COMPRESSED_RGB_BPTC_SIGNED_FLOAT_ARB 0x8E8E +#define GL_COMPRESSED_RGB_BPTC_UNSIGNED_FLOAT_ARB 0x8E8F +#endif + +#ifndef GL_ARB_blend_func_extended +#define GL_SRC1_COLOR 0x88F9 +/* reuse GL_SRC1_ALPHA */ +#define GL_ONE_MINUS_SRC1_COLOR 0x88FA +#define GL_ONE_MINUS_SRC1_ALPHA 0x88FB +#define GL_MAX_DUAL_SOURCE_DRAW_BUFFERS 0x88FC +#endif + +#ifndef GL_ARB_explicit_attrib_location +#endif + +#ifndef GL_ARB_occlusion_query2 +#define GL_ANY_SAMPLES_PASSED 0x8C2F +#endif + +#ifndef GL_ARB_sampler_objects +#define GL_SAMPLER_BINDING 0x8919 +#endif + +#ifndef GL_ARB_shader_bit_encoding +#endif + +#ifndef GL_ARB_texture_rgb10_a2ui +#define GL_RGB10_A2UI 0x906F +#endif + +#ifndef GL_ARB_texture_swizzle +#define GL_TEXTURE_SWIZZLE_R 0x8E42 +#define GL_TEXTURE_SWIZZLE_G 0x8E43 +#define GL_TEXTURE_SWIZZLE_B 0x8E44 +#define GL_TEXTURE_SWIZZLE_A 0x8E45 +#define GL_TEXTURE_SWIZZLE_RGBA 0x8E46 +#endif + +#ifndef GL_ARB_timer_query +#define GL_TIME_ELAPSED 0x88BF +#define GL_TIMESTAMP 0x8E28 +#endif + +#ifndef GL_ARB_vertex_type_2_10_10_10_rev +/* reuse GL_UNSIGNED_INT_2_10_10_10_REV */ +#define GL_INT_2_10_10_10_REV 0x8D9F +#endif + +#ifndef GL_ARB_draw_indirect +#define GL_DRAW_INDIRECT_BUFFER 0x8F3F +#define GL_DRAW_INDIRECT_BUFFER_BINDING 0x8F43 +#endif + +#ifndef GL_ARB_gpu_shader5 +#define GL_GEOMETRY_SHADER_INVOCATIONS 0x887F +#define GL_MAX_GEOMETRY_SHADER_INVOCATIONS 0x8E5A +#define GL_MIN_FRAGMENT_INTERPOLATION_OFFSET 0x8E5B +#define GL_MAX_FRAGMENT_INTERPOLATION_OFFSET 0x8E5C +#define GL_FRAGMENT_INTERPOLATION_OFFSET_BITS 0x8E5D +/* reuse GL_MAX_VERTEX_STREAMS */ +#endif + +#ifndef GL_ARB_gpu_shader_fp64 +/* reuse GL_DOUBLE */ +#define GL_DOUBLE_VEC2 0x8FFC +#define GL_DOUBLE_VEC3 0x8FFD +#define GL_DOUBLE_VEC4 0x8FFE +#define GL_DOUBLE_MAT2 0x8F46 +#define GL_DOUBLE_MAT3 0x8F47 +#define GL_DOUBLE_MAT4 0x8F48 +#define GL_DOUBLE_MAT2x3 0x8F49 +#define GL_DOUBLE_MAT2x4 0x8F4A +#define GL_DOUBLE_MAT3x2 0x8F4B +#define GL_DOUBLE_MAT3x4 0x8F4C +#define GL_DOUBLE_MAT4x2 0x8F4D +#define GL_DOUBLE_MAT4x3 0x8F4E +#endif + +#ifndef GL_ARB_shader_subroutine +#define GL_ACTIVE_SUBROUTINES 0x8DE5 +#define GL_ACTIVE_SUBROUTINE_UNIFORMS 0x8DE6 +#define GL_ACTIVE_SUBROUTINE_UNIFORM_LOCATIONS 0x8E47 +#define GL_ACTIVE_SUBROUTINE_MAX_LENGTH 0x8E48 +#define GL_ACTIVE_SUBROUTINE_UNIFORM_MAX_LENGTH 0x8E49 +#define GL_MAX_SUBROUTINES 0x8DE7 +#define GL_MAX_SUBROUTINE_UNIFORM_LOCATIONS 0x8DE8 +#define GL_NUM_COMPATIBLE_SUBROUTINES 0x8E4A +#define GL_COMPATIBLE_SUBROUTINES 0x8E4B +/* reuse GL_UNIFORM_SIZE */ +/* reuse GL_UNIFORM_NAME_LENGTH */ +#endif + +#ifndef GL_ARB_tessellation_shader +#define GL_PATCHES 0x000E +#define GL_PATCH_VERTICES 0x8E72 +#define GL_PATCH_DEFAULT_INNER_LEVEL 0x8E73 +#define GL_PATCH_DEFAULT_OUTER_LEVEL 0x8E74 +#define GL_TESS_CONTROL_OUTPUT_VERTICES 0x8E75 +#define GL_TESS_GEN_MODE 0x8E76 +#define GL_TESS_GEN_SPACING 0x8E77 +#define GL_TESS_GEN_VERTEX_ORDER 0x8E78 +#define GL_TESS_GEN_POINT_MODE 0x8E79 +/* reuse GL_TRIANGLES */ +/* reuse GL_QUADS */ +#define GL_ISOLINES 0x8E7A +/* reuse GL_EQUAL */ +#define GL_FRACTIONAL_ODD 0x8E7B +#define GL_FRACTIONAL_EVEN 0x8E7C +/* reuse GL_CCW */ +/* reuse GL_CW */ +#define GL_MAX_PATCH_VERTICES 0x8E7D +#define GL_MAX_TESS_GEN_LEVEL 0x8E7E +#define GL_MAX_TESS_CONTROL_UNIFORM_COMPONENTS 0x8E7F +#define GL_MAX_TESS_EVALUATION_UNIFORM_COMPONENTS 0x8E80 +#define GL_MAX_TESS_CONTROL_TEXTURE_IMAGE_UNITS 0x8E81 +#define GL_MAX_TESS_EVALUATION_TEXTURE_IMAGE_UNITS 0x8E82 +#define GL_MAX_TESS_CONTROL_OUTPUT_COMPONENTS 0x8E83 +#define GL_MAX_TESS_PATCH_COMPONENTS 0x8E84 +#define GL_MAX_TESS_CONTROL_TOTAL_OUTPUT_COMPONENTS 0x8E85 +#define GL_MAX_TESS_EVALUATION_OUTPUT_COMPONENTS 0x8E86 +#define GL_MAX_TESS_CONTROL_UNIFORM_BLOCKS 0x8E89 +#define GL_MAX_TESS_EVALUATION_UNIFORM_BLOCKS 0x8E8A +#define GL_MAX_TESS_CONTROL_INPUT_COMPONENTS 0x886C +#define GL_MAX_TESS_EVALUATION_INPUT_COMPONENTS 0x886D +#define GL_MAX_COMBINED_TESS_CONTROL_UNIFORM_COMPONENTS 0x8E1E +#define GL_MAX_COMBINED_TESS_EVALUATION_UNIFORM_COMPONENTS 0x8E1F +#define GL_UNIFORM_BLOCK_REFERENCED_BY_TESS_CONTROL_SHADER 0x84F0 +#define GL_UNIFORM_BLOCK_REFERENCED_BY_TESS_EVALUATION_SHADER 0x84F1 +#define GL_TESS_EVALUATION_SHADER 0x8E87 +#define GL_TESS_CONTROL_SHADER 0x8E88 +#endif + +#ifndef GL_ARB_texture_buffer_object_rgb32 +/* reuse GL_RGB32F */ +/* reuse GL_RGB32UI */ +/* reuse GL_RGB32I */ +#endif + +#ifndef GL_ARB_transform_feedback2 +#define GL_TRANSFORM_FEEDBACK 0x8E22 +#define GL_TRANSFORM_FEEDBACK_PAUSED 0x8E23 +#define GL_TRANSFORM_FEEDBACK_BUFFER_PAUSED GL_TRANSFORM_FEEDBACK_PAUSED +#define GL_TRANSFORM_FEEDBACK_ACTIVE 0x8E24 +#define GL_TRANSFORM_FEEDBACK_BUFFER_ACTIVE GL_TRANSFORM_FEEDBACK_ACTIVE +#define GL_TRANSFORM_FEEDBACK_BINDING 0x8E25 +#endif + +#ifndef GL_ARB_transform_feedback3 +#define GL_MAX_TRANSFORM_FEEDBACK_BUFFERS 0x8E70 +#define GL_MAX_VERTEX_STREAMS 0x8E71 +#endif + +#ifndef GL_ARB_ES2_compatibility +#define GL_FIXED 0x140C +#define GL_IMPLEMENTATION_COLOR_READ_TYPE 0x8B9A +#define GL_IMPLEMENTATION_COLOR_READ_FORMAT 0x8B9B +#define GL_LOW_FLOAT 0x8DF0 +#define GL_MEDIUM_FLOAT 0x8DF1 +#define GL_HIGH_FLOAT 0x8DF2 +#define GL_LOW_INT 0x8DF3 +#define GL_MEDIUM_INT 0x8DF4 +#define GL_HIGH_INT 0x8DF5 +#define GL_SHADER_COMPILER 0x8DFA +#define GL_SHADER_BINARY_FORMATS 0x8DF8 +#define GL_NUM_SHADER_BINARY_FORMATS 0x8DF9 +#define GL_MAX_VERTEX_UNIFORM_VECTORS 0x8DFB +#define GL_MAX_VARYING_VECTORS 0x8DFC +#define GL_MAX_FRAGMENT_UNIFORM_VECTORS 0x8DFD +#define GL_RGB565 0x8D62 +#endif + +#ifndef GL_ARB_get_program_binary +#define GL_PROGRAM_BINARY_RETRIEVABLE_HINT 0x8257 +#define GL_PROGRAM_BINARY_LENGTH 0x8741 +#define GL_NUM_PROGRAM_BINARY_FORMATS 0x87FE +#define GL_PROGRAM_BINARY_FORMATS 0x87FF +#endif + +#ifndef GL_ARB_separate_shader_objects +#define GL_VERTEX_SHADER_BIT 0x00000001 +#define GL_FRAGMENT_SHADER_BIT 0x00000002 +#define GL_GEOMETRY_SHADER_BIT 0x00000004 +#define GL_TESS_CONTROL_SHADER_BIT 0x00000008 +#define GL_TESS_EVALUATION_SHADER_BIT 0x00000010 +#define GL_ALL_SHADER_BITS 0xFFFFFFFF +#define GL_PROGRAM_SEPARABLE 0x8258 +#define GL_ACTIVE_PROGRAM 0x8259 +#define GL_PROGRAM_PIPELINE_BINDING 0x825A +#endif + +#ifndef GL_ARB_shader_precision +#endif + +#ifndef GL_ARB_vertex_attrib_64bit +/* reuse GL_RGB32I */ +/* reuse GL_DOUBLE_VEC2 */ +/* reuse GL_DOUBLE_VEC3 */ +/* reuse GL_DOUBLE_VEC4 */ +/* reuse GL_DOUBLE_MAT2 */ +/* reuse GL_DOUBLE_MAT3 */ +/* reuse GL_DOUBLE_MAT4 */ +/* reuse GL_DOUBLE_MAT2x3 */ +/* reuse GL_DOUBLE_MAT2x4 */ +/* reuse GL_DOUBLE_MAT3x2 */ +/* reuse GL_DOUBLE_MAT3x4 */ +/* reuse GL_DOUBLE_MAT4x2 */ +/* reuse GL_DOUBLE_MAT4x3 */ +#endif + +#ifndef GL_ARB_viewport_array +/* reuse GL_SCISSOR_BOX */ +/* reuse GL_VIEWPORT */ +/* reuse GL_DEPTH_RANGE */ +/* reuse GL_SCISSOR_TEST */ +#define GL_MAX_VIEWPORTS 0x825B +#define GL_VIEWPORT_SUBPIXEL_BITS 0x825C +#define GL_VIEWPORT_BOUNDS_RANGE 0x825D +#define GL_LAYER_PROVOKING_VERTEX 0x825E +#define GL_VIEWPORT_INDEX_PROVOKING_VERTEX 0x825F +#define GL_UNDEFINED_VERTEX 0x8260 +/* reuse GL_FIRST_VERTEX_CONVENTION */ +/* reuse GL_LAST_VERTEX_CONVENTION */ +/* reuse GL_PROVOKING_VERTEX */ +#endif + +#ifndef GL_ARB_cl_event +#define GL_SYNC_CL_EVENT_ARB 0x8240 +#define GL_SYNC_CL_EVENT_COMPLETE_ARB 0x8241 +#endif + +#ifndef GL_ARB_debug_output +#define GL_DEBUG_OUTPUT_SYNCHRONOUS_ARB 0x8242 +#define GL_DEBUG_NEXT_LOGGED_MESSAGE_LENGTH_ARB 0x8243 +#define GL_DEBUG_CALLBACK_FUNCTION_ARB 0x8244 +#define GL_DEBUG_CALLBACK_USER_PARAM_ARB 0x8245 +#define GL_DEBUG_SOURCE_API_ARB 0x8246 +#define GL_DEBUG_SOURCE_WINDOW_SYSTEM_ARB 0x8247 +#define GL_DEBUG_SOURCE_SHADER_COMPILER_ARB 0x8248 +#define GL_DEBUG_SOURCE_THIRD_PARTY_ARB 0x8249 +#define GL_DEBUG_SOURCE_APPLICATION_ARB 0x824A +#define GL_DEBUG_SOURCE_OTHER_ARB 0x824B +#define GL_DEBUG_TYPE_ERROR_ARB 0x824C +#define GL_DEBUG_TYPE_DEPRECATED_BEHAVIOR_ARB 0x824D +#define GL_DEBUG_TYPE_UNDEFINED_BEHAVIOR_ARB 0x824E +#define GL_DEBUG_TYPE_PORTABILITY_ARB 0x824F +#define GL_DEBUG_TYPE_PERFORMANCE_ARB 0x8250 +#define GL_DEBUG_TYPE_OTHER_ARB 0x8251 +#define GL_MAX_DEBUG_MESSAGE_LENGTH_ARB 0x9143 +#define GL_MAX_DEBUG_LOGGED_MESSAGES_ARB 0x9144 +#define GL_DEBUG_LOGGED_MESSAGES_ARB 0x9145 +#define GL_DEBUG_SEVERITY_HIGH_ARB 0x9146 +#define GL_DEBUG_SEVERITY_MEDIUM_ARB 0x9147 +#define GL_DEBUG_SEVERITY_LOW_ARB 0x9148 +#endif + +#ifndef GL_ARB_robustness +/* reuse GL_NO_ERROR */ +#define GL_CONTEXT_FLAG_ROBUST_ACCESS_BIT_ARB 0x00000004 +#define GL_LOSE_CONTEXT_ON_RESET_ARB 0x8252 +#define GL_GUILTY_CONTEXT_RESET_ARB 0x8253 +#define GL_INNOCENT_CONTEXT_RESET_ARB 0x8254 +#define GL_UNKNOWN_CONTEXT_RESET_ARB 0x8255 +#define GL_RESET_NOTIFICATION_STRATEGY_ARB 0x8256 +#define GL_NO_RESET_NOTIFICATION_ARB 0x8261 +#endif + +#ifndef GL_ARB_shader_stencil_export +#endif + +#ifndef GL_ARB_base_instance +#endif + +#ifndef GL_ARB_shading_language_420pack +#endif + +#ifndef GL_ARB_transform_feedback_instanced +#endif + +#ifndef GL_ARB_compressed_texture_pixel_storage +#define GL_UNPACK_COMPRESSED_BLOCK_WIDTH 0x9127 +#define GL_UNPACK_COMPRESSED_BLOCK_HEIGHT 0x9128 +#define GL_UNPACK_COMPRESSED_BLOCK_DEPTH 0x9129 +#define GL_UNPACK_COMPRESSED_BLOCK_SIZE 0x912A +#define GL_PACK_COMPRESSED_BLOCK_WIDTH 0x912B +#define GL_PACK_COMPRESSED_BLOCK_HEIGHT 0x912C +#define GL_PACK_COMPRESSED_BLOCK_DEPTH 0x912D +#define GL_PACK_COMPRESSED_BLOCK_SIZE 0x912E +#endif + +#ifndef GL_ARB_conservative_depth +#endif + +#ifndef GL_ARB_internalformat_query +#define GL_NUM_SAMPLE_COUNTS 0x9380 +#endif + +#ifndef GL_ARB_map_buffer_alignment +#define GL_MIN_MAP_BUFFER_ALIGNMENT 0x90BC +#endif + +#ifndef GL_ARB_shader_atomic_counters +#define GL_ATOMIC_COUNTER_BUFFER 0x92C0 +#define GL_ATOMIC_COUNTER_BUFFER_BINDING 0x92C1 +#define GL_ATOMIC_COUNTER_BUFFER_START 0x92C2 +#define GL_ATOMIC_COUNTER_BUFFER_SIZE 0x92C3 +#define GL_ATOMIC_COUNTER_BUFFER_DATA_SIZE 0x92C4 +#define GL_ATOMIC_COUNTER_BUFFER_ACTIVE_ATOMIC_COUNTERS 0x92C5 +#define GL_ATOMIC_COUNTER_BUFFER_ACTIVE_ATOMIC_COUNTER_INDICES 0x92C6 +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_VERTEX_SHADER 0x92C7 +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_TESS_CONTROL_SHADER 0x92C8 +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_TESS_EVALUATION_SHADER 0x92C9 +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_GEOMETRY_SHADER 0x92CA +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_FRAGMENT_SHADER 0x92CB +#define GL_MAX_VERTEX_ATOMIC_COUNTER_BUFFERS 0x92CC +#define GL_MAX_TESS_CONTROL_ATOMIC_COUNTER_BUFFERS 0x92CD +#define GL_MAX_TESS_EVALUATION_ATOMIC_COUNTER_BUFFERS 0x92CE +#define GL_MAX_GEOMETRY_ATOMIC_COUNTER_BUFFERS 0x92CF +#define GL_MAX_FRAGMENT_ATOMIC_COUNTER_BUFFERS 0x92D0 +#define GL_MAX_COMBINED_ATOMIC_COUNTER_BUFFERS 0x92D1 +#define GL_MAX_VERTEX_ATOMIC_COUNTERS 0x92D2 +#define GL_MAX_TESS_CONTROL_ATOMIC_COUNTERS 0x92D3 +#define GL_MAX_TESS_EVALUATION_ATOMIC_COUNTERS 0x92D4 +#define GL_MAX_GEOMETRY_ATOMIC_COUNTERS 0x92D5 +#define GL_MAX_FRAGMENT_ATOMIC_COUNTERS 0x92D6 +#define GL_MAX_COMBINED_ATOMIC_COUNTERS 0x92D7 +#define GL_MAX_ATOMIC_COUNTER_BUFFER_SIZE 0x92D8 +#define GL_MAX_ATOMIC_COUNTER_BUFFER_BINDINGS 0x92DC +#define GL_ACTIVE_ATOMIC_COUNTER_BUFFERS 0x92D9 +#define GL_UNIFORM_ATOMIC_COUNTER_BUFFER_INDEX 0x92DA +#define GL_UNSIGNED_INT_ATOMIC_COUNTER 0x92DB +#endif + +#ifndef GL_ARB_shader_image_load_store +#define GL_VERTEX_ATTRIB_ARRAY_BARRIER_BIT 0x00000001 +#define GL_ELEMENT_ARRAY_BARRIER_BIT 0x00000002 +#define GL_UNIFORM_BARRIER_BIT 0x00000004 +#define GL_TEXTURE_FETCH_BARRIER_BIT 0x00000008 +#define GL_SHADER_IMAGE_ACCESS_BARRIER_BIT 0x00000020 +#define GL_COMMAND_BARRIER_BIT 0x00000040 +#define GL_PIXEL_BUFFER_BARRIER_BIT 0x00000080 +#define GL_TEXTURE_UPDATE_BARRIER_BIT 0x00000100 +#define GL_BUFFER_UPDATE_BARRIER_BIT 0x00000200 +#define GL_FRAMEBUFFER_BARRIER_BIT 0x00000400 +#define GL_TRANSFORM_FEEDBACK_BARRIER_BIT 0x00000800 +#define GL_ATOMIC_COUNTER_BARRIER_BIT 0x00001000 +#define GL_ALL_BARRIER_BITS 0xFFFFFFFF +#define GL_MAX_IMAGE_UNITS 0x8F38 +#define GL_MAX_COMBINED_IMAGE_UNITS_AND_FRAGMENT_OUTPUTS 0x8F39 +#define GL_IMAGE_BINDING_NAME 0x8F3A +#define GL_IMAGE_BINDING_LEVEL 0x8F3B +#define GL_IMAGE_BINDING_LAYERED 0x8F3C +#define GL_IMAGE_BINDING_LAYER 0x8F3D +#define GL_IMAGE_BINDING_ACCESS 0x8F3E +#define GL_IMAGE_1D 0x904C +#define GL_IMAGE_2D 0x904D +#define GL_IMAGE_3D 0x904E +#define GL_IMAGE_2D_RECT 0x904F +#define GL_IMAGE_CUBE 0x9050 +#define GL_IMAGE_BUFFER 0x9051 +#define GL_IMAGE_1D_ARRAY 0x9052 +#define GL_IMAGE_2D_ARRAY 0x9053 +#define GL_IMAGE_CUBE_MAP_ARRAY 0x9054 +#define GL_IMAGE_2D_MULTISAMPLE 0x9055 +#define GL_IMAGE_2D_MULTISAMPLE_ARRAY 0x9056 +#define GL_INT_IMAGE_1D 0x9057 +#define GL_INT_IMAGE_2D 0x9058 +#define GL_INT_IMAGE_3D 0x9059 +#define GL_INT_IMAGE_2D_RECT 0x905A +#define GL_INT_IMAGE_CUBE 0x905B +#define GL_INT_IMAGE_BUFFER 0x905C +#define GL_INT_IMAGE_1D_ARRAY 0x905D +#define GL_INT_IMAGE_2D_ARRAY 0x905E +#define GL_INT_IMAGE_CUBE_MAP_ARRAY 0x905F +#define GL_INT_IMAGE_2D_MULTISAMPLE 0x9060 +#define GL_INT_IMAGE_2D_MULTISAMPLE_ARRAY 0x9061 +#define GL_UNSIGNED_INT_IMAGE_1D 0x9062 +#define GL_UNSIGNED_INT_IMAGE_2D 0x9063 +#define GL_UNSIGNED_INT_IMAGE_3D 0x9064 +#define GL_UNSIGNED_INT_IMAGE_2D_RECT 0x9065 +#define GL_UNSIGNED_INT_IMAGE_CUBE 0x9066 +#define GL_UNSIGNED_INT_IMAGE_BUFFER 0x9067 +#define GL_UNSIGNED_INT_IMAGE_1D_ARRAY 0x9068 +#define GL_UNSIGNED_INT_IMAGE_2D_ARRAY 0x9069 +#define GL_UNSIGNED_INT_IMAGE_CUBE_MAP_ARRAY 0x906A +#define GL_UNSIGNED_INT_IMAGE_2D_MULTISAMPLE 0x906B +#define GL_UNSIGNED_INT_IMAGE_2D_MULTISAMPLE_ARRAY 0x906C +#define GL_MAX_IMAGE_SAMPLES 0x906D +#define GL_IMAGE_BINDING_FORMAT 0x906E +#define GL_IMAGE_FORMAT_COMPATIBILITY_TYPE 0x90C7 +#define GL_IMAGE_FORMAT_COMPATIBILITY_BY_SIZE 0x90C8 +#define GL_IMAGE_FORMAT_COMPATIBILITY_BY_CLASS 0x90C9 +#define GL_MAX_VERTEX_IMAGE_UNIFORMS 0x90CA +#define GL_MAX_TESS_CONTROL_IMAGE_UNIFORMS 0x90CB +#define GL_MAX_TESS_EVALUATION_IMAGE_UNIFORMS 0x90CC +#define GL_MAX_GEOMETRY_IMAGE_UNIFORMS 0x90CD +#define GL_MAX_FRAGMENT_IMAGE_UNIFORMS 0x90CE +#define GL_MAX_COMBINED_IMAGE_UNIFORMS 0x90CF +#endif + +#ifndef GL_ARB_shading_language_packing +#endif + +#ifndef GL_ARB_texture_storage +#define GL_TEXTURE_IMMUTABLE_FORMAT 0x912F +#endif + +#ifndef GL_KHR_texture_compression_astc_ldr +#define GL_COMPRESSED_RGBA_ASTC_4x4_KHR 0x93B0 +#define GL_COMPRESSED_RGBA_ASTC_5x4_KHR 0x93B1 +#define GL_COMPRESSED_RGBA_ASTC_5x5_KHR 0x93B2 +#define GL_COMPRESSED_RGBA_ASTC_6x5_KHR 0x93B3 +#define GL_COMPRESSED_RGBA_ASTC_6x6_KHR 0x93B4 +#define GL_COMPRESSED_RGBA_ASTC_8x5_KHR 0x93B5 +#define GL_COMPRESSED_RGBA_ASTC_8x6_KHR 0x93B6 +#define GL_COMPRESSED_RGBA_ASTC_8x8_KHR 0x93B7 +#define GL_COMPRESSED_RGBA_ASTC_10x5_KHR 0x93B8 +#define GL_COMPRESSED_RGBA_ASTC_10x6_KHR 0x93B9 +#define GL_COMPRESSED_RGBA_ASTC_10x8_KHR 0x93BA +#define GL_COMPRESSED_RGBA_ASTC_10x10_KHR 0x93BB +#define GL_COMPRESSED_RGBA_ASTC_12x10_KHR 0x93BC +#define GL_COMPRESSED_RGBA_ASTC_12x12_KHR 0x93BD +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_4x4_KHR 0x93D0 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_5x4_KHR 0x93D1 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_5x5_KHR 0x93D2 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_6x5_KHR 0x93D3 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_6x6_KHR 0x93D4 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_8x5_KHR 0x93D5 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_8x6_KHR 0x93D6 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_8x8_KHR 0x93D7 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x5_KHR 0x93D8 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x6_KHR 0x93D9 +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x8_KHR 0x93DA +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x10_KHR 0x93DB +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_12x10_KHR 0x93DC +#define GL_COMPRESSED_SRGB8_ALPHA8_ASTC_12x12_KHR 0x93DD +#endif + +#ifndef GL_KHR_debug +#define GL_DEBUG_OUTPUT_SYNCHRONOUS 0x8242 +#define GL_DEBUG_NEXT_LOGGED_MESSAGE_LENGTH 0x8243 +#define GL_DEBUG_CALLBACK_FUNCTION 0x8244 +#define GL_DEBUG_CALLBACK_USER_PARAM 0x8245 +#define GL_DEBUG_SOURCE_API 0x8246 +#define GL_DEBUG_SOURCE_WINDOW_SYSTEM 0x8247 +#define GL_DEBUG_SOURCE_SHADER_COMPILER 0x8248 +#define GL_DEBUG_SOURCE_THIRD_PARTY 0x8249 +#define GL_DEBUG_SOURCE_APPLICATION 0x824A +#define GL_DEBUG_SOURCE_OTHER 0x824B +#define GL_DEBUG_TYPE_ERROR 0x824C +#define GL_DEBUG_TYPE_DEPRECATED_BEHAVIOR 0x824D +#define GL_DEBUG_TYPE_UNDEFINED_BEHAVIOR 0x824E +#define GL_DEBUG_TYPE_PORTABILITY 0x824F +#define GL_DEBUG_TYPE_PERFORMANCE 0x8250 +#define GL_DEBUG_TYPE_OTHER 0x8251 +#define GL_DEBUG_TYPE_MARKER 0x8268 +#define GL_DEBUG_TYPE_PUSH_GROUP 0x8269 +#define GL_DEBUG_TYPE_POP_GROUP 0x826A +#define GL_DEBUG_SEVERITY_NOTIFICATION 0x826B +#define GL_MAX_DEBUG_GROUP_STACK_DEPTH 0x826C +#define GL_DEBUG_GROUP_STACK_DEPTH 0x826D +#define GL_BUFFER 0x82E0 +#define GL_SHADER 0x82E1 +#define GL_PROGRAM 0x82E2 +#define GL_QUERY 0x82E3 +#define GL_PROGRAM_PIPELINE 0x82E4 +#define GL_SAMPLER 0x82E6 +#define GL_DISPLAY_LIST 0x82E7 +/* DISPLAY_LIST used in compatibility profile only */ +#define GL_MAX_LABEL_LENGTH 0x82E8 +#define GL_MAX_DEBUG_MESSAGE_LENGTH 0x9143 +#define GL_MAX_DEBUG_LOGGED_MESSAGES 0x9144 +#define GL_DEBUG_LOGGED_MESSAGES 0x9145 +#define GL_DEBUG_SEVERITY_HIGH 0x9146 +#define GL_DEBUG_SEVERITY_MEDIUM 0x9147 +#define GL_DEBUG_SEVERITY_LOW 0x9148 +#define GL_DEBUG_OUTPUT 0x92E0 +#define GL_CONTEXT_FLAG_DEBUG_BIT 0x00000002 +/* reuse GL_STACK_UNDERFLOW */ +/* reuse GL_STACK_OVERFLOW */ +#endif + +#ifndef GL_ARB_arrays_of_arrays +#endif + +#ifndef GL_ARB_clear_buffer_object +#endif + +#ifndef GL_ARB_compute_shader +#define GL_COMPUTE_SHADER 0x91B9 +#define GL_MAX_COMPUTE_UNIFORM_BLOCKS 0x91BB +#define GL_MAX_COMPUTE_TEXTURE_IMAGE_UNITS 0x91BC +#define GL_MAX_COMPUTE_IMAGE_UNIFORMS 0x91BD +#define GL_MAX_COMPUTE_SHARED_MEMORY_SIZE 0x8262 +#define GL_MAX_COMPUTE_UNIFORM_COMPONENTS 0x8263 +#define GL_MAX_COMPUTE_ATOMIC_COUNTER_BUFFERS 0x8264 +#define GL_MAX_COMPUTE_ATOMIC_COUNTERS 0x8265 +#define GL_MAX_COMBINED_COMPUTE_UNIFORM_COMPONENTS 0x8266 +#define GL_MAX_COMPUTE_LOCAL_INVOCATIONS 0x90EB +#define GL_MAX_COMPUTE_WORK_GROUP_COUNT 0x91BE +#define GL_MAX_COMPUTE_WORK_GROUP_SIZE 0x91BF +#define GL_COMPUTE_LOCAL_WORK_SIZE 0x8267 +#define GL_UNIFORM_BLOCK_REFERENCED_BY_COMPUTE_SHADER 0x90EC +#define GL_ATOMIC_COUNTER_BUFFER_REFERENCED_BY_COMPUTE_SHADER 0x90ED +#define GL_DISPATCH_INDIRECT_BUFFER 0x90EE +#define GL_DISPATCH_INDIRECT_BUFFER_BINDING 0x90EF +#define GL_COMPUTE_SHADER_BIT 0x00000020 +#endif + +#ifndef GL_ARB_copy_image +#endif + +#ifndef GL_ARB_texture_view +#define GL_TEXTURE_VIEW_MIN_LEVEL 0x82DB +#define GL_TEXTURE_VIEW_NUM_LEVELS 0x82DC +#define GL_TEXTURE_VIEW_MIN_LAYER 0x82DD +#define GL_TEXTURE_VIEW_NUM_LAYERS 0x82DE +#define GL_TEXTURE_IMMUTABLE_LEVELS 0x82DF +#endif + +#ifndef GL_ARB_vertex_attrib_binding +#define GL_VERTEX_ATTRIB_BINDING 0x82D4 +#define GL_VERTEX_ATTRIB_RELATIVE_OFFSET 0x82D5 +#define GL_VERTEX_BINDING_DIVISOR 0x82D6 +#define GL_VERTEX_BINDING_OFFSET 0x82D7 +#define GL_VERTEX_BINDING_STRIDE 0x82D8 +#define GL_MAX_VERTEX_ATTRIB_RELATIVE_OFFSET 0x82D9 +#define GL_MAX_VERTEX_ATTRIB_BINDINGS 0x82DA +#endif + +#ifndef GL_ARB_robustness_isolation +#endif + +#ifndef GL_ARB_ES3_compatibility +#define GL_COMPRESSED_RGB8_ETC2 0x9274 +#define GL_COMPRESSED_SRGB8_ETC2 0x9275 +#define GL_COMPRESSED_RGB8_PUNCHTHROUGH_ALPHA1_ETC2 0x9276 +#define GL_COMPRESSED_SRGB8_PUNCHTHROUGH_ALPHA1_ETC2 0x9277 +#define GL_COMPRESSED_RGBA8_ETC2_EAC 0x9278 +#define GL_COMPRESSED_SRGB8_ALPHA8_ETC2_EAC 0x9279 +#define GL_COMPRESSED_R11_EAC 0x9270 +#define GL_COMPRESSED_SIGNED_R11_EAC 0x9271 +#define GL_COMPRESSED_RG11_EAC 0x9272 +#define GL_COMPRESSED_SIGNED_RG11_EAC 0x9273 +#define GL_PRIMITIVE_RESTART_FIXED_INDEX 0x8D69 +#define GL_ANY_SAMPLES_PASSED_CONSERVATIVE 0x8D6A +#define GL_MAX_ELEMENT_INDEX 0x8D6B +#endif + +#ifndef GL_ARB_explicit_uniform_location +#define GL_MAX_UNIFORM_LOCATIONS 0x826E +#endif + +#ifndef GL_ARB_fragment_layer_viewport +#endif + +#ifndef GL_ARB_framebuffer_no_attachments +#define GL_FRAMEBUFFER_DEFAULT_WIDTH 0x9310 +#define GL_FRAMEBUFFER_DEFAULT_HEIGHT 0x9311 +#define GL_FRAMEBUFFER_DEFAULT_LAYERS 0x9312 +#define GL_FRAMEBUFFER_DEFAULT_SAMPLES 0x9313 +#define GL_FRAMEBUFFER_DEFAULT_FIXED_SAMPLE_LOCATIONS 0x9314 +#define GL_MAX_FRAMEBUFFER_WIDTH 0x9315 +#define GL_MAX_FRAMEBUFFER_HEIGHT 0x9316 +#define GL_MAX_FRAMEBUFFER_LAYERS 0x9317 +#define GL_MAX_FRAMEBUFFER_SAMPLES 0x9318 +#endif + +#ifndef GL_ARB_internalformat_query2 +/* reuse GL_IMAGE_FORMAT_COMPATIBILITY_TYPE */ +/* reuse GL_NUM_SAMPLE_COUNTS */ +/* reuse GL_RENDERBUFFER */ +/* reuse GL_SAMPLES */ +/* reuse GL_TEXTURE_1D */ +/* reuse GL_TEXTURE_1D_ARRAY */ +/* reuse GL_TEXTURE_2D */ +/* reuse GL_TEXTURE_2D_ARRAY */ +/* reuse GL_TEXTURE_3D */ +/* reuse GL_TEXTURE_CUBE_MAP */ +/* reuse GL_TEXTURE_CUBE_MAP_ARRAY */ +/* reuse GL_TEXTURE_RECTANGLE */ +/* reuse GL_TEXTURE_BUFFER */ +/* reuse GL_TEXTURE_2D_MULTISAMPLE */ +/* reuse GL_TEXTURE_2D_MULTISAMPLE_ARRAY */ +/* reuse GL_TEXTURE_COMPRESSED */ +#define GL_INTERNALFORMAT_SUPPORTED 0x826F +#define GL_INTERNALFORMAT_PREFERRED 0x8270 +#define GL_INTERNALFORMAT_RED_SIZE 0x8271 +#define GL_INTERNALFORMAT_GREEN_SIZE 0x8272 +#define GL_INTERNALFORMAT_BLUE_SIZE 0x8273 +#define GL_INTERNALFORMAT_ALPHA_SIZE 0x8274 +#define GL_INTERNALFORMAT_DEPTH_SIZE 0x8275 +#define GL_INTERNALFORMAT_STENCIL_SIZE 0x8276 +#define GL_INTERNALFORMAT_SHARED_SIZE 0x8277 +#define GL_INTERNALFORMAT_RED_TYPE 0x8278 +#define GL_INTERNALFORMAT_GREEN_TYPE 0x8279 +#define GL_INTERNALFORMAT_BLUE_TYPE 0x827A +#define GL_INTERNALFORMAT_ALPHA_TYPE 0x827B +#define GL_INTERNALFORMAT_DEPTH_TYPE 0x827C +#define GL_INTERNALFORMAT_STENCIL_TYPE 0x827D +#define GL_MAX_WIDTH 0x827E +#define GL_MAX_HEIGHT 0x827F +#define GL_MAX_DEPTH 0x8280 +#define GL_MAX_LAYERS 0x8281 +#define GL_MAX_COMBINED_DIMENSIONS 0x8282 +#define GL_COLOR_COMPONENTS 0x8283 +#define GL_DEPTH_COMPONENTS 0x8284 +#define GL_STENCIL_COMPONENTS 0x8285 +#define GL_COLOR_RENDERABLE 0x8286 +#define GL_DEPTH_RENDERABLE 0x8287 +#define GL_STENCIL_RENDERABLE 0x8288 +#define GL_FRAMEBUFFER_RENDERABLE 0x8289 +#define GL_FRAMEBUFFER_RENDERABLE_LAYERED 0x828A +#define GL_FRAMEBUFFER_BLEND 0x828B +#define GL_READ_PIXELS 0x828C +#define GL_READ_PIXELS_FORMAT 0x828D +#define GL_READ_PIXELS_TYPE 0x828E +#define GL_TEXTURE_IMAGE_FORMAT 0x828F +#define GL_TEXTURE_IMAGE_TYPE 0x8290 +#define GL_GET_TEXTURE_IMAGE_FORMAT 0x8291 +#define GL_GET_TEXTURE_IMAGE_TYPE 0x8292 +#define GL_MIPMAP 0x8293 +#define GL_MANUAL_GENERATE_MIPMAP 0x8294 +#define GL_AUTO_GENERATE_MIPMAP 0x8295 +#define GL_COLOR_ENCODING 0x8296 +#define GL_SRGB_READ 0x8297 +#define GL_SRGB_WRITE 0x8298 +#define GL_SRGB_DECODE_ARB 0x8299 +#define GL_FILTER 0x829A +#define GL_VERTEX_TEXTURE 0x829B +#define GL_TESS_CONTROL_TEXTURE 0x829C +#define GL_TESS_EVALUATION_TEXTURE 0x829D +#define GL_GEOMETRY_TEXTURE 0x829E +#define GL_FRAGMENT_TEXTURE 0x829F +#define GL_COMPUTE_TEXTURE 0x82A0 +#define GL_TEXTURE_SHADOW 0x82A1 +#define GL_TEXTURE_GATHER 0x82A2 +#define GL_TEXTURE_GATHER_SHADOW 0x82A3 +#define GL_SHADER_IMAGE_LOAD 0x82A4 +#define GL_SHADER_IMAGE_STORE 0x82A5 +#define GL_SHADER_IMAGE_ATOMIC 0x82A6 +#define GL_IMAGE_TEXEL_SIZE 0x82A7 +#define GL_IMAGE_COMPATIBILITY_CLASS 0x82A8 +#define GL_IMAGE_PIXEL_FORMAT 0x82A9 +#define GL_IMAGE_PIXEL_TYPE 0x82AA +#define GL_SIMULTANEOUS_TEXTURE_AND_DEPTH_TEST 0x82AC +#define GL_SIMULTANEOUS_TEXTURE_AND_STENCIL_TEST 0x82AD +#define GL_SIMULTANEOUS_TEXTURE_AND_DEPTH_WRITE 0x82AE +#define GL_SIMULTANEOUS_TEXTURE_AND_STENCIL_WRITE 0x82AF +#define GL_TEXTURE_COMPRESSED_BLOCK_WIDTH 0x82B1 +#define GL_TEXTURE_COMPRESSED_BLOCK_HEIGHT 0x82B2 +#define GL_TEXTURE_COMPRESSED_BLOCK_SIZE 0x82B3 +#define GL_CLEAR_BUFFER 0x82B4 +#define GL_TEXTURE_VIEW 0x82B5 +#define GL_VIEW_COMPATIBILITY_CLASS 0x82B6 +#define GL_FULL_SUPPORT 0x82B7 +#define GL_CAVEAT_SUPPORT 0x82B8 +#define GL_IMAGE_CLASS_4_X_32 0x82B9 +#define GL_IMAGE_CLASS_2_X_32 0x82BA +#define GL_IMAGE_CLASS_1_X_32 0x82BB +#define GL_IMAGE_CLASS_4_X_16 0x82BC +#define GL_IMAGE_CLASS_2_X_16 0x82BD +#define GL_IMAGE_CLASS_1_X_16 0x82BE +#define GL_IMAGE_CLASS_4_X_8 0x82BF +#define GL_IMAGE_CLASS_2_X_8 0x82C0 +#define GL_IMAGE_CLASS_1_X_8 0x82C1 +#define GL_IMAGE_CLASS_11_11_10 0x82C2 +#define GL_IMAGE_CLASS_10_10_10_2 0x82C3 +#define GL_VIEW_CLASS_128_BITS 0x82C4 +#define GL_VIEW_CLASS_96_BITS 0x82C5 +#define GL_VIEW_CLASS_64_BITS 0x82C6 +#define GL_VIEW_CLASS_48_BITS 0x82C7 +#define GL_VIEW_CLASS_32_BITS 0x82C8 +#define GL_VIEW_CLASS_24_BITS 0x82C9 +#define GL_VIEW_CLASS_16_BITS 0x82CA +#define GL_VIEW_CLASS_8_BITS 0x82CB +#define GL_VIEW_CLASS_S3TC_DXT1_RGB 0x82CC +#define GL_VIEW_CLASS_S3TC_DXT1_RGBA 0x82CD +#define GL_VIEW_CLASS_S3TC_DXT3_RGBA 0x82CE +#define GL_VIEW_CLASS_S3TC_DXT5_RGBA 0x82CF +#define GL_VIEW_CLASS_RGTC1_RED 0x82D0 +#define GL_VIEW_CLASS_RGTC2_RG 0x82D1 +#define GL_VIEW_CLASS_BPTC_UNORM 0x82D2 +#define GL_VIEW_CLASS_BPTC_FLOAT 0x82D3 +#endif + +#ifndef GL_ARB_invalidate_subdata +#endif + +#ifndef GL_ARB_multi_draw_indirect +#endif + +#ifndef GL_ARB_program_interface_query +#define GL_UNIFORM 0x92E1 +#define GL_UNIFORM_BLOCK 0x92E2 +#define GL_PROGRAM_INPUT 0x92E3 +#define GL_PROGRAM_OUTPUT 0x92E4 +#define GL_BUFFER_VARIABLE 0x92E5 +#define GL_SHADER_STORAGE_BLOCK 0x92E6 +/* reuse GL_ATOMIC_COUNTER_BUFFER */ +#define GL_VERTEX_SUBROUTINE 0x92E8 +#define GL_TESS_CONTROL_SUBROUTINE 0x92E9 +#define GL_TESS_EVALUATION_SUBROUTINE 0x92EA +#define GL_GEOMETRY_SUBROUTINE 0x92EB +#define GL_FRAGMENT_SUBROUTINE 0x92EC +#define GL_COMPUTE_SUBROUTINE 0x92ED +#define GL_VERTEX_SUBROUTINE_UNIFORM 0x92EE +#define GL_TESS_CONTROL_SUBROUTINE_UNIFORM 0x92EF +#define GL_TESS_EVALUATION_SUBROUTINE_UNIFORM 0x92F0 +#define GL_GEOMETRY_SUBROUTINE_UNIFORM 0x92F1 +#define GL_FRAGMENT_SUBROUTINE_UNIFORM 0x92F2 +#define GL_COMPUTE_SUBROUTINE_UNIFORM 0x92F3 +#define GL_TRANSFORM_FEEDBACK_VARYING 0x92F4 +#define GL_ACTIVE_RESOURCES 0x92F5 +#define GL_MAX_NAME_LENGTH 0x92F6 +#define GL_MAX_NUM_ACTIVE_VARIABLES 0x92F7 +#define GL_MAX_NUM_COMPATIBLE_SUBROUTINES 0x92F8 +#define GL_NAME_LENGTH 0x92F9 +#define GL_TYPE 0x92FA +#define GL_ARRAY_SIZE 0x92FB +#define GL_OFFSET 0x92FC +#define GL_BLOCK_INDEX 0x92FD +#define GL_ARRAY_STRIDE 0x92FE +#define GL_MATRIX_STRIDE 0x92FF +#define GL_IS_ROW_MAJOR 0x9300 +#define GL_ATOMIC_COUNTER_BUFFER_INDEX 0x9301 +#define GL_BUFFER_BINDING 0x9302 +#define GL_BUFFER_DATA_SIZE 0x9303 +#define GL_NUM_ACTIVE_VARIABLES 0x9304 +#define GL_ACTIVE_VARIABLES 0x9305 +#define GL_REFERENCED_BY_VERTEX_SHADER 0x9306 +#define GL_REFERENCED_BY_TESS_CONTROL_SHADER 0x9307 +#define GL_REFERENCED_BY_TESS_EVALUATION_SHADER 0x9308 +#define GL_REFERENCED_BY_GEOMETRY_SHADER 0x9309 +#define GL_REFERENCED_BY_FRAGMENT_SHADER 0x930A +#define GL_REFERENCED_BY_COMPUTE_SHADER 0x930B +#define GL_TOP_LEVEL_ARRAY_SIZE 0x930C +#define GL_TOP_LEVEL_ARRAY_STRIDE 0x930D +#define GL_LOCATION 0x930E +#define GL_LOCATION_INDEX 0x930F +#define GL_IS_PER_PATCH 0x92E7 +/* reuse GL_NUM_COMPATIBLE_SUBROUTINES */ +/* reuse GL_COMPATIBLE_SUBROUTINES */ +#endif + +#ifndef GL_ARB_robust_buffer_access_behavior +#endif + +#ifndef GL_ARB_shader_image_size +#endif + +#ifndef GL_ARB_shader_storage_buffer_object +#define GL_SHADER_STORAGE_BUFFER 0x90D2 +#define GL_SHADER_STORAGE_BUFFER_BINDING 0x90D3 +#define GL_SHADER_STORAGE_BUFFER_START 0x90D4 +#define GL_SHADER_STORAGE_BUFFER_SIZE 0x90D5 +#define GL_MAX_VERTEX_SHADER_STORAGE_BLOCKS 0x90D6 +#define GL_MAX_GEOMETRY_SHADER_STORAGE_BLOCKS 0x90D7 +#define GL_MAX_TESS_CONTROL_SHADER_STORAGE_BLOCKS 0x90D8 +#define GL_MAX_TESS_EVALUATION_SHADER_STORAGE_BLOCKS 0x90D9 +#define GL_MAX_FRAGMENT_SHADER_STORAGE_BLOCKS 0x90DA +#define GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS 0x90DB +#define GL_MAX_COMBINED_SHADER_STORAGE_BLOCKS 0x90DC +#define GL_MAX_SHADER_STORAGE_BUFFER_BINDINGS 0x90DD +#define GL_MAX_SHADER_STORAGE_BLOCK_SIZE 0x90DE +#define GL_SHADER_STORAGE_BUFFER_OFFSET_ALIGNMENT 0x90DF +#define GL_SHADER_STORAGE_BARRIER_BIT 0x2000 +#define GL_MAX_COMBINED_SHADER_OUTPUT_RESOURCES GL_MAX_COMBINED_IMAGE_UNITS_AND_FRAGMENT_OUTPUTS +/* reuse GL_MAX_COMBINED_IMAGE_UNITS_AND_FRAGMENT_OUTPUTS */ +#endif + +#ifndef GL_ARB_stencil_texturing +#define GL_DEPTH_STENCIL_TEXTURE_MODE 0x90EA +#endif + +#ifndef GL_ARB_texture_buffer_range +#define GL_TEXTURE_BUFFER_OFFSET 0x919D +#define GL_TEXTURE_BUFFER_SIZE 0x919E +#define GL_TEXTURE_BUFFER_OFFSET_ALIGNMENT 0x919F +#endif + +#ifndef GL_ARB_texture_query_levels +#endif + +#ifndef GL_ARB_texture_storage_multisample +#endif + + +/*************************************************************/ + +#include +#ifndef GL_VERSION_2_0 +/* GL type for program/shader text */ +typedef char GLchar; +#endif + +#ifndef GL_VERSION_1_5 +/* GL types for handling large vertex buffer objects */ +typedef ptrdiff_t GLintptr; +typedef ptrdiff_t GLsizeiptr; +#endif + +#ifndef GL_ARB_vertex_buffer_object +/* GL types for handling large vertex buffer objects */ +typedef ptrdiff_t GLintptrARB; +typedef ptrdiff_t GLsizeiptrARB; +#endif + +#ifndef GL_ARB_shader_objects +/* GL types for program/shader text and shader object handles */ +typedef char GLcharARB; +typedef unsigned int GLhandleARB; +#endif + +/* GL type for "half" precision (s10e5) float data in host memory */ +#ifndef GL_ARB_half_float_pixel +typedef unsigned short GLhalfARB; +#endif + +#ifndef GL_NV_half_float +typedef unsigned short GLhalfNV; +#endif + +#ifndef GLEXT_64_TYPES_DEFINED +/* This code block is duplicated in glxext.h, so must be protected */ +#define GLEXT_64_TYPES_DEFINED +/* Define int32_t, int64_t, and uint64_t types for UST/MSC */ +/* (as used in the GL_EXT_timer_query extension). */ +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#include +#elif defined(__sun__) || defined(__digital__) +#include +#if defined(__STDC__) +#if defined(__arch64__) || defined(_LP64) +typedef long int int64_t; +typedef unsigned long int uint64_t; +#else +typedef long long int int64_t; +typedef unsigned long long int uint64_t; +#endif /* __arch64__ */ +#endif /* __STDC__ */ +#elif defined( __VMS ) || defined(__sgi) +#include +#elif defined(__SCO__) || defined(__USLC__) +#include +#elif defined(__UNIXOS2__) || defined(__SOL64__) +typedef long int int32_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; +#elif defined(_WIN32) && defined(__GNUC__) +#include +#elif defined(_WIN32) +typedef __int32 int32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#else +/* Fallback if nothing above works */ +#include +#endif +#endif + +#ifndef GL_EXT_timer_query +typedef int64_t GLint64EXT; +typedef uint64_t GLuint64EXT; +#endif + +#ifndef GL_ARB_sync +typedef int64_t GLint64; +typedef uint64_t GLuint64; +typedef struct __GLsync *GLsync; +#endif + +#ifndef GL_ARB_cl_event +/* These incomplete types let us declare types compatible with OpenCL's cl_context and cl_event */ +struct _cl_context; +struct _cl_event; +#endif + +#ifndef GL_ARB_debug_output +typedef void (APIENTRY *GLDEBUGPROCARB)(GLenum source,GLenum type,GLuint id,GLenum severity,GLsizei length,const GLchar *message,GLvoid *userParam); +#endif + +#ifndef GL_AMD_debug_output +typedef void (APIENTRY *GLDEBUGPROCAMD)(GLuint id,GLenum category,GLenum severity,GLsizei length,const GLchar *message,GLvoid *userParam); +#endif + +#ifndef GL_KHR_debug +typedef void (APIENTRY *GLDEBUGPROC)(GLenum source,GLenum type,GLuint id,GLenum severity,GLsizei length,const GLchar *message,GLvoid *userParam); +#endif + +#ifndef GL_NV_vdpau_interop +typedef GLintptr GLvdpauSurfaceNV; +#endif + +#ifndef GL_VERSION_1_0 +#define GL_VERSION_1_0 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glCullFace (GLenum mode); +GLAPI void APIENTRY glFrontFace (GLenum mode); +GLAPI void APIENTRY glHint (GLenum target, GLenum mode); +GLAPI void APIENTRY glLineWidth (GLfloat width); +GLAPI void APIENTRY glPointSize (GLfloat size); +GLAPI void APIENTRY glPolygonMode (GLenum face, GLenum mode); +GLAPI void APIENTRY glScissor (GLint x, GLint y, GLsizei width, GLsizei height); +GLAPI void APIENTRY glTexParameterf (GLenum target, GLenum pname, GLfloat param); +GLAPI void APIENTRY glTexParameterfv (GLenum target, GLenum pname, const GLfloat *params); +GLAPI void APIENTRY glTexParameteri (GLenum target, GLenum pname, GLint param); +GLAPI void APIENTRY glTexParameteriv (GLenum target, GLenum pname, const GLint *params); +GLAPI void APIENTRY glTexImage1D (GLenum target, GLint level, GLint internalformat, GLsizei width, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glTexImage2D (GLenum target, GLint level, GLint internalformat, GLsizei width, GLsizei height, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glDrawBuffer (GLenum mode); +GLAPI void APIENTRY glClear (GLbitfield mask); +GLAPI void APIENTRY glClearColor (GLfloat red, GLfloat green, GLfloat blue, GLfloat alpha); +GLAPI void APIENTRY glClearStencil (GLint s); +GLAPI void APIENTRY glClearDepth (GLdouble depth); +GLAPI void APIENTRY glStencilMask (GLuint mask); +GLAPI void APIENTRY glColorMask (GLboolean red, GLboolean green, GLboolean blue, GLboolean alpha); +GLAPI void APIENTRY glDepthMask (GLboolean flag); +GLAPI void APIENTRY glDisable (GLenum cap); +GLAPI void APIENTRY glEnable (GLenum cap); +GLAPI void APIENTRY glFinish (void); +GLAPI void APIENTRY glFlush (void); +GLAPI void APIENTRY glBlendFunc (GLenum sfactor, GLenum dfactor); +GLAPI void APIENTRY glLogicOp (GLenum opcode); +GLAPI void APIENTRY glStencilFunc (GLenum func, GLint ref, GLuint mask); +GLAPI void APIENTRY glStencilOp (GLenum fail, GLenum zfail, GLenum zpass); +GLAPI void APIENTRY glDepthFunc (GLenum func); +GLAPI void APIENTRY glPixelStoref (GLenum pname, GLfloat param); +GLAPI void APIENTRY glPixelStorei (GLenum pname, GLint param); +GLAPI void APIENTRY glReadBuffer (GLenum mode); +GLAPI void APIENTRY glReadPixels (GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, GLvoid *pixels); +GLAPI void APIENTRY glGetBooleanv (GLenum pname, GLboolean *params); +GLAPI void APIENTRY glGetDoublev (GLenum pname, GLdouble *params); +GLAPI GLenum APIENTRY glGetError (void); +GLAPI void APIENTRY glGetFloatv (GLenum pname, GLfloat *params); +GLAPI void APIENTRY glGetIntegerv (GLenum pname, GLint *params); +GLAPI const GLubyte * APIENTRY glGetString (GLenum name); +GLAPI void APIENTRY glGetTexImage (GLenum target, GLint level, GLenum format, GLenum type, GLvoid *pixels); +GLAPI void APIENTRY glGetTexParameterfv (GLenum target, GLenum pname, GLfloat *params); +GLAPI void APIENTRY glGetTexParameteriv (GLenum target, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetTexLevelParameterfv (GLenum target, GLint level, GLenum pname, GLfloat *params); +GLAPI void APIENTRY glGetTexLevelParameteriv (GLenum target, GLint level, GLenum pname, GLint *params); +GLAPI GLboolean APIENTRY glIsEnabled (GLenum cap); +GLAPI void APIENTRY glDepthRange (GLdouble near, GLdouble far); +GLAPI void APIENTRY glViewport (GLint x, GLint y, GLsizei width, GLsizei height); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLCULLFACEPROC) (GLenum mode); +typedef void (APIENTRYP PFNGLFRONTFACEPROC) (GLenum mode); +typedef void (APIENTRYP PFNGLHINTPROC) (GLenum target, GLenum mode); +typedef void (APIENTRYP PFNGLLINEWIDTHPROC) (GLfloat width); +typedef void (APIENTRYP PFNGLPOINTSIZEPROC) (GLfloat size); +typedef void (APIENTRYP PFNGLPOLYGONMODEPROC) (GLenum face, GLenum mode); +typedef void (APIENTRYP PFNGLSCISSORPROC) (GLint x, GLint y, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLTEXPARAMETERFPROC) (GLenum target, GLenum pname, GLfloat param); +typedef void (APIENTRYP PFNGLTEXPARAMETERFVPROC) (GLenum target, GLenum pname, const GLfloat *params); +typedef void (APIENTRYP PFNGLTEXPARAMETERIPROC) (GLenum target, GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLTEXPARAMETERIVPROC) (GLenum target, GLenum pname, const GLint *params); +typedef void (APIENTRYP PFNGLTEXIMAGE1DPROC) (GLenum target, GLint level, GLint internalformat, GLsizei width, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLTEXIMAGE2DPROC) (GLenum target, GLint level, GLint internalformat, GLsizei width, GLsizei height, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLDRAWBUFFERPROC) (GLenum mode); +typedef void (APIENTRYP PFNGLCLEARPROC) (GLbitfield mask); +typedef void (APIENTRYP PFNGLCLEARCOLORPROC) (GLfloat red, GLfloat green, GLfloat blue, GLfloat alpha); +typedef void (APIENTRYP PFNGLCLEARSTENCILPROC) (GLint s); +typedef void (APIENTRYP PFNGLCLEARDEPTHPROC) (GLdouble depth); +typedef void (APIENTRYP PFNGLSTENCILMASKPROC) (GLuint mask); +typedef void (APIENTRYP PFNGLCOLORMASKPROC) (GLboolean red, GLboolean green, GLboolean blue, GLboolean alpha); +typedef void (APIENTRYP PFNGLDEPTHMASKPROC) (GLboolean flag); +typedef void (APIENTRYP PFNGLDISABLEPROC) (GLenum cap); +typedef void (APIENTRYP PFNGLENABLEPROC) (GLenum cap); +typedef void (APIENTRYP PFNGLFINISHPROC) (void); +typedef void (APIENTRYP PFNGLFLUSHPROC) (void); +typedef void (APIENTRYP PFNGLBLENDFUNCPROC) (GLenum sfactor, GLenum dfactor); +typedef void (APIENTRYP PFNGLLOGICOPPROC) (GLenum opcode); +typedef void (APIENTRYP PFNGLSTENCILFUNCPROC) (GLenum func, GLint ref, GLuint mask); +typedef void (APIENTRYP PFNGLSTENCILOPPROC) (GLenum fail, GLenum zfail, GLenum zpass); +typedef void (APIENTRYP PFNGLDEPTHFUNCPROC) (GLenum func); +typedef void (APIENTRYP PFNGLPIXELSTOREFPROC) (GLenum pname, GLfloat param); +typedef void (APIENTRYP PFNGLPIXELSTOREIPROC) (GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLREADBUFFERPROC) (GLenum mode); +typedef void (APIENTRYP PFNGLREADPIXELSPROC) (GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, GLvoid *pixels); +typedef void (APIENTRYP PFNGLGETBOOLEANVPROC) (GLenum pname, GLboolean *params); +typedef void (APIENTRYP PFNGLGETDOUBLEVPROC) (GLenum pname, GLdouble *params); +typedef GLenum (APIENTRYP PFNGLGETERRORPROC) (void); +typedef void (APIENTRYP PFNGLGETFLOATVPROC) (GLenum pname, GLfloat *params); +typedef void (APIENTRYP PFNGLGETINTEGERVPROC) (GLenum pname, GLint *params); +typedef const GLubyte * (APIENTRYP PFNGLGETSTRINGPROC) (GLenum name); +typedef void (APIENTRYP PFNGLGETTEXIMAGEPROC) (GLenum target, GLint level, GLenum format, GLenum type, GLvoid *pixels); +typedef void (APIENTRYP PFNGLGETTEXPARAMETERFVPROC) (GLenum target, GLenum pname, GLfloat *params); +typedef void (APIENTRYP PFNGLGETTEXPARAMETERIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETTEXLEVELPARAMETERFVPROC) (GLenum target, GLint level, GLenum pname, GLfloat *params); +typedef void (APIENTRYP PFNGLGETTEXLEVELPARAMETERIVPROC) (GLenum target, GLint level, GLenum pname, GLint *params); +typedef GLboolean (APIENTRYP PFNGLISENABLEDPROC) (GLenum cap); +typedef void (APIENTRYP PFNGLDEPTHRANGEPROC) (GLdouble near, GLdouble far); +typedef void (APIENTRYP PFNGLVIEWPORTPROC) (GLint x, GLint y, GLsizei width, GLsizei height); +#endif + +#ifndef GL_VERSION_1_1 +#define GL_VERSION_1_1 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawArrays (GLenum mode, GLint first, GLsizei count); +GLAPI void APIENTRY glDrawElements (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices); +GLAPI void APIENTRY glGetPointerv (GLenum pname, GLvoid* *params); +GLAPI void APIENTRY glPolygonOffset (GLfloat factor, GLfloat units); +GLAPI void APIENTRY glCopyTexImage1D (GLenum target, GLint level, GLenum internalformat, GLint x, GLint y, GLsizei width, GLint border); +GLAPI void APIENTRY glCopyTexImage2D (GLenum target, GLint level, GLenum internalformat, GLint x, GLint y, GLsizei width, GLsizei height, GLint border); +GLAPI void APIENTRY glCopyTexSubImage1D (GLenum target, GLint level, GLint xoffset, GLint x, GLint y, GLsizei width); +GLAPI void APIENTRY glCopyTexSubImage2D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint x, GLint y, GLsizei width, GLsizei height); +GLAPI void APIENTRY glTexSubImage1D (GLenum target, GLint level, GLint xoffset, GLsizei width, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glTexSubImage2D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, GLsizei height, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glBindTexture (GLenum target, GLuint texture); +GLAPI void APIENTRY glDeleteTextures (GLsizei n, const GLuint *textures); +GLAPI void APIENTRY glGenTextures (GLsizei n, GLuint *textures); +GLAPI GLboolean APIENTRY glIsTexture (GLuint texture); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWARRAYSPROC) (GLenum mode, GLint first, GLsizei count); +typedef void (APIENTRYP PFNGLDRAWELEMENTSPROC) (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices); +typedef void (APIENTRYP PFNGLGETPOINTERVPROC) (GLenum pname, GLvoid* *params); +typedef void (APIENTRYP PFNGLPOLYGONOFFSETPROC) (GLfloat factor, GLfloat units); +typedef void (APIENTRYP PFNGLCOPYTEXIMAGE1DPROC) (GLenum target, GLint level, GLenum internalformat, GLint x, GLint y, GLsizei width, GLint border); +typedef void (APIENTRYP PFNGLCOPYTEXIMAGE2DPROC) (GLenum target, GLint level, GLenum internalformat, GLint x, GLint y, GLsizei width, GLsizei height, GLint border); +typedef void (APIENTRYP PFNGLCOPYTEXSUBIMAGE1DPROC) (GLenum target, GLint level, GLint xoffset, GLint x, GLint y, GLsizei width); +typedef void (APIENTRYP PFNGLCOPYTEXSUBIMAGE2DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint x, GLint y, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLTEXSUBIMAGE1DPROC) (GLenum target, GLint level, GLint xoffset, GLsizei width, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLTEXSUBIMAGE2DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, GLsizei height, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLBINDTEXTUREPROC) (GLenum target, GLuint texture); +typedef void (APIENTRYP PFNGLDELETETEXTURESPROC) (GLsizei n, const GLuint *textures); +typedef void (APIENTRYP PFNGLGENTEXTURESPROC) (GLsizei n, GLuint *textures); +typedef GLboolean (APIENTRYP PFNGLISTEXTUREPROC) (GLuint texture); +#endif + +#ifndef GL_VERSION_1_2 +#define GL_VERSION_1_2 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBlendColor (GLfloat red, GLfloat green, GLfloat blue, GLfloat alpha); +GLAPI void APIENTRY glBlendEquation (GLenum mode); +GLAPI void APIENTRY glDrawRangeElements (GLenum mode, GLuint start, GLuint end, GLsizei count, GLenum type, const GLvoid *indices); +GLAPI void APIENTRY glTexImage3D (GLenum target, GLint level, GLint internalformat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glTexSubImage3D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth, GLenum format, GLenum type, const GLvoid *pixels); +GLAPI void APIENTRY glCopyTexSubImage3D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLint x, GLint y, GLsizei width, GLsizei height); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBLENDCOLORPROC) (GLfloat red, GLfloat green, GLfloat blue, GLfloat alpha); +typedef void (APIENTRYP PFNGLBLENDEQUATIONPROC) (GLenum mode); +typedef void (APIENTRYP PFNGLDRAWRANGEELEMENTSPROC) (GLenum mode, GLuint start, GLuint end, GLsizei count, GLenum type, const GLvoid *indices); +typedef void (APIENTRYP PFNGLTEXIMAGE3DPROC) (GLenum target, GLint level, GLint internalformat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLTEXSUBIMAGE3DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth, GLenum format, GLenum type, const GLvoid *pixels); +typedef void (APIENTRYP PFNGLCOPYTEXSUBIMAGE3DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLint x, GLint y, GLsizei width, GLsizei height); +#endif + +#ifndef GL_VERSION_1_3 +#define GL_VERSION_1_3 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glActiveTexture (GLenum texture); +GLAPI void APIENTRY glSampleCoverage (GLfloat value, GLboolean invert); +GLAPI void APIENTRY glCompressedTexImage3D (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glCompressedTexImage2D (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLsizei height, GLint border, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glCompressedTexImage1D (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLint border, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glCompressedTexSubImage3D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth, GLenum format, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glCompressedTexSubImage2D (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, GLsizei height, GLenum format, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glCompressedTexSubImage1D (GLenum target, GLint level, GLint xoffset, GLsizei width, GLenum format, GLsizei imageSize, const GLvoid *data); +GLAPI void APIENTRY glGetCompressedTexImage (GLenum target, GLint level, GLvoid *img); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLACTIVETEXTUREPROC) (GLenum texture); +typedef void (APIENTRYP PFNGLSAMPLECOVERAGEPROC) (GLfloat value, GLboolean invert); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXIMAGE3DPROC) (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLint border, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXIMAGE2DPROC) (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLsizei height, GLint border, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXIMAGE1DPROC) (GLenum target, GLint level, GLenum internalformat, GLsizei width, GLint border, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXSUBIMAGE3DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth, GLenum format, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXSUBIMAGE2DPROC) (GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, GLsizei height, GLenum format, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLCOMPRESSEDTEXSUBIMAGE1DPROC) (GLenum target, GLint level, GLint xoffset, GLsizei width, GLenum format, GLsizei imageSize, const GLvoid *data); +typedef void (APIENTRYP PFNGLGETCOMPRESSEDTEXIMAGEPROC) (GLenum target, GLint level, GLvoid *img); +#endif + +#ifndef GL_VERSION_1_4 +#define GL_VERSION_1_4 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBlendFuncSeparate (GLenum sfactorRGB, GLenum dfactorRGB, GLenum sfactorAlpha, GLenum dfactorAlpha); +GLAPI void APIENTRY glMultiDrawArrays (GLenum mode, const GLint *first, const GLsizei *count, GLsizei drawcount); +GLAPI void APIENTRY glMultiDrawElements (GLenum mode, const GLsizei *count, GLenum type, const GLvoid* const *indices, GLsizei drawcount); +GLAPI void APIENTRY glPointParameterf (GLenum pname, GLfloat param); +GLAPI void APIENTRY glPointParameterfv (GLenum pname, const GLfloat *params); +GLAPI void APIENTRY glPointParameteri (GLenum pname, GLint param); +GLAPI void APIENTRY glPointParameteriv (GLenum pname, const GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBLENDFUNCSEPARATEPROC) (GLenum sfactorRGB, GLenum dfactorRGB, GLenum sfactorAlpha, GLenum dfactorAlpha); +typedef void (APIENTRYP PFNGLMULTIDRAWARRAYSPROC) (GLenum mode, const GLint *first, const GLsizei *count, GLsizei drawcount); +typedef void (APIENTRYP PFNGLMULTIDRAWELEMENTSPROC) (GLenum mode, const GLsizei *count, GLenum type, const GLvoid* const *indices, GLsizei drawcount); +typedef void (APIENTRYP PFNGLPOINTPARAMETERFPROC) (GLenum pname, GLfloat param); +typedef void (APIENTRYP PFNGLPOINTPARAMETERFVPROC) (GLenum pname, const GLfloat *params); +typedef void (APIENTRYP PFNGLPOINTPARAMETERIPROC) (GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLPOINTPARAMETERIVPROC) (GLenum pname, const GLint *params); +#endif + +#ifndef GL_VERSION_1_5 +#define GL_VERSION_1_5 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGenQueries (GLsizei n, GLuint *ids); +GLAPI void APIENTRY glDeleteQueries (GLsizei n, const GLuint *ids); +GLAPI GLboolean APIENTRY glIsQuery (GLuint id); +GLAPI void APIENTRY glBeginQuery (GLenum target, GLuint id); +GLAPI void APIENTRY glEndQuery (GLenum target); +GLAPI void APIENTRY glGetQueryiv (GLenum target, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetQueryObjectiv (GLuint id, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetQueryObjectuiv (GLuint id, GLenum pname, GLuint *params); +GLAPI void APIENTRY glBindBuffer (GLenum target, GLuint buffer); +GLAPI void APIENTRY glDeleteBuffers (GLsizei n, const GLuint *buffers); +GLAPI void APIENTRY glGenBuffers (GLsizei n, GLuint *buffers); +GLAPI GLboolean APIENTRY glIsBuffer (GLuint buffer); +GLAPI void APIENTRY glBufferData (GLenum target, GLsizeiptr size, const GLvoid *data, GLenum usage); +GLAPI void APIENTRY glBufferSubData (GLenum target, GLintptr offset, GLsizeiptr size, const GLvoid *data); +GLAPI void APIENTRY glGetBufferSubData (GLenum target, GLintptr offset, GLsizeiptr size, GLvoid *data); +GLAPI GLvoid* APIENTRY glMapBuffer (GLenum target, GLenum access); +GLAPI GLboolean APIENTRY glUnmapBuffer (GLenum target); +GLAPI void APIENTRY glGetBufferParameteriv (GLenum target, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetBufferPointerv (GLenum target, GLenum pname, GLvoid* *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGENQUERIESPROC) (GLsizei n, GLuint *ids); +typedef void (APIENTRYP PFNGLDELETEQUERIESPROC) (GLsizei n, const GLuint *ids); +typedef GLboolean (APIENTRYP PFNGLISQUERYPROC) (GLuint id); +typedef void (APIENTRYP PFNGLBEGINQUERYPROC) (GLenum target, GLuint id); +typedef void (APIENTRYP PFNGLENDQUERYPROC) (GLenum target); +typedef void (APIENTRYP PFNGLGETQUERYIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETQUERYOBJECTIVPROC) (GLuint id, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETQUERYOBJECTUIVPROC) (GLuint id, GLenum pname, GLuint *params); +typedef void (APIENTRYP PFNGLBINDBUFFERPROC) (GLenum target, GLuint buffer); +typedef void (APIENTRYP PFNGLDELETEBUFFERSPROC) (GLsizei n, const GLuint *buffers); +typedef void (APIENTRYP PFNGLGENBUFFERSPROC) (GLsizei n, GLuint *buffers); +typedef GLboolean (APIENTRYP PFNGLISBUFFERPROC) (GLuint buffer); +typedef void (APIENTRYP PFNGLBUFFERDATAPROC) (GLenum target, GLsizeiptr size, const GLvoid *data, GLenum usage); +typedef void (APIENTRYP PFNGLBUFFERSUBDATAPROC) (GLenum target, GLintptr offset, GLsizeiptr size, const GLvoid *data); +typedef void (APIENTRYP PFNGLGETBUFFERSUBDATAPROC) (GLenum target, GLintptr offset, GLsizeiptr size, GLvoid *data); +typedef GLvoid* (APIENTRYP PFNGLMAPBUFFERPROC) (GLenum target, GLenum access); +typedef GLboolean (APIENTRYP PFNGLUNMAPBUFFERPROC) (GLenum target); +typedef void (APIENTRYP PFNGLGETBUFFERPARAMETERIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETBUFFERPOINTERVPROC) (GLenum target, GLenum pname, GLvoid* *params); +#endif + +#ifndef GL_VERSION_2_0 +#define GL_VERSION_2_0 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBlendEquationSeparate (GLenum modeRGB, GLenum modeAlpha); +GLAPI void APIENTRY glDrawBuffers (GLsizei n, const GLenum *bufs); +GLAPI void APIENTRY glStencilOpSeparate (GLenum face, GLenum sfail, GLenum dpfail, GLenum dppass); +GLAPI void APIENTRY glStencilFuncSeparate (GLenum face, GLenum func, GLint ref, GLuint mask); +GLAPI void APIENTRY glStencilMaskSeparate (GLenum face, GLuint mask); +GLAPI void APIENTRY glAttachShader (GLuint program, GLuint shader); +GLAPI void APIENTRY glBindAttribLocation (GLuint program, GLuint index, const GLchar *name); +GLAPI void APIENTRY glCompileShader (GLuint shader); +GLAPI GLuint APIENTRY glCreateProgram (void); +GLAPI GLuint APIENTRY glCreateShader (GLenum type); +GLAPI void APIENTRY glDeleteProgram (GLuint program); +GLAPI void APIENTRY glDeleteShader (GLuint shader); +GLAPI void APIENTRY glDetachShader (GLuint program, GLuint shader); +GLAPI void APIENTRY glDisableVertexAttribArray (GLuint index); +GLAPI void APIENTRY glEnableVertexAttribArray (GLuint index); +GLAPI void APIENTRY glGetActiveAttrib (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLint *size, GLenum *type, GLchar *name); +GLAPI void APIENTRY glGetActiveUniform (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLint *size, GLenum *type, GLchar *name); +GLAPI void APIENTRY glGetAttachedShaders (GLuint program, GLsizei maxCount, GLsizei *count, GLuint *obj); +GLAPI GLint APIENTRY glGetAttribLocation (GLuint program, const GLchar *name); +GLAPI void APIENTRY glGetProgramiv (GLuint program, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetProgramInfoLog (GLuint program, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +GLAPI void APIENTRY glGetShaderiv (GLuint shader, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetShaderInfoLog (GLuint shader, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +GLAPI void APIENTRY glGetShaderSource (GLuint shader, GLsizei bufSize, GLsizei *length, GLchar *source); +GLAPI GLint APIENTRY glGetUniformLocation (GLuint program, const GLchar *name); +GLAPI void APIENTRY glGetUniformfv (GLuint program, GLint location, GLfloat *params); +GLAPI void APIENTRY glGetUniformiv (GLuint program, GLint location, GLint *params); +GLAPI void APIENTRY glGetVertexAttribdv (GLuint index, GLenum pname, GLdouble *params); +GLAPI void APIENTRY glGetVertexAttribfv (GLuint index, GLenum pname, GLfloat *params); +GLAPI void APIENTRY glGetVertexAttribiv (GLuint index, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetVertexAttribPointerv (GLuint index, GLenum pname, GLvoid* *pointer); +GLAPI GLboolean APIENTRY glIsProgram (GLuint program); +GLAPI GLboolean APIENTRY glIsShader (GLuint shader); +GLAPI void APIENTRY glLinkProgram (GLuint program); +GLAPI void APIENTRY glShaderSource (GLuint shader, GLsizei count, const GLchar* const *string, const GLint *length); +GLAPI void APIENTRY glUseProgram (GLuint program); +GLAPI void APIENTRY glUniform1f (GLint location, GLfloat v0); +GLAPI void APIENTRY glUniform2f (GLint location, GLfloat v0, GLfloat v1); +GLAPI void APIENTRY glUniform3f (GLint location, GLfloat v0, GLfloat v1, GLfloat v2); +GLAPI void APIENTRY glUniform4f (GLint location, GLfloat v0, GLfloat v1, GLfloat v2, GLfloat v3); +GLAPI void APIENTRY glUniform1i (GLint location, GLint v0); +GLAPI void APIENTRY glUniform2i (GLint location, GLint v0, GLint v1); +GLAPI void APIENTRY glUniform3i (GLint location, GLint v0, GLint v1, GLint v2); +GLAPI void APIENTRY glUniform4i (GLint location, GLint v0, GLint v1, GLint v2, GLint v3); +GLAPI void APIENTRY glUniform1fv (GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glUniform2fv (GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glUniform3fv (GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glUniform4fv (GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glUniform1iv (GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glUniform2iv (GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glUniform3iv (GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glUniform4iv (GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glUniformMatrix2fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix3fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix4fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glValidateProgram (GLuint program); +GLAPI void APIENTRY glVertexAttrib1d (GLuint index, GLdouble x); +GLAPI void APIENTRY glVertexAttrib1dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttrib1f (GLuint index, GLfloat x); +GLAPI void APIENTRY glVertexAttrib1fv (GLuint index, const GLfloat *v); +GLAPI void APIENTRY glVertexAttrib1s (GLuint index, GLshort x); +GLAPI void APIENTRY glVertexAttrib1sv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttrib2d (GLuint index, GLdouble x, GLdouble y); +GLAPI void APIENTRY glVertexAttrib2dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttrib2f (GLuint index, GLfloat x, GLfloat y); +GLAPI void APIENTRY glVertexAttrib2fv (GLuint index, const GLfloat *v); +GLAPI void APIENTRY glVertexAttrib2s (GLuint index, GLshort x, GLshort y); +GLAPI void APIENTRY glVertexAttrib2sv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttrib3d (GLuint index, GLdouble x, GLdouble y, GLdouble z); +GLAPI void APIENTRY glVertexAttrib3dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttrib3f (GLuint index, GLfloat x, GLfloat y, GLfloat z); +GLAPI void APIENTRY glVertexAttrib3fv (GLuint index, const GLfloat *v); +GLAPI void APIENTRY glVertexAttrib3s (GLuint index, GLshort x, GLshort y, GLshort z); +GLAPI void APIENTRY glVertexAttrib3sv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttrib4Nbv (GLuint index, const GLbyte *v); +GLAPI void APIENTRY glVertexAttrib4Niv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttrib4Nsv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttrib4Nub (GLuint index, GLubyte x, GLubyte y, GLubyte z, GLubyte w); +GLAPI void APIENTRY glVertexAttrib4Nubv (GLuint index, const GLubyte *v); +GLAPI void APIENTRY glVertexAttrib4Nuiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttrib4Nusv (GLuint index, const GLushort *v); +GLAPI void APIENTRY glVertexAttrib4bv (GLuint index, const GLbyte *v); +GLAPI void APIENTRY glVertexAttrib4d (GLuint index, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +GLAPI void APIENTRY glVertexAttrib4dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttrib4f (GLuint index, GLfloat x, GLfloat y, GLfloat z, GLfloat w); +GLAPI void APIENTRY glVertexAttrib4fv (GLuint index, const GLfloat *v); +GLAPI void APIENTRY glVertexAttrib4iv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttrib4s (GLuint index, GLshort x, GLshort y, GLshort z, GLshort w); +GLAPI void APIENTRY glVertexAttrib4sv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttrib4ubv (GLuint index, const GLubyte *v); +GLAPI void APIENTRY glVertexAttrib4uiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttrib4usv (GLuint index, const GLushort *v); +GLAPI void APIENTRY glVertexAttribPointer (GLuint index, GLint size, GLenum type, GLboolean normalized, GLsizei stride, const GLvoid *pointer); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBLENDEQUATIONSEPARATEPROC) (GLenum modeRGB, GLenum modeAlpha); +typedef void (APIENTRYP PFNGLDRAWBUFFERSPROC) (GLsizei n, const GLenum *bufs); +typedef void (APIENTRYP PFNGLSTENCILOPSEPARATEPROC) (GLenum face, GLenum sfail, GLenum dpfail, GLenum dppass); +typedef void (APIENTRYP PFNGLSTENCILFUNCSEPARATEPROC) (GLenum face, GLenum func, GLint ref, GLuint mask); +typedef void (APIENTRYP PFNGLSTENCILMASKSEPARATEPROC) (GLenum face, GLuint mask); +typedef void (APIENTRYP PFNGLATTACHSHADERPROC) (GLuint program, GLuint shader); +typedef void (APIENTRYP PFNGLBINDATTRIBLOCATIONPROC) (GLuint program, GLuint index, const GLchar *name); +typedef void (APIENTRYP PFNGLCOMPILESHADERPROC) (GLuint shader); +typedef GLuint (APIENTRYP PFNGLCREATEPROGRAMPROC) (void); +typedef GLuint (APIENTRYP PFNGLCREATESHADERPROC) (GLenum type); +typedef void (APIENTRYP PFNGLDELETEPROGRAMPROC) (GLuint program); +typedef void (APIENTRYP PFNGLDELETESHADERPROC) (GLuint shader); +typedef void (APIENTRYP PFNGLDETACHSHADERPROC) (GLuint program, GLuint shader); +typedef void (APIENTRYP PFNGLDISABLEVERTEXATTRIBARRAYPROC) (GLuint index); +typedef void (APIENTRYP PFNGLENABLEVERTEXATTRIBARRAYPROC) (GLuint index); +typedef void (APIENTRYP PFNGLGETACTIVEATTRIBPROC) (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLint *size, GLenum *type, GLchar *name); +typedef void (APIENTRYP PFNGLGETACTIVEUNIFORMPROC) (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLint *size, GLenum *type, GLchar *name); +typedef void (APIENTRYP PFNGLGETATTACHEDSHADERSPROC) (GLuint program, GLsizei maxCount, GLsizei *count, GLuint *obj); +typedef GLint (APIENTRYP PFNGLGETATTRIBLOCATIONPROC) (GLuint program, const GLchar *name); +typedef void (APIENTRYP PFNGLGETPROGRAMIVPROC) (GLuint program, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETPROGRAMINFOLOGPROC) (GLuint program, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +typedef void (APIENTRYP PFNGLGETSHADERIVPROC) (GLuint shader, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETSHADERINFOLOGPROC) (GLuint shader, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +typedef void (APIENTRYP PFNGLGETSHADERSOURCEPROC) (GLuint shader, GLsizei bufSize, GLsizei *length, GLchar *source); +typedef GLint (APIENTRYP PFNGLGETUNIFORMLOCATIONPROC) (GLuint program, const GLchar *name); +typedef void (APIENTRYP PFNGLGETUNIFORMFVPROC) (GLuint program, GLint location, GLfloat *params); +typedef void (APIENTRYP PFNGLGETUNIFORMIVPROC) (GLuint program, GLint location, GLint *params); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBDVPROC) (GLuint index, GLenum pname, GLdouble *params); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBFVPROC) (GLuint index, GLenum pname, GLfloat *params); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBIVPROC) (GLuint index, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBPOINTERVPROC) (GLuint index, GLenum pname, GLvoid* *pointer); +typedef GLboolean (APIENTRYP PFNGLISPROGRAMPROC) (GLuint program); +typedef GLboolean (APIENTRYP PFNGLISSHADERPROC) (GLuint shader); +typedef void (APIENTRYP PFNGLLINKPROGRAMPROC) (GLuint program); +typedef void (APIENTRYP PFNGLSHADERSOURCEPROC) (GLuint shader, GLsizei count, const GLchar* const *string, const GLint *length); +typedef void (APIENTRYP PFNGLUSEPROGRAMPROC) (GLuint program); +typedef void (APIENTRYP PFNGLUNIFORM1FPROC) (GLint location, GLfloat v0); +typedef void (APIENTRYP PFNGLUNIFORM2FPROC) (GLint location, GLfloat v0, GLfloat v1); +typedef void (APIENTRYP PFNGLUNIFORM3FPROC) (GLint location, GLfloat v0, GLfloat v1, GLfloat v2); +typedef void (APIENTRYP PFNGLUNIFORM4FPROC) (GLint location, GLfloat v0, GLfloat v1, GLfloat v2, GLfloat v3); +typedef void (APIENTRYP PFNGLUNIFORM1IPROC) (GLint location, GLint v0); +typedef void (APIENTRYP PFNGLUNIFORM2IPROC) (GLint location, GLint v0, GLint v1); +typedef void (APIENTRYP PFNGLUNIFORM3IPROC) (GLint location, GLint v0, GLint v1, GLint v2); +typedef void (APIENTRYP PFNGLUNIFORM4IPROC) (GLint location, GLint v0, GLint v1, GLint v2, GLint v3); +typedef void (APIENTRYP PFNGLUNIFORM1FVPROC) (GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORM2FVPROC) (GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORM3FVPROC) (GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORM4FVPROC) (GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORM1IVPROC) (GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLUNIFORM2IVPROC) (GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLUNIFORM3IVPROC) (GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLUNIFORM4IVPROC) (GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLVALIDATEPROGRAMPROC) (GLuint program); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1DPROC) (GLuint index, GLdouble x); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1FPROC) (GLuint index, GLfloat x); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1FVPROC) (GLuint index, const GLfloat *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1SPROC) (GLuint index, GLshort x); +typedef void (APIENTRYP PFNGLVERTEXATTRIB1SVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2DPROC) (GLuint index, GLdouble x, GLdouble y); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2FPROC) (GLuint index, GLfloat x, GLfloat y); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2FVPROC) (GLuint index, const GLfloat *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2SPROC) (GLuint index, GLshort x, GLshort y); +typedef void (APIENTRYP PFNGLVERTEXATTRIB2SVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3DPROC) (GLuint index, GLdouble x, GLdouble y, GLdouble z); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3FPROC) (GLuint index, GLfloat x, GLfloat y, GLfloat z); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3FVPROC) (GLuint index, const GLfloat *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3SPROC) (GLuint index, GLshort x, GLshort y, GLshort z); +typedef void (APIENTRYP PFNGLVERTEXATTRIB3SVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NBVPROC) (GLuint index, const GLbyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NIVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NSVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NUBPROC) (GLuint index, GLubyte x, GLubyte y, GLubyte z, GLubyte w); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NUBVPROC) (GLuint index, const GLubyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NUIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4NUSVPROC) (GLuint index, const GLushort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4BVPROC) (GLuint index, const GLbyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4DPROC) (GLuint index, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4FPROC) (GLuint index, GLfloat x, GLfloat y, GLfloat z, GLfloat w); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4FVPROC) (GLuint index, const GLfloat *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4IVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4SPROC) (GLuint index, GLshort x, GLshort y, GLshort z, GLshort w); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4SVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4UBVPROC) (GLuint index, const GLubyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4UIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIB4USVPROC) (GLuint index, const GLushort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBPOINTERPROC) (GLuint index, GLint size, GLenum type, GLboolean normalized, GLsizei stride, const GLvoid *pointer); +#endif + +#ifndef GL_VERSION_2_1 +#define GL_VERSION_2_1 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glUniformMatrix2x3fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix3x2fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix2x4fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix4x2fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix3x4fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glUniformMatrix4x3fv (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2X3FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3X2FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2X4FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4X2FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3X4FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4X3FVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +#endif + +#ifndef GL_VERSION_3_0 +#define GL_VERSION_3_0 1 +/* OpenGL 3.0 also reuses entry points from these extensions: */ +/* ARB_framebuffer_object */ +/* ARB_map_buffer_range */ +/* ARB_vertex_array_object */ +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glColorMaski (GLuint index, GLboolean r, GLboolean g, GLboolean b, GLboolean a); +GLAPI void APIENTRY glGetBooleani_v (GLenum target, GLuint index, GLboolean *data); +GLAPI void APIENTRY glGetIntegeri_v (GLenum target, GLuint index, GLint *data); +GLAPI void APIENTRY glEnablei (GLenum target, GLuint index); +GLAPI void APIENTRY glDisablei (GLenum target, GLuint index); +GLAPI GLboolean APIENTRY glIsEnabledi (GLenum target, GLuint index); +GLAPI void APIENTRY glBeginTransformFeedback (GLenum primitiveMode); +GLAPI void APIENTRY glEndTransformFeedback (void); +GLAPI void APIENTRY glBindBufferRange (GLenum target, GLuint index, GLuint buffer, GLintptr offset, GLsizeiptr size); +GLAPI void APIENTRY glBindBufferBase (GLenum target, GLuint index, GLuint buffer); +GLAPI void APIENTRY glTransformFeedbackVaryings (GLuint program, GLsizei count, const GLchar* const *varyings, GLenum bufferMode); +GLAPI void APIENTRY glGetTransformFeedbackVarying (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLsizei *size, GLenum *type, GLchar *name); +GLAPI void APIENTRY glClampColor (GLenum target, GLenum clamp); +GLAPI void APIENTRY glBeginConditionalRender (GLuint id, GLenum mode); +GLAPI void APIENTRY glEndConditionalRender (void); +GLAPI void APIENTRY glVertexAttribIPointer (GLuint index, GLint size, GLenum type, GLsizei stride, const GLvoid *pointer); +GLAPI void APIENTRY glGetVertexAttribIiv (GLuint index, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetVertexAttribIuiv (GLuint index, GLenum pname, GLuint *params); +GLAPI void APIENTRY glVertexAttribI1i (GLuint index, GLint x); +GLAPI void APIENTRY glVertexAttribI2i (GLuint index, GLint x, GLint y); +GLAPI void APIENTRY glVertexAttribI3i (GLuint index, GLint x, GLint y, GLint z); +GLAPI void APIENTRY glVertexAttribI4i (GLuint index, GLint x, GLint y, GLint z, GLint w); +GLAPI void APIENTRY glVertexAttribI1ui (GLuint index, GLuint x); +GLAPI void APIENTRY glVertexAttribI2ui (GLuint index, GLuint x, GLuint y); +GLAPI void APIENTRY glVertexAttribI3ui (GLuint index, GLuint x, GLuint y, GLuint z); +GLAPI void APIENTRY glVertexAttribI4ui (GLuint index, GLuint x, GLuint y, GLuint z, GLuint w); +GLAPI void APIENTRY glVertexAttribI1iv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttribI2iv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttribI3iv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttribI4iv (GLuint index, const GLint *v); +GLAPI void APIENTRY glVertexAttribI1uiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttribI2uiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttribI3uiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttribI4uiv (GLuint index, const GLuint *v); +GLAPI void APIENTRY glVertexAttribI4bv (GLuint index, const GLbyte *v); +GLAPI void APIENTRY glVertexAttribI4sv (GLuint index, const GLshort *v); +GLAPI void APIENTRY glVertexAttribI4ubv (GLuint index, const GLubyte *v); +GLAPI void APIENTRY glVertexAttribI4usv (GLuint index, const GLushort *v); +GLAPI void APIENTRY glGetUniformuiv (GLuint program, GLint location, GLuint *params); +GLAPI void APIENTRY glBindFragDataLocation (GLuint program, GLuint color, const GLchar *name); +GLAPI GLint APIENTRY glGetFragDataLocation (GLuint program, const GLchar *name); +GLAPI void APIENTRY glUniform1ui (GLint location, GLuint v0); +GLAPI void APIENTRY glUniform2ui (GLint location, GLuint v0, GLuint v1); +GLAPI void APIENTRY glUniform3ui (GLint location, GLuint v0, GLuint v1, GLuint v2); +GLAPI void APIENTRY glUniform4ui (GLint location, GLuint v0, GLuint v1, GLuint v2, GLuint v3); +GLAPI void APIENTRY glUniform1uiv (GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glUniform2uiv (GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glUniform3uiv (GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glUniform4uiv (GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glTexParameterIiv (GLenum target, GLenum pname, const GLint *params); +GLAPI void APIENTRY glTexParameterIuiv (GLenum target, GLenum pname, const GLuint *params); +GLAPI void APIENTRY glGetTexParameterIiv (GLenum target, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetTexParameterIuiv (GLenum target, GLenum pname, GLuint *params); +GLAPI void APIENTRY glClearBufferiv (GLenum buffer, GLint drawbuffer, const GLint *value); +GLAPI void APIENTRY glClearBufferuiv (GLenum buffer, GLint drawbuffer, const GLuint *value); +GLAPI void APIENTRY glClearBufferfv (GLenum buffer, GLint drawbuffer, const GLfloat *value); +GLAPI void APIENTRY glClearBufferfi (GLenum buffer, GLint drawbuffer, GLfloat depth, GLint stencil); +GLAPI const GLubyte * APIENTRY glGetStringi (GLenum name, GLuint index); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLCOLORMASKIPROC) (GLuint index, GLboolean r, GLboolean g, GLboolean b, GLboolean a); +typedef void (APIENTRYP PFNGLGETBOOLEANI_VPROC) (GLenum target, GLuint index, GLboolean *data); +typedef void (APIENTRYP PFNGLGETINTEGERI_VPROC) (GLenum target, GLuint index, GLint *data); +typedef void (APIENTRYP PFNGLENABLEIPROC) (GLenum target, GLuint index); +typedef void (APIENTRYP PFNGLDISABLEIPROC) (GLenum target, GLuint index); +typedef GLboolean (APIENTRYP PFNGLISENABLEDIPROC) (GLenum target, GLuint index); +typedef void (APIENTRYP PFNGLBEGINTRANSFORMFEEDBACKPROC) (GLenum primitiveMode); +typedef void (APIENTRYP PFNGLENDTRANSFORMFEEDBACKPROC) (void); +typedef void (APIENTRYP PFNGLBINDBUFFERRANGEPROC) (GLenum target, GLuint index, GLuint buffer, GLintptr offset, GLsizeiptr size); +typedef void (APIENTRYP PFNGLBINDBUFFERBASEPROC) (GLenum target, GLuint index, GLuint buffer); +typedef void (APIENTRYP PFNGLTRANSFORMFEEDBACKVARYINGSPROC) (GLuint program, GLsizei count, const GLchar* const *varyings, GLenum bufferMode); +typedef void (APIENTRYP PFNGLGETTRANSFORMFEEDBACKVARYINGPROC) (GLuint program, GLuint index, GLsizei bufSize, GLsizei *length, GLsizei *size, GLenum *type, GLchar *name); +typedef void (APIENTRYP PFNGLCLAMPCOLORPROC) (GLenum target, GLenum clamp); +typedef void (APIENTRYP PFNGLBEGINCONDITIONALRENDERPROC) (GLuint id, GLenum mode); +typedef void (APIENTRYP PFNGLENDCONDITIONALRENDERPROC) (void); +typedef void (APIENTRYP PFNGLVERTEXATTRIBIPOINTERPROC) (GLuint index, GLint size, GLenum type, GLsizei stride, const GLvoid *pointer); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBIIVPROC) (GLuint index, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBIUIVPROC) (GLuint index, GLenum pname, GLuint *params); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI1IPROC) (GLuint index, GLint x); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI2IPROC) (GLuint index, GLint x, GLint y); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI3IPROC) (GLuint index, GLint x, GLint y, GLint z); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4IPROC) (GLuint index, GLint x, GLint y, GLint z, GLint w); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI1UIPROC) (GLuint index, GLuint x); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI2UIPROC) (GLuint index, GLuint x, GLuint y); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI3UIPROC) (GLuint index, GLuint x, GLuint y, GLuint z); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4UIPROC) (GLuint index, GLuint x, GLuint y, GLuint z, GLuint w); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI1IVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI2IVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI3IVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4IVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI1UIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI2UIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI3UIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4UIVPROC) (GLuint index, const GLuint *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4BVPROC) (GLuint index, const GLbyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4SVPROC) (GLuint index, const GLshort *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4UBVPROC) (GLuint index, const GLubyte *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBI4USVPROC) (GLuint index, const GLushort *v); +typedef void (APIENTRYP PFNGLGETUNIFORMUIVPROC) (GLuint program, GLint location, GLuint *params); +typedef void (APIENTRYP PFNGLBINDFRAGDATALOCATIONPROC) (GLuint program, GLuint color, const GLchar *name); +typedef GLint (APIENTRYP PFNGLGETFRAGDATALOCATIONPROC) (GLuint program, const GLchar *name); +typedef void (APIENTRYP PFNGLUNIFORM1UIPROC) (GLint location, GLuint v0); +typedef void (APIENTRYP PFNGLUNIFORM2UIPROC) (GLint location, GLuint v0, GLuint v1); +typedef void (APIENTRYP PFNGLUNIFORM3UIPROC) (GLint location, GLuint v0, GLuint v1, GLuint v2); +typedef void (APIENTRYP PFNGLUNIFORM4UIPROC) (GLint location, GLuint v0, GLuint v1, GLuint v2, GLuint v3); +typedef void (APIENTRYP PFNGLUNIFORM1UIVPROC) (GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLUNIFORM2UIVPROC) (GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLUNIFORM3UIVPROC) (GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLUNIFORM4UIVPROC) (GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLTEXPARAMETERIIVPROC) (GLenum target, GLenum pname, const GLint *params); +typedef void (APIENTRYP PFNGLTEXPARAMETERIUIVPROC) (GLenum target, GLenum pname, const GLuint *params); +typedef void (APIENTRYP PFNGLGETTEXPARAMETERIIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETTEXPARAMETERIUIVPROC) (GLenum target, GLenum pname, GLuint *params); +typedef void (APIENTRYP PFNGLCLEARBUFFERIVPROC) (GLenum buffer, GLint drawbuffer, const GLint *value); +typedef void (APIENTRYP PFNGLCLEARBUFFERUIVPROC) (GLenum buffer, GLint drawbuffer, const GLuint *value); +typedef void (APIENTRYP PFNGLCLEARBUFFERFVPROC) (GLenum buffer, GLint drawbuffer, const GLfloat *value); +typedef void (APIENTRYP PFNGLCLEARBUFFERFIPROC) (GLenum buffer, GLint drawbuffer, GLfloat depth, GLint stencil); +typedef const GLubyte * (APIENTRYP PFNGLGETSTRINGIPROC) (GLenum name, GLuint index); +#endif + +#ifndef GL_VERSION_3_1 +#define GL_VERSION_3_1 1 +/* OpenGL 3.1 also reuses entry points from these extensions: */ +/* ARB_copy_buffer */ +/* ARB_uniform_buffer_object */ +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawArraysInstanced (GLenum mode, GLint first, GLsizei count, GLsizei instancecount); +GLAPI void APIENTRY glDrawElementsInstanced (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLsizei instancecount); +GLAPI void APIENTRY glTexBuffer (GLenum target, GLenum internalformat, GLuint buffer); +GLAPI void APIENTRY glPrimitiveRestartIndex (GLuint index); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWARRAYSINSTANCEDPROC) (GLenum mode, GLint first, GLsizei count, GLsizei instancecount); +typedef void (APIENTRYP PFNGLDRAWELEMENTSINSTANCEDPROC) (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLsizei instancecount); +typedef void (APIENTRYP PFNGLTEXBUFFERPROC) (GLenum target, GLenum internalformat, GLuint buffer); +typedef void (APIENTRYP PFNGLPRIMITIVERESTARTINDEXPROC) (GLuint index); +#endif + +#ifndef GL_VERSION_3_2 +#define GL_VERSION_3_2 1 +/* OpenGL 3.2 also reuses entry points from these extensions: */ +/* ARB_draw_elements_base_vertex */ +/* ARB_provoking_vertex */ +/* ARB_sync */ +/* ARB_texture_multisample */ +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetInteger64i_v (GLenum target, GLuint index, GLint64 *data); +GLAPI void APIENTRY glGetBufferParameteri64v (GLenum target, GLenum pname, GLint64 *params); +GLAPI void APIENTRY glFramebufferTexture (GLenum target, GLenum attachment, GLuint texture, GLint level); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETINTEGER64I_VPROC) (GLenum target, GLuint index, GLint64 *data); +typedef void (APIENTRYP PFNGLGETBUFFERPARAMETERI64VPROC) (GLenum target, GLenum pname, GLint64 *params); +typedef void (APIENTRYP PFNGLFRAMEBUFFERTEXTUREPROC) (GLenum target, GLenum attachment, GLuint texture, GLint level); +#endif + +#ifndef GL_VERSION_3_3 +#define GL_VERSION_3_3 1 +/* OpenGL 3.3 also reuses entry points from these extensions: */ +/* ARB_blend_func_extended */ +/* ARB_sampler_objects */ +/* ARB_explicit_attrib_location, but it has none */ +/* ARB_occlusion_query2 (no entry points) */ +/* ARB_shader_bit_encoding (no entry points) */ +/* ARB_texture_rgb10_a2ui (no entry points) */ +/* ARB_texture_swizzle (no entry points) */ +/* ARB_timer_query */ +/* ARB_vertex_type_2_10_10_10_rev */ +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glVertexAttribDivisor (GLuint index, GLuint divisor); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLVERTEXATTRIBDIVISORPROC) (GLuint index, GLuint divisor); +#endif + +#ifndef GL_VERSION_4_0 +#define GL_VERSION_4_0 1 +/* OpenGL 4.0 also reuses entry points from these extensions: */ +/* ARB_texture_query_lod (no entry points) */ +/* ARB_draw_indirect */ +/* ARB_gpu_shader5 (no entry points) */ +/* ARB_gpu_shader_fp64 */ +/* ARB_shader_subroutine */ +/* ARB_tessellation_shader */ +/* ARB_texture_buffer_object_rgb32 (no entry points) */ +/* ARB_texture_cube_map_array (no entry points) */ +/* ARB_texture_gather (no entry points) */ +/* ARB_transform_feedback2 */ +/* ARB_transform_feedback3 */ +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glMinSampleShading (GLfloat value); +GLAPI void APIENTRY glBlendEquationi (GLuint buf, GLenum mode); +GLAPI void APIENTRY glBlendEquationSeparatei (GLuint buf, GLenum modeRGB, GLenum modeAlpha); +GLAPI void APIENTRY glBlendFunci (GLuint buf, GLenum src, GLenum dst); +GLAPI void APIENTRY glBlendFuncSeparatei (GLuint buf, GLenum srcRGB, GLenum dstRGB, GLenum srcAlpha, GLenum dstAlpha); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLMINSAMPLESHADINGPROC) (GLfloat value); +typedef void (APIENTRYP PFNGLBLENDEQUATIONIPROC) (GLuint buf, GLenum mode); +typedef void (APIENTRYP PFNGLBLENDEQUATIONSEPARATEIPROC) (GLuint buf, GLenum modeRGB, GLenum modeAlpha); +typedef void (APIENTRYP PFNGLBLENDFUNCIPROC) (GLuint buf, GLenum src, GLenum dst); +typedef void (APIENTRYP PFNGLBLENDFUNCSEPARATEIPROC) (GLuint buf, GLenum srcRGB, GLenum dstRGB, GLenum srcAlpha, GLenum dstAlpha); +#endif + +#ifndef GL_VERSION_4_1 +#define GL_VERSION_4_1 1 +/* OpenGL 4.1 reuses entry points from these extensions: */ +/* ARB_ES2_compatibility */ +/* ARB_get_program_binary */ +/* ARB_separate_shader_objects */ +/* ARB_shader_precision (no entry points) */ +/* ARB_vertex_attrib_64bit */ +/* ARB_viewport_array */ +#endif + +#ifndef GL_VERSION_4_2 +#define GL_VERSION_4_2 1 +/* OpenGL 4.2 reuses entry points from these extensions: */ +/* ARB_base_instance */ +/* ARB_shading_language_420pack (no entry points) */ +/* ARB_transform_feedback_instanced */ +/* ARB_compressed_texture_pixel_storage (no entry points) */ +/* ARB_conservative_depth (no entry points) */ +/* ARB_internalformat_query */ +/* ARB_map_buffer_alignment (no entry points) */ +/* ARB_shader_atomic_counters */ +/* ARB_shader_image_load_store */ +/* ARB_shading_language_packing (no entry points) */ +/* ARB_texture_storage */ +#endif + +#ifndef GL_VERSION_4_3 +#define GL_VERSION_4_3 1 +/* OpenGL 4.3 reuses entry points from these extensions: */ +/* ARB_arrays_of_arrays (no entry points, GLSL only) */ +/* ARB_fragment_layer_viewport (no entry points, GLSL only) */ +/* ARB_shader_image_size (no entry points, GLSL only) */ +/* ARB_ES3_compatibility (no entry points) */ +/* ARB_clear_buffer_object */ +/* ARB_compute_shader */ +/* ARB_copy_image */ +/* KHR_debug (includes ARB_debug_output commands promoted to KHR without suffixes) */ +/* ARB_explicit_uniform_location (no entry points) */ +/* ARB_framebuffer_no_attachments */ +/* ARB_internalformat_query2 */ +/* ARB_invalidate_subdata */ +/* ARB_multi_draw_indirect */ +/* ARB_program_interface_query */ +/* ARB_robust_buffer_access_behavior (no entry points) */ +/* ARB_shader_storage_buffer_object */ +/* ARB_stencil_texturing (no entry points) */ +/* ARB_texture_buffer_range */ +/* ARB_texture_query_levels (no entry points) */ +/* ARB_texture_storage_multisample */ +/* ARB_texture_view */ +/* ARB_vertex_attrib_binding */ +#endif + +#ifndef GL_ARB_depth_buffer_float +#define GL_ARB_depth_buffer_float 1 +#endif + +#ifndef GL_ARB_framebuffer_object +#define GL_ARB_framebuffer_object 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLboolean APIENTRY glIsRenderbuffer (GLuint renderbuffer); +GLAPI void APIENTRY glBindRenderbuffer (GLenum target, GLuint renderbuffer); +GLAPI void APIENTRY glDeleteRenderbuffers (GLsizei n, const GLuint *renderbuffers); +GLAPI void APIENTRY glGenRenderbuffers (GLsizei n, GLuint *renderbuffers); +GLAPI void APIENTRY glRenderbufferStorage (GLenum target, GLenum internalformat, GLsizei width, GLsizei height); +GLAPI void APIENTRY glGetRenderbufferParameteriv (GLenum target, GLenum pname, GLint *params); +GLAPI GLboolean APIENTRY glIsFramebuffer (GLuint framebuffer); +GLAPI void APIENTRY glBindFramebuffer (GLenum target, GLuint framebuffer); +GLAPI void APIENTRY glDeleteFramebuffers (GLsizei n, const GLuint *framebuffers); +GLAPI void APIENTRY glGenFramebuffers (GLsizei n, GLuint *framebuffers); +GLAPI GLenum APIENTRY glCheckFramebufferStatus (GLenum target); +GLAPI void APIENTRY glFramebufferTexture1D (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level); +GLAPI void APIENTRY glFramebufferTexture2D (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level); +GLAPI void APIENTRY glFramebufferTexture3D (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level, GLint zoffset); +GLAPI void APIENTRY glFramebufferRenderbuffer (GLenum target, GLenum attachment, GLenum renderbuffertarget, GLuint renderbuffer); +GLAPI void APIENTRY glGetFramebufferAttachmentParameteriv (GLenum target, GLenum attachment, GLenum pname, GLint *params); +GLAPI void APIENTRY glGenerateMipmap (GLenum target); +GLAPI void APIENTRY glBlitFramebuffer (GLint srcX0, GLint srcY0, GLint srcX1, GLint srcY1, GLint dstX0, GLint dstY0, GLint dstX1, GLint dstY1, GLbitfield mask, GLenum filter); +GLAPI void APIENTRY glRenderbufferStorageMultisample (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height); +GLAPI void APIENTRY glFramebufferTextureLayer (GLenum target, GLenum attachment, GLuint texture, GLint level, GLint layer); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLboolean (APIENTRYP PFNGLISRENDERBUFFERPROC) (GLuint renderbuffer); +typedef void (APIENTRYP PFNGLBINDRENDERBUFFERPROC) (GLenum target, GLuint renderbuffer); +typedef void (APIENTRYP PFNGLDELETERENDERBUFFERSPROC) (GLsizei n, const GLuint *renderbuffers); +typedef void (APIENTRYP PFNGLGENRENDERBUFFERSPROC) (GLsizei n, GLuint *renderbuffers); +typedef void (APIENTRYP PFNGLRENDERBUFFERSTORAGEPROC) (GLenum target, GLenum internalformat, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLGETRENDERBUFFERPARAMETERIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef GLboolean (APIENTRYP PFNGLISFRAMEBUFFERPROC) (GLuint framebuffer); +typedef void (APIENTRYP PFNGLBINDFRAMEBUFFERPROC) (GLenum target, GLuint framebuffer); +typedef void (APIENTRYP PFNGLDELETEFRAMEBUFFERSPROC) (GLsizei n, const GLuint *framebuffers); +typedef void (APIENTRYP PFNGLGENFRAMEBUFFERSPROC) (GLsizei n, GLuint *framebuffers); +typedef GLenum (APIENTRYP PFNGLCHECKFRAMEBUFFERSTATUSPROC) (GLenum target); +typedef void (APIENTRYP PFNGLFRAMEBUFFERTEXTURE1DPROC) (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level); +typedef void (APIENTRYP PFNGLFRAMEBUFFERTEXTURE2DPROC) (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level); +typedef void (APIENTRYP PFNGLFRAMEBUFFERTEXTURE3DPROC) (GLenum target, GLenum attachment, GLenum textarget, GLuint texture, GLint level, GLint zoffset); +typedef void (APIENTRYP PFNGLFRAMEBUFFERRENDERBUFFERPROC) (GLenum target, GLenum attachment, GLenum renderbuffertarget, GLuint renderbuffer); +typedef void (APIENTRYP PFNGLGETFRAMEBUFFERATTACHMENTPARAMETERIVPROC) (GLenum target, GLenum attachment, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGENERATEMIPMAPPROC) (GLenum target); +typedef void (APIENTRYP PFNGLBLITFRAMEBUFFERPROC) (GLint srcX0, GLint srcY0, GLint srcX1, GLint srcY1, GLint dstX0, GLint dstY0, GLint dstX1, GLint dstY1, GLbitfield mask, GLenum filter); +typedef void (APIENTRYP PFNGLRENDERBUFFERSTORAGEMULTISAMPLEPROC) (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLFRAMEBUFFERTEXTURELAYERPROC) (GLenum target, GLenum attachment, GLuint texture, GLint level, GLint layer); +#endif + +#ifndef GL_ARB_framebuffer_sRGB +#define GL_ARB_framebuffer_sRGB 1 +#endif + +#ifndef GL_ARB_half_float_vertex +#define GL_ARB_half_float_vertex 1 +#endif + +#ifndef GL_ARB_map_buffer_range +#define GL_ARB_map_buffer_range 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLvoid* APIENTRY glMapBufferRange (GLenum target, GLintptr offset, GLsizeiptr length, GLbitfield access); +GLAPI void APIENTRY glFlushMappedBufferRange (GLenum target, GLintptr offset, GLsizeiptr length); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLvoid* (APIENTRYP PFNGLMAPBUFFERRANGEPROC) (GLenum target, GLintptr offset, GLsizeiptr length, GLbitfield access); +typedef void (APIENTRYP PFNGLFLUSHMAPPEDBUFFERRANGEPROC) (GLenum target, GLintptr offset, GLsizeiptr length); +#endif + +#ifndef GL_ARB_texture_compression_rgtc +#define GL_ARB_texture_compression_rgtc 1 +#endif + +#ifndef GL_ARB_texture_rg +#define GL_ARB_texture_rg 1 +#endif + +#ifndef GL_ARB_vertex_array_object +#define GL_ARB_vertex_array_object 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBindVertexArray (GLuint array); +GLAPI void APIENTRY glDeleteVertexArrays (GLsizei n, const GLuint *arrays); +GLAPI void APIENTRY glGenVertexArrays (GLsizei n, GLuint *arrays); +GLAPI GLboolean APIENTRY glIsVertexArray (GLuint array); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBINDVERTEXARRAYPROC) (GLuint array); +typedef void (APIENTRYP PFNGLDELETEVERTEXARRAYSPROC) (GLsizei n, const GLuint *arrays); +typedef void (APIENTRYP PFNGLGENVERTEXARRAYSPROC) (GLsizei n, GLuint *arrays); +typedef GLboolean (APIENTRYP PFNGLISVERTEXARRAYPROC) (GLuint array); +#endif + +#ifndef GL_ARB_uniform_buffer_object +#define GL_ARB_uniform_buffer_object 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetUniformIndices (GLuint program, GLsizei uniformCount, const GLchar* const *uniformNames, GLuint *uniformIndices); +GLAPI void APIENTRY glGetActiveUniformsiv (GLuint program, GLsizei uniformCount, const GLuint *uniformIndices, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetActiveUniformName (GLuint program, GLuint uniformIndex, GLsizei bufSize, GLsizei *length, GLchar *uniformName); +GLAPI GLuint APIENTRY glGetUniformBlockIndex (GLuint program, const GLchar *uniformBlockName); +GLAPI void APIENTRY glGetActiveUniformBlockiv (GLuint program, GLuint uniformBlockIndex, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetActiveUniformBlockName (GLuint program, GLuint uniformBlockIndex, GLsizei bufSize, GLsizei *length, GLchar *uniformBlockName); +GLAPI void APIENTRY glUniformBlockBinding (GLuint program, GLuint uniformBlockIndex, GLuint uniformBlockBinding); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETUNIFORMINDICESPROC) (GLuint program, GLsizei uniformCount, const GLchar* const *uniformNames, GLuint *uniformIndices); +typedef void (APIENTRYP PFNGLGETACTIVEUNIFORMSIVPROC) (GLuint program, GLsizei uniformCount, const GLuint *uniformIndices, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETACTIVEUNIFORMNAMEPROC) (GLuint program, GLuint uniformIndex, GLsizei bufSize, GLsizei *length, GLchar *uniformName); +typedef GLuint (APIENTRYP PFNGLGETUNIFORMBLOCKINDEXPROC) (GLuint program, const GLchar *uniformBlockName); +typedef void (APIENTRYP PFNGLGETACTIVEUNIFORMBLOCKIVPROC) (GLuint program, GLuint uniformBlockIndex, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETACTIVEUNIFORMBLOCKNAMEPROC) (GLuint program, GLuint uniformBlockIndex, GLsizei bufSize, GLsizei *length, GLchar *uniformBlockName); +typedef void (APIENTRYP PFNGLUNIFORMBLOCKBINDINGPROC) (GLuint program, GLuint uniformBlockIndex, GLuint uniformBlockBinding); +#endif + +#ifndef GL_ARB_copy_buffer +#define GL_ARB_copy_buffer 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glCopyBufferSubData (GLenum readTarget, GLenum writeTarget, GLintptr readOffset, GLintptr writeOffset, GLsizeiptr size); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLCOPYBUFFERSUBDATAPROC) (GLenum readTarget, GLenum writeTarget, GLintptr readOffset, GLintptr writeOffset, GLsizeiptr size); +#endif + +#ifndef GL_ARB_depth_clamp +#define GL_ARB_depth_clamp 1 +#endif + +#ifndef GL_ARB_draw_elements_base_vertex +#define GL_ARB_draw_elements_base_vertex 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawElementsBaseVertex (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLint basevertex); +GLAPI void APIENTRY glDrawRangeElementsBaseVertex (GLenum mode, GLuint start, GLuint end, GLsizei count, GLenum type, const GLvoid *indices, GLint basevertex); +GLAPI void APIENTRY glDrawElementsInstancedBaseVertex (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLsizei instancecount, GLint basevertex); +GLAPI void APIENTRY glMultiDrawElementsBaseVertex (GLenum mode, const GLsizei *count, GLenum type, const GLvoid* const *indices, GLsizei drawcount, const GLint *basevertex); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWELEMENTSBASEVERTEXPROC) (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLint basevertex); +typedef void (APIENTRYP PFNGLDRAWRANGEELEMENTSBASEVERTEXPROC) (GLenum mode, GLuint start, GLuint end, GLsizei count, GLenum type, const GLvoid *indices, GLint basevertex); +typedef void (APIENTRYP PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXPROC) (GLenum mode, GLsizei count, GLenum type, const GLvoid *indices, GLsizei instancecount, GLint basevertex); +typedef void (APIENTRYP PFNGLMULTIDRAWELEMENTSBASEVERTEXPROC) (GLenum mode, const GLsizei *count, GLenum type, const GLvoid* const *indices, GLsizei drawcount, const GLint *basevertex); +#endif + +#ifndef GL_ARB_fragment_coord_conventions +#define GL_ARB_fragment_coord_conventions 1 +#endif + +#ifndef GL_ARB_provoking_vertex +#define GL_ARB_provoking_vertex 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glProvokingVertex (GLenum mode); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLPROVOKINGVERTEXPROC) (GLenum mode); +#endif + +#ifndef GL_ARB_seamless_cube_map +#define GL_ARB_seamless_cube_map 1 +#endif + +#ifndef GL_ARB_sync +#define GL_ARB_sync 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLsync APIENTRY glFenceSync (GLenum condition, GLbitfield flags); +GLAPI GLboolean APIENTRY glIsSync (GLsync sync); +GLAPI void APIENTRY glDeleteSync (GLsync sync); +GLAPI GLenum APIENTRY glClientWaitSync (GLsync sync, GLbitfield flags, GLuint64 timeout); +GLAPI void APIENTRY glWaitSync (GLsync sync, GLbitfield flags, GLuint64 timeout); +GLAPI void APIENTRY glGetInteger64v (GLenum pname, GLint64 *params); +GLAPI void APIENTRY glGetSynciv (GLsync sync, GLenum pname, GLsizei bufSize, GLsizei *length, GLint *values); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLsync (APIENTRYP PFNGLFENCESYNCPROC) (GLenum condition, GLbitfield flags); +typedef GLboolean (APIENTRYP PFNGLISSYNCPROC) (GLsync sync); +typedef void (APIENTRYP PFNGLDELETESYNCPROC) (GLsync sync); +typedef GLenum (APIENTRYP PFNGLCLIENTWAITSYNCPROC) (GLsync sync, GLbitfield flags, GLuint64 timeout); +typedef void (APIENTRYP PFNGLWAITSYNCPROC) (GLsync sync, GLbitfield flags, GLuint64 timeout); +typedef void (APIENTRYP PFNGLGETINTEGER64VPROC) (GLenum pname, GLint64 *params); +typedef void (APIENTRYP PFNGLGETSYNCIVPROC) (GLsync sync, GLenum pname, GLsizei bufSize, GLsizei *length, GLint *values); +#endif + +#ifndef GL_ARB_texture_multisample +#define GL_ARB_texture_multisample 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glTexImage2DMultisample (GLenum target, GLsizei samples, GLint internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +GLAPI void APIENTRY glTexImage3DMultisample (GLenum target, GLsizei samples, GLint internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +GLAPI void APIENTRY glGetMultisamplefv (GLenum pname, GLuint index, GLfloat *val); +GLAPI void APIENTRY glSampleMaski (GLuint index, GLbitfield mask); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLTEXIMAGE2DMULTISAMPLEPROC) (GLenum target, GLsizei samples, GLint internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +typedef void (APIENTRYP PFNGLTEXIMAGE3DMULTISAMPLEPROC) (GLenum target, GLsizei samples, GLint internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +typedef void (APIENTRYP PFNGLGETMULTISAMPLEFVPROC) (GLenum pname, GLuint index, GLfloat *val); +typedef void (APIENTRYP PFNGLSAMPLEMASKIPROC) (GLuint index, GLbitfield mask); +#endif + +#ifndef GL_ARB_vertex_array_bgra +#define GL_ARB_vertex_array_bgra 1 +#endif + +#ifndef GL_ARB_draw_buffers_blend +#define GL_ARB_draw_buffers_blend 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBlendEquationiARB (GLuint buf, GLenum mode); +GLAPI void APIENTRY glBlendEquationSeparateiARB (GLuint buf, GLenum modeRGB, GLenum modeAlpha); +GLAPI void APIENTRY glBlendFunciARB (GLuint buf, GLenum src, GLenum dst); +GLAPI void APIENTRY glBlendFuncSeparateiARB (GLuint buf, GLenum srcRGB, GLenum dstRGB, GLenum srcAlpha, GLenum dstAlpha); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBLENDEQUATIONIARBPROC) (GLuint buf, GLenum mode); +typedef void (APIENTRYP PFNGLBLENDEQUATIONSEPARATEIARBPROC) (GLuint buf, GLenum modeRGB, GLenum modeAlpha); +typedef void (APIENTRYP PFNGLBLENDFUNCIARBPROC) (GLuint buf, GLenum src, GLenum dst); +typedef void (APIENTRYP PFNGLBLENDFUNCSEPARATEIARBPROC) (GLuint buf, GLenum srcRGB, GLenum dstRGB, GLenum srcAlpha, GLenum dstAlpha); +#endif + +#ifndef GL_ARB_sample_shading +#define GL_ARB_sample_shading 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glMinSampleShadingARB (GLfloat value); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLMINSAMPLESHADINGARBPROC) (GLfloat value); +#endif + +#ifndef GL_ARB_texture_cube_map_array +#define GL_ARB_texture_cube_map_array 1 +#endif + +#ifndef GL_ARB_texture_gather +#define GL_ARB_texture_gather 1 +#endif + +#ifndef GL_ARB_texture_query_lod +#define GL_ARB_texture_query_lod 1 +#endif + +#ifndef GL_ARB_shading_language_include +#define GL_ARB_shading_language_include 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glNamedStringARB (GLenum type, GLint namelen, const GLchar *name, GLint stringlen, const GLchar *string); +GLAPI void APIENTRY glDeleteNamedStringARB (GLint namelen, const GLchar *name); +GLAPI void APIENTRY glCompileShaderIncludeARB (GLuint shader, GLsizei count, const GLchar* *path, const GLint *length); +GLAPI GLboolean APIENTRY glIsNamedStringARB (GLint namelen, const GLchar *name); +GLAPI void APIENTRY glGetNamedStringARB (GLint namelen, const GLchar *name, GLsizei bufSize, GLint *stringlen, GLchar *string); +GLAPI void APIENTRY glGetNamedStringivARB (GLint namelen, const GLchar *name, GLenum pname, GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLNAMEDSTRINGARBPROC) (GLenum type, GLint namelen, const GLchar *name, GLint stringlen, const GLchar *string); +typedef void (APIENTRYP PFNGLDELETENAMEDSTRINGARBPROC) (GLint namelen, const GLchar *name); +typedef void (APIENTRYP PFNGLCOMPILESHADERINCLUDEARBPROC) (GLuint shader, GLsizei count, const GLchar* *path, const GLint *length); +typedef GLboolean (APIENTRYP PFNGLISNAMEDSTRINGARBPROC) (GLint namelen, const GLchar *name); +typedef void (APIENTRYP PFNGLGETNAMEDSTRINGARBPROC) (GLint namelen, const GLchar *name, GLsizei bufSize, GLint *stringlen, GLchar *string); +typedef void (APIENTRYP PFNGLGETNAMEDSTRINGIVARBPROC) (GLint namelen, const GLchar *name, GLenum pname, GLint *params); +#endif + +#ifndef GL_ARB_texture_compression_bptc +#define GL_ARB_texture_compression_bptc 1 +#endif + +#ifndef GL_ARB_blend_func_extended +#define GL_ARB_blend_func_extended 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBindFragDataLocationIndexed (GLuint program, GLuint colorNumber, GLuint index, const GLchar *name); +GLAPI GLint APIENTRY glGetFragDataIndex (GLuint program, const GLchar *name); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBINDFRAGDATALOCATIONINDEXEDPROC) (GLuint program, GLuint colorNumber, GLuint index, const GLchar *name); +typedef GLint (APIENTRYP PFNGLGETFRAGDATAINDEXPROC) (GLuint program, const GLchar *name); +#endif + +#ifndef GL_ARB_explicit_attrib_location +#define GL_ARB_explicit_attrib_location 1 +#endif + +#ifndef GL_ARB_occlusion_query2 +#define GL_ARB_occlusion_query2 1 +#endif + +#ifndef GL_ARB_sampler_objects +#define GL_ARB_sampler_objects 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGenSamplers (GLsizei count, GLuint *samplers); +GLAPI void APIENTRY glDeleteSamplers (GLsizei count, const GLuint *samplers); +GLAPI GLboolean APIENTRY glIsSampler (GLuint sampler); +GLAPI void APIENTRY glBindSampler (GLuint unit, GLuint sampler); +GLAPI void APIENTRY glSamplerParameteri (GLuint sampler, GLenum pname, GLint param); +GLAPI void APIENTRY glSamplerParameteriv (GLuint sampler, GLenum pname, const GLint *param); +GLAPI void APIENTRY glSamplerParameterf (GLuint sampler, GLenum pname, GLfloat param); +GLAPI void APIENTRY glSamplerParameterfv (GLuint sampler, GLenum pname, const GLfloat *param); +GLAPI void APIENTRY glSamplerParameterIiv (GLuint sampler, GLenum pname, const GLint *param); +GLAPI void APIENTRY glSamplerParameterIuiv (GLuint sampler, GLenum pname, const GLuint *param); +GLAPI void APIENTRY glGetSamplerParameteriv (GLuint sampler, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetSamplerParameterIiv (GLuint sampler, GLenum pname, GLint *params); +GLAPI void APIENTRY glGetSamplerParameterfv (GLuint sampler, GLenum pname, GLfloat *params); +GLAPI void APIENTRY glGetSamplerParameterIuiv (GLuint sampler, GLenum pname, GLuint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGENSAMPLERSPROC) (GLsizei count, GLuint *samplers); +typedef void (APIENTRYP PFNGLDELETESAMPLERSPROC) (GLsizei count, const GLuint *samplers); +typedef GLboolean (APIENTRYP PFNGLISSAMPLERPROC) (GLuint sampler); +typedef void (APIENTRYP PFNGLBINDSAMPLERPROC) (GLuint unit, GLuint sampler); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERIPROC) (GLuint sampler, GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERIVPROC) (GLuint sampler, GLenum pname, const GLint *param); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERFPROC) (GLuint sampler, GLenum pname, GLfloat param); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERFVPROC) (GLuint sampler, GLenum pname, const GLfloat *param); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERIIVPROC) (GLuint sampler, GLenum pname, const GLint *param); +typedef void (APIENTRYP PFNGLSAMPLERPARAMETERIUIVPROC) (GLuint sampler, GLenum pname, const GLuint *param); +typedef void (APIENTRYP PFNGLGETSAMPLERPARAMETERIVPROC) (GLuint sampler, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETSAMPLERPARAMETERIIVPROC) (GLuint sampler, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLGETSAMPLERPARAMETERFVPROC) (GLuint sampler, GLenum pname, GLfloat *params); +typedef void (APIENTRYP PFNGLGETSAMPLERPARAMETERIUIVPROC) (GLuint sampler, GLenum pname, GLuint *params); +#endif + +#ifndef GL_ARB_shader_bit_encoding +#define GL_ARB_shader_bit_encoding 1 +#endif + +#ifndef GL_ARB_texture_rgb10_a2ui +#define GL_ARB_texture_rgb10_a2ui 1 +#endif + +#ifndef GL_ARB_texture_swizzle +#define GL_ARB_texture_swizzle 1 +#endif + +#ifndef GL_ARB_timer_query +#define GL_ARB_timer_query 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glQueryCounter (GLuint id, GLenum target); +GLAPI void APIENTRY glGetQueryObjecti64v (GLuint id, GLenum pname, GLint64 *params); +GLAPI void APIENTRY glGetQueryObjectui64v (GLuint id, GLenum pname, GLuint64 *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLQUERYCOUNTERPROC) (GLuint id, GLenum target); +typedef void (APIENTRYP PFNGLGETQUERYOBJECTI64VPROC) (GLuint id, GLenum pname, GLint64 *params); +typedef void (APIENTRYP PFNGLGETQUERYOBJECTUI64VPROC) (GLuint id, GLenum pname, GLuint64 *params); +#endif + +#ifndef GL_ARB_vertex_type_2_10_10_10_rev +#define GL_ARB_vertex_type_2_10_10_10_rev 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glVertexP2ui (GLenum type, GLuint value); +GLAPI void APIENTRY glVertexP2uiv (GLenum type, const GLuint *value); +GLAPI void APIENTRY glVertexP3ui (GLenum type, GLuint value); +GLAPI void APIENTRY glVertexP3uiv (GLenum type, const GLuint *value); +GLAPI void APIENTRY glVertexP4ui (GLenum type, GLuint value); +GLAPI void APIENTRY glVertexP4uiv (GLenum type, const GLuint *value); +GLAPI void APIENTRY glTexCoordP1ui (GLenum type, GLuint coords); +GLAPI void APIENTRY glTexCoordP1uiv (GLenum type, const GLuint *coords); +GLAPI void APIENTRY glTexCoordP2ui (GLenum type, GLuint coords); +GLAPI void APIENTRY glTexCoordP2uiv (GLenum type, const GLuint *coords); +GLAPI void APIENTRY glTexCoordP3ui (GLenum type, GLuint coords); +GLAPI void APIENTRY glTexCoordP3uiv (GLenum type, const GLuint *coords); +GLAPI void APIENTRY glTexCoordP4ui (GLenum type, GLuint coords); +GLAPI void APIENTRY glTexCoordP4uiv (GLenum type, const GLuint *coords); +GLAPI void APIENTRY glMultiTexCoordP1ui (GLenum texture, GLenum type, GLuint coords); +GLAPI void APIENTRY glMultiTexCoordP1uiv (GLenum texture, GLenum type, const GLuint *coords); +GLAPI void APIENTRY glMultiTexCoordP2ui (GLenum texture, GLenum type, GLuint coords); +GLAPI void APIENTRY glMultiTexCoordP2uiv (GLenum texture, GLenum type, const GLuint *coords); +GLAPI void APIENTRY glMultiTexCoordP3ui (GLenum texture, GLenum type, GLuint coords); +GLAPI void APIENTRY glMultiTexCoordP3uiv (GLenum texture, GLenum type, const GLuint *coords); +GLAPI void APIENTRY glMultiTexCoordP4ui (GLenum texture, GLenum type, GLuint coords); +GLAPI void APIENTRY glMultiTexCoordP4uiv (GLenum texture, GLenum type, const GLuint *coords); +GLAPI void APIENTRY glNormalP3ui (GLenum type, GLuint coords); +GLAPI void APIENTRY glNormalP3uiv (GLenum type, const GLuint *coords); +GLAPI void APIENTRY glColorP3ui (GLenum type, GLuint color); +GLAPI void APIENTRY glColorP3uiv (GLenum type, const GLuint *color); +GLAPI void APIENTRY glColorP4ui (GLenum type, GLuint color); +GLAPI void APIENTRY glColorP4uiv (GLenum type, const GLuint *color); +GLAPI void APIENTRY glSecondaryColorP3ui (GLenum type, GLuint color); +GLAPI void APIENTRY glSecondaryColorP3uiv (GLenum type, const GLuint *color); +GLAPI void APIENTRY glVertexAttribP1ui (GLuint index, GLenum type, GLboolean normalized, GLuint value); +GLAPI void APIENTRY glVertexAttribP1uiv (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +GLAPI void APIENTRY glVertexAttribP2ui (GLuint index, GLenum type, GLboolean normalized, GLuint value); +GLAPI void APIENTRY glVertexAttribP2uiv (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +GLAPI void APIENTRY glVertexAttribP3ui (GLuint index, GLenum type, GLboolean normalized, GLuint value); +GLAPI void APIENTRY glVertexAttribP3uiv (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +GLAPI void APIENTRY glVertexAttribP4ui (GLuint index, GLenum type, GLboolean normalized, GLuint value); +GLAPI void APIENTRY glVertexAttribP4uiv (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLVERTEXP2UIPROC) (GLenum type, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXP2UIVPROC) (GLenum type, const GLuint *value); +typedef void (APIENTRYP PFNGLVERTEXP3UIPROC) (GLenum type, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXP3UIVPROC) (GLenum type, const GLuint *value); +typedef void (APIENTRYP PFNGLVERTEXP4UIPROC) (GLenum type, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXP4UIVPROC) (GLenum type, const GLuint *value); +typedef void (APIENTRYP PFNGLTEXCOORDP1UIPROC) (GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLTEXCOORDP1UIVPROC) (GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLTEXCOORDP2UIPROC) (GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLTEXCOORDP2UIVPROC) (GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLTEXCOORDP3UIPROC) (GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLTEXCOORDP3UIVPROC) (GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLTEXCOORDP4UIPROC) (GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLTEXCOORDP4UIVPROC) (GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP1UIPROC) (GLenum texture, GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP1UIVPROC) (GLenum texture, GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP2UIPROC) (GLenum texture, GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP2UIVPROC) (GLenum texture, GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP3UIPROC) (GLenum texture, GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP3UIVPROC) (GLenum texture, GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP4UIPROC) (GLenum texture, GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLMULTITEXCOORDP4UIVPROC) (GLenum texture, GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLNORMALP3UIPROC) (GLenum type, GLuint coords); +typedef void (APIENTRYP PFNGLNORMALP3UIVPROC) (GLenum type, const GLuint *coords); +typedef void (APIENTRYP PFNGLCOLORP3UIPROC) (GLenum type, GLuint color); +typedef void (APIENTRYP PFNGLCOLORP3UIVPROC) (GLenum type, const GLuint *color); +typedef void (APIENTRYP PFNGLCOLORP4UIPROC) (GLenum type, GLuint color); +typedef void (APIENTRYP PFNGLCOLORP4UIVPROC) (GLenum type, const GLuint *color); +typedef void (APIENTRYP PFNGLSECONDARYCOLORP3UIPROC) (GLenum type, GLuint color); +typedef void (APIENTRYP PFNGLSECONDARYCOLORP3UIVPROC) (GLenum type, const GLuint *color); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP1UIPROC) (GLuint index, GLenum type, GLboolean normalized, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP1UIVPROC) (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP2UIPROC) (GLuint index, GLenum type, GLboolean normalized, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP2UIVPROC) (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP3UIPROC) (GLuint index, GLenum type, GLboolean normalized, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP3UIVPROC) (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP4UIPROC) (GLuint index, GLenum type, GLboolean normalized, GLuint value); +typedef void (APIENTRYP PFNGLVERTEXATTRIBP4UIVPROC) (GLuint index, GLenum type, GLboolean normalized, const GLuint *value); +#endif + +#ifndef GL_ARB_draw_indirect +#define GL_ARB_draw_indirect 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawArraysIndirect (GLenum mode, const GLvoid *indirect); +GLAPI void APIENTRY glDrawElementsIndirect (GLenum mode, GLenum type, const GLvoid *indirect); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWARRAYSINDIRECTPROC) (GLenum mode, const GLvoid *indirect); +typedef void (APIENTRYP PFNGLDRAWELEMENTSINDIRECTPROC) (GLenum mode, GLenum type, const GLvoid *indirect); +#endif + +#ifndef GL_ARB_gpu_shader5 +#define GL_ARB_gpu_shader5 1 +#endif + +#ifndef GL_ARB_gpu_shader_fp64 +#define GL_ARB_gpu_shader_fp64 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glUniform1d (GLint location, GLdouble x); +GLAPI void APIENTRY glUniform2d (GLint location, GLdouble x, GLdouble y); +GLAPI void APIENTRY glUniform3d (GLint location, GLdouble x, GLdouble y, GLdouble z); +GLAPI void APIENTRY glUniform4d (GLint location, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +GLAPI void APIENTRY glUniform1dv (GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glUniform2dv (GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glUniform3dv (GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glUniform4dv (GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix2dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix3dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix4dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix2x3dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix2x4dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix3x2dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix3x4dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix4x2dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glUniformMatrix4x3dv (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glGetUniformdv (GLuint program, GLint location, GLdouble *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLUNIFORM1DPROC) (GLint location, GLdouble x); +typedef void (APIENTRYP PFNGLUNIFORM2DPROC) (GLint location, GLdouble x, GLdouble y); +typedef void (APIENTRYP PFNGLUNIFORM3DPROC) (GLint location, GLdouble x, GLdouble y, GLdouble z); +typedef void (APIENTRYP PFNGLUNIFORM4DPROC) (GLint location, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +typedef void (APIENTRYP PFNGLUNIFORM1DVPROC) (GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORM2DVPROC) (GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORM3DVPROC) (GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORM4DVPROC) (GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2X3DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX2X4DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3X2DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX3X4DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4X2DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLUNIFORMMATRIX4X3DVPROC) (GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLGETUNIFORMDVPROC) (GLuint program, GLint location, GLdouble *params); +#endif + +#ifndef GL_ARB_shader_subroutine +#define GL_ARB_shader_subroutine 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLint APIENTRY glGetSubroutineUniformLocation (GLuint program, GLenum shadertype, const GLchar *name); +GLAPI GLuint APIENTRY glGetSubroutineIndex (GLuint program, GLenum shadertype, const GLchar *name); +GLAPI void APIENTRY glGetActiveSubroutineUniformiv (GLuint program, GLenum shadertype, GLuint index, GLenum pname, GLint *values); +GLAPI void APIENTRY glGetActiveSubroutineUniformName (GLuint program, GLenum shadertype, GLuint index, GLsizei bufsize, GLsizei *length, GLchar *name); +GLAPI void APIENTRY glGetActiveSubroutineName (GLuint program, GLenum shadertype, GLuint index, GLsizei bufsize, GLsizei *length, GLchar *name); +GLAPI void APIENTRY glUniformSubroutinesuiv (GLenum shadertype, GLsizei count, const GLuint *indices); +GLAPI void APIENTRY glGetUniformSubroutineuiv (GLenum shadertype, GLint location, GLuint *params); +GLAPI void APIENTRY glGetProgramStageiv (GLuint program, GLenum shadertype, GLenum pname, GLint *values); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLint (APIENTRYP PFNGLGETSUBROUTINEUNIFORMLOCATIONPROC) (GLuint program, GLenum shadertype, const GLchar *name); +typedef GLuint (APIENTRYP PFNGLGETSUBROUTINEINDEXPROC) (GLuint program, GLenum shadertype, const GLchar *name); +typedef void (APIENTRYP PFNGLGETACTIVESUBROUTINEUNIFORMIVPROC) (GLuint program, GLenum shadertype, GLuint index, GLenum pname, GLint *values); +typedef void (APIENTRYP PFNGLGETACTIVESUBROUTINEUNIFORMNAMEPROC) (GLuint program, GLenum shadertype, GLuint index, GLsizei bufsize, GLsizei *length, GLchar *name); +typedef void (APIENTRYP PFNGLGETACTIVESUBROUTINENAMEPROC) (GLuint program, GLenum shadertype, GLuint index, GLsizei bufsize, GLsizei *length, GLchar *name); +typedef void (APIENTRYP PFNGLUNIFORMSUBROUTINESUIVPROC) (GLenum shadertype, GLsizei count, const GLuint *indices); +typedef void (APIENTRYP PFNGLGETUNIFORMSUBROUTINEUIVPROC) (GLenum shadertype, GLint location, GLuint *params); +typedef void (APIENTRYP PFNGLGETPROGRAMSTAGEIVPROC) (GLuint program, GLenum shadertype, GLenum pname, GLint *values); +#endif + +#ifndef GL_ARB_tessellation_shader +#define GL_ARB_tessellation_shader 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glPatchParameteri (GLenum pname, GLint value); +GLAPI void APIENTRY glPatchParameterfv (GLenum pname, const GLfloat *values); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLPATCHPARAMETERIPROC) (GLenum pname, GLint value); +typedef void (APIENTRYP PFNGLPATCHPARAMETERFVPROC) (GLenum pname, const GLfloat *values); +#endif + +#ifndef GL_ARB_texture_buffer_object_rgb32 +#define GL_ARB_texture_buffer_object_rgb32 1 +#endif + +#ifndef GL_ARB_transform_feedback2 +#define GL_ARB_transform_feedback2 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBindTransformFeedback (GLenum target, GLuint id); +GLAPI void APIENTRY glDeleteTransformFeedbacks (GLsizei n, const GLuint *ids); +GLAPI void APIENTRY glGenTransformFeedbacks (GLsizei n, GLuint *ids); +GLAPI GLboolean APIENTRY glIsTransformFeedback (GLuint id); +GLAPI void APIENTRY glPauseTransformFeedback (void); +GLAPI void APIENTRY glResumeTransformFeedback (void); +GLAPI void APIENTRY glDrawTransformFeedback (GLenum mode, GLuint id); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBINDTRANSFORMFEEDBACKPROC) (GLenum target, GLuint id); +typedef void (APIENTRYP PFNGLDELETETRANSFORMFEEDBACKSPROC) (GLsizei n, const GLuint *ids); +typedef void (APIENTRYP PFNGLGENTRANSFORMFEEDBACKSPROC) (GLsizei n, GLuint *ids); +typedef GLboolean (APIENTRYP PFNGLISTRANSFORMFEEDBACKPROC) (GLuint id); +typedef void (APIENTRYP PFNGLPAUSETRANSFORMFEEDBACKPROC) (void); +typedef void (APIENTRYP PFNGLRESUMETRANSFORMFEEDBACKPROC) (void); +typedef void (APIENTRYP PFNGLDRAWTRANSFORMFEEDBACKPROC) (GLenum mode, GLuint id); +#endif + +#ifndef GL_ARB_transform_feedback3 +#define GL_ARB_transform_feedback3 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawTransformFeedbackStream (GLenum mode, GLuint id, GLuint stream); +GLAPI void APIENTRY glBeginQueryIndexed (GLenum target, GLuint index, GLuint id); +GLAPI void APIENTRY glEndQueryIndexed (GLenum target, GLuint index); +GLAPI void APIENTRY glGetQueryIndexediv (GLenum target, GLuint index, GLenum pname, GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWTRANSFORMFEEDBACKSTREAMPROC) (GLenum mode, GLuint id, GLuint stream); +typedef void (APIENTRYP PFNGLBEGINQUERYINDEXEDPROC) (GLenum target, GLuint index, GLuint id); +typedef void (APIENTRYP PFNGLENDQUERYINDEXEDPROC) (GLenum target, GLuint index); +typedef void (APIENTRYP PFNGLGETQUERYINDEXEDIVPROC) (GLenum target, GLuint index, GLenum pname, GLint *params); +#endif + +#ifndef GL_ARB_ES2_compatibility +#define GL_ARB_ES2_compatibility 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glReleaseShaderCompiler (void); +GLAPI void APIENTRY glShaderBinary (GLsizei count, const GLuint *shaders, GLenum binaryformat, const GLvoid *binary, GLsizei length); +GLAPI void APIENTRY glGetShaderPrecisionFormat (GLenum shadertype, GLenum precisiontype, GLint *range, GLint *precision); +GLAPI void APIENTRY glDepthRangef (GLfloat n, GLfloat f); +GLAPI void APIENTRY glClearDepthf (GLfloat d); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLRELEASESHADERCOMPILERPROC) (void); +typedef void (APIENTRYP PFNGLSHADERBINARYPROC) (GLsizei count, const GLuint *shaders, GLenum binaryformat, const GLvoid *binary, GLsizei length); +typedef void (APIENTRYP PFNGLGETSHADERPRECISIONFORMATPROC) (GLenum shadertype, GLenum precisiontype, GLint *range, GLint *precision); +typedef void (APIENTRYP PFNGLDEPTHRANGEFPROC) (GLfloat n, GLfloat f); +typedef void (APIENTRYP PFNGLCLEARDEPTHFPROC) (GLfloat d); +#endif + +#ifndef GL_ARB_get_program_binary +#define GL_ARB_get_program_binary 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetProgramBinary (GLuint program, GLsizei bufSize, GLsizei *length, GLenum *binaryFormat, GLvoid *binary); +GLAPI void APIENTRY glProgramBinary (GLuint program, GLenum binaryFormat, const GLvoid *binary, GLsizei length); +GLAPI void APIENTRY glProgramParameteri (GLuint program, GLenum pname, GLint value); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETPROGRAMBINARYPROC) (GLuint program, GLsizei bufSize, GLsizei *length, GLenum *binaryFormat, GLvoid *binary); +typedef void (APIENTRYP PFNGLPROGRAMBINARYPROC) (GLuint program, GLenum binaryFormat, const GLvoid *binary, GLsizei length); +typedef void (APIENTRYP PFNGLPROGRAMPARAMETERIPROC) (GLuint program, GLenum pname, GLint value); +#endif + +#ifndef GL_ARB_separate_shader_objects +#define GL_ARB_separate_shader_objects 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glUseProgramStages (GLuint pipeline, GLbitfield stages, GLuint program); +GLAPI void APIENTRY glActiveShaderProgram (GLuint pipeline, GLuint program); +GLAPI GLuint APIENTRY glCreateShaderProgramv (GLenum type, GLsizei count, const GLchar* const *strings); +GLAPI void APIENTRY glBindProgramPipeline (GLuint pipeline); +GLAPI void APIENTRY glDeleteProgramPipelines (GLsizei n, const GLuint *pipelines); +GLAPI void APIENTRY glGenProgramPipelines (GLsizei n, GLuint *pipelines); +GLAPI GLboolean APIENTRY glIsProgramPipeline (GLuint pipeline); +GLAPI void APIENTRY glGetProgramPipelineiv (GLuint pipeline, GLenum pname, GLint *params); +GLAPI void APIENTRY glProgramUniform1i (GLuint program, GLint location, GLint v0); +GLAPI void APIENTRY glProgramUniform1iv (GLuint program, GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glProgramUniform1f (GLuint program, GLint location, GLfloat v0); +GLAPI void APIENTRY glProgramUniform1fv (GLuint program, GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glProgramUniform1d (GLuint program, GLint location, GLdouble v0); +GLAPI void APIENTRY glProgramUniform1dv (GLuint program, GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glProgramUniform1ui (GLuint program, GLint location, GLuint v0); +GLAPI void APIENTRY glProgramUniform1uiv (GLuint program, GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glProgramUniform2i (GLuint program, GLint location, GLint v0, GLint v1); +GLAPI void APIENTRY glProgramUniform2iv (GLuint program, GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glProgramUniform2f (GLuint program, GLint location, GLfloat v0, GLfloat v1); +GLAPI void APIENTRY glProgramUniform2fv (GLuint program, GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glProgramUniform2d (GLuint program, GLint location, GLdouble v0, GLdouble v1); +GLAPI void APIENTRY glProgramUniform2dv (GLuint program, GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glProgramUniform2ui (GLuint program, GLint location, GLuint v0, GLuint v1); +GLAPI void APIENTRY glProgramUniform2uiv (GLuint program, GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glProgramUniform3i (GLuint program, GLint location, GLint v0, GLint v1, GLint v2); +GLAPI void APIENTRY glProgramUniform3iv (GLuint program, GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glProgramUniform3f (GLuint program, GLint location, GLfloat v0, GLfloat v1, GLfloat v2); +GLAPI void APIENTRY glProgramUniform3fv (GLuint program, GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glProgramUniform3d (GLuint program, GLint location, GLdouble v0, GLdouble v1, GLdouble v2); +GLAPI void APIENTRY glProgramUniform3dv (GLuint program, GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glProgramUniform3ui (GLuint program, GLint location, GLuint v0, GLuint v1, GLuint v2); +GLAPI void APIENTRY glProgramUniform3uiv (GLuint program, GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glProgramUniform4i (GLuint program, GLint location, GLint v0, GLint v1, GLint v2, GLint v3); +GLAPI void APIENTRY glProgramUniform4iv (GLuint program, GLint location, GLsizei count, const GLint *value); +GLAPI void APIENTRY glProgramUniform4f (GLuint program, GLint location, GLfloat v0, GLfloat v1, GLfloat v2, GLfloat v3); +GLAPI void APIENTRY glProgramUniform4fv (GLuint program, GLint location, GLsizei count, const GLfloat *value); +GLAPI void APIENTRY glProgramUniform4d (GLuint program, GLint location, GLdouble v0, GLdouble v1, GLdouble v2, GLdouble v3); +GLAPI void APIENTRY glProgramUniform4dv (GLuint program, GLint location, GLsizei count, const GLdouble *value); +GLAPI void APIENTRY glProgramUniform4ui (GLuint program, GLint location, GLuint v0, GLuint v1, GLuint v2, GLuint v3); +GLAPI void APIENTRY glProgramUniform4uiv (GLuint program, GLint location, GLsizei count, const GLuint *value); +GLAPI void APIENTRY glProgramUniformMatrix2fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix3fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix4fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix2dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix3dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix4dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix2x3fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix3x2fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix2x4fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix4x2fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix3x4fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix4x3fv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +GLAPI void APIENTRY glProgramUniformMatrix2x3dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix3x2dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix2x4dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix4x2dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix3x4dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glProgramUniformMatrix4x3dv (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +GLAPI void APIENTRY glValidateProgramPipeline (GLuint pipeline); +GLAPI void APIENTRY glGetProgramPipelineInfoLog (GLuint pipeline, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLUSEPROGRAMSTAGESPROC) (GLuint pipeline, GLbitfield stages, GLuint program); +typedef void (APIENTRYP PFNGLACTIVESHADERPROGRAMPROC) (GLuint pipeline, GLuint program); +typedef GLuint (APIENTRYP PFNGLCREATESHADERPROGRAMVPROC) (GLenum type, GLsizei count, const GLchar* const *strings); +typedef void (APIENTRYP PFNGLBINDPROGRAMPIPELINEPROC) (GLuint pipeline); +typedef void (APIENTRYP PFNGLDELETEPROGRAMPIPELINESPROC) (GLsizei n, const GLuint *pipelines); +typedef void (APIENTRYP PFNGLGENPROGRAMPIPELINESPROC) (GLsizei n, GLuint *pipelines); +typedef GLboolean (APIENTRYP PFNGLISPROGRAMPIPELINEPROC) (GLuint pipeline); +typedef void (APIENTRYP PFNGLGETPROGRAMPIPELINEIVPROC) (GLuint pipeline, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1IPROC) (GLuint program, GLint location, GLint v0); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1IVPROC) (GLuint program, GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1FPROC) (GLuint program, GLint location, GLfloat v0); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1FVPROC) (GLuint program, GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1DPROC) (GLuint program, GLint location, GLdouble v0); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1DVPROC) (GLuint program, GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1UIPROC) (GLuint program, GLint location, GLuint v0); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM1UIVPROC) (GLuint program, GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2IPROC) (GLuint program, GLint location, GLint v0, GLint v1); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2IVPROC) (GLuint program, GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2FPROC) (GLuint program, GLint location, GLfloat v0, GLfloat v1); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2FVPROC) (GLuint program, GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2DPROC) (GLuint program, GLint location, GLdouble v0, GLdouble v1); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2DVPROC) (GLuint program, GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2UIPROC) (GLuint program, GLint location, GLuint v0, GLuint v1); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM2UIVPROC) (GLuint program, GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3IPROC) (GLuint program, GLint location, GLint v0, GLint v1, GLint v2); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3IVPROC) (GLuint program, GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3FPROC) (GLuint program, GLint location, GLfloat v0, GLfloat v1, GLfloat v2); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3FVPROC) (GLuint program, GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3DPROC) (GLuint program, GLint location, GLdouble v0, GLdouble v1, GLdouble v2); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3DVPROC) (GLuint program, GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3UIPROC) (GLuint program, GLint location, GLuint v0, GLuint v1, GLuint v2); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM3UIVPROC) (GLuint program, GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4IPROC) (GLuint program, GLint location, GLint v0, GLint v1, GLint v2, GLint v3); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4IVPROC) (GLuint program, GLint location, GLsizei count, const GLint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4FPROC) (GLuint program, GLint location, GLfloat v0, GLfloat v1, GLfloat v2, GLfloat v3); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4FVPROC) (GLuint program, GLint location, GLsizei count, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4DPROC) (GLuint program, GLint location, GLdouble v0, GLdouble v1, GLdouble v2, GLdouble v3); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4DVPROC) (GLuint program, GLint location, GLsizei count, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4UIPROC) (GLuint program, GLint location, GLuint v0, GLuint v1, GLuint v2, GLuint v3); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORM4UIVPROC) (GLuint program, GLint location, GLsizei count, const GLuint *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2X3FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3X2FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2X4FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4X2FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3X4FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4X3FVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLfloat *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2X3DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3X2DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX2X4DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4X2DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX3X4DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLPROGRAMUNIFORMMATRIX4X3DVPROC) (GLuint program, GLint location, GLsizei count, GLboolean transpose, const GLdouble *value); +typedef void (APIENTRYP PFNGLVALIDATEPROGRAMPIPELINEPROC) (GLuint pipeline); +typedef void (APIENTRYP PFNGLGETPROGRAMPIPELINEINFOLOGPROC) (GLuint pipeline, GLsizei bufSize, GLsizei *length, GLchar *infoLog); +#endif + +#ifndef GL_ARB_vertex_attrib_64bit +#define GL_ARB_vertex_attrib_64bit 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glVertexAttribL1d (GLuint index, GLdouble x); +GLAPI void APIENTRY glVertexAttribL2d (GLuint index, GLdouble x, GLdouble y); +GLAPI void APIENTRY glVertexAttribL3d (GLuint index, GLdouble x, GLdouble y, GLdouble z); +GLAPI void APIENTRY glVertexAttribL4d (GLuint index, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +GLAPI void APIENTRY glVertexAttribL1dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttribL2dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttribL3dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttribL4dv (GLuint index, const GLdouble *v); +GLAPI void APIENTRY glVertexAttribLPointer (GLuint index, GLint size, GLenum type, GLsizei stride, const GLvoid *pointer); +GLAPI void APIENTRY glGetVertexAttribLdv (GLuint index, GLenum pname, GLdouble *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLVERTEXATTRIBL1DPROC) (GLuint index, GLdouble x); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL2DPROC) (GLuint index, GLdouble x, GLdouble y); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL3DPROC) (GLuint index, GLdouble x, GLdouble y, GLdouble z); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL4DPROC) (GLuint index, GLdouble x, GLdouble y, GLdouble z, GLdouble w); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL1DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL2DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL3DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBL4DVPROC) (GLuint index, const GLdouble *v); +typedef void (APIENTRYP PFNGLVERTEXATTRIBLPOINTERPROC) (GLuint index, GLint size, GLenum type, GLsizei stride, const GLvoid *pointer); +typedef void (APIENTRYP PFNGLGETVERTEXATTRIBLDVPROC) (GLuint index, GLenum pname, GLdouble *params); +#endif + +#ifndef GL_ARB_viewport_array +#define GL_ARB_viewport_array 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glViewportArrayv (GLuint first, GLsizei count, const GLfloat *v); +GLAPI void APIENTRY glViewportIndexedf (GLuint index, GLfloat x, GLfloat y, GLfloat w, GLfloat h); +GLAPI void APIENTRY glViewportIndexedfv (GLuint index, const GLfloat *v); +GLAPI void APIENTRY glScissorArrayv (GLuint first, GLsizei count, const GLint *v); +GLAPI void APIENTRY glScissorIndexed (GLuint index, GLint left, GLint bottom, GLsizei width, GLsizei height); +GLAPI void APIENTRY glScissorIndexedv (GLuint index, const GLint *v); +GLAPI void APIENTRY glDepthRangeArrayv (GLuint first, GLsizei count, const GLdouble *v); +GLAPI void APIENTRY glDepthRangeIndexed (GLuint index, GLdouble n, GLdouble f); +GLAPI void APIENTRY glGetFloati_v (GLenum target, GLuint index, GLfloat *data); +GLAPI void APIENTRY glGetDoublei_v (GLenum target, GLuint index, GLdouble *data); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLVIEWPORTARRAYVPROC) (GLuint first, GLsizei count, const GLfloat *v); +typedef void (APIENTRYP PFNGLVIEWPORTINDEXEDFPROC) (GLuint index, GLfloat x, GLfloat y, GLfloat w, GLfloat h); +typedef void (APIENTRYP PFNGLVIEWPORTINDEXEDFVPROC) (GLuint index, const GLfloat *v); +typedef void (APIENTRYP PFNGLSCISSORARRAYVPROC) (GLuint first, GLsizei count, const GLint *v); +typedef void (APIENTRYP PFNGLSCISSORINDEXEDPROC) (GLuint index, GLint left, GLint bottom, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLSCISSORINDEXEDVPROC) (GLuint index, const GLint *v); +typedef void (APIENTRYP PFNGLDEPTHRANGEARRAYVPROC) (GLuint first, GLsizei count, const GLdouble *v); +typedef void (APIENTRYP PFNGLDEPTHRANGEINDEXEDPROC) (GLuint index, GLdouble n, GLdouble f); +typedef void (APIENTRYP PFNGLGETFLOATI_VPROC) (GLenum target, GLuint index, GLfloat *data); +typedef void (APIENTRYP PFNGLGETDOUBLEI_VPROC) (GLenum target, GLuint index, GLdouble *data); +#endif + +#ifndef GL_ARB_cl_event +#define GL_ARB_cl_event 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLsync APIENTRY glCreateSyncFromCLeventARB (struct _cl_context * context, struct _cl_event * event, GLbitfield flags); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLsync (APIENTRYP PFNGLCREATESYNCFROMCLEVENTARBPROC) (struct _cl_context * context, struct _cl_event * event, GLbitfield flags); +#endif + +#ifndef GL_ARB_debug_output +#define GL_ARB_debug_output 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDebugMessageControlARB (GLenum source, GLenum type, GLenum severity, GLsizei count, const GLuint *ids, GLboolean enabled); +GLAPI void APIENTRY glDebugMessageInsertARB (GLenum source, GLenum type, GLuint id, GLenum severity, GLsizei length, const GLchar *buf); +GLAPI void APIENTRY glDebugMessageCallbackARB (GLDEBUGPROCARB callback, const GLvoid *userParam); +GLAPI GLuint APIENTRY glGetDebugMessageLogARB (GLuint count, GLsizei bufsize, GLenum *sources, GLenum *types, GLuint *ids, GLenum *severities, GLsizei *lengths, GLchar *messageLog); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDEBUGMESSAGECONTROLARBPROC) (GLenum source, GLenum type, GLenum severity, GLsizei count, const GLuint *ids, GLboolean enabled); +typedef void (APIENTRYP PFNGLDEBUGMESSAGEINSERTARBPROC) (GLenum source, GLenum type, GLuint id, GLenum severity, GLsizei length, const GLchar *buf); +typedef void (APIENTRYP PFNGLDEBUGMESSAGECALLBACKARBPROC) (GLDEBUGPROCARB callback, const GLvoid *userParam); +typedef GLuint (APIENTRYP PFNGLGETDEBUGMESSAGELOGARBPROC) (GLuint count, GLsizei bufsize, GLenum *sources, GLenum *types, GLuint *ids, GLenum *severities, GLsizei *lengths, GLchar *messageLog); +#endif + +#ifndef GL_ARB_robustness +#define GL_ARB_robustness 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI GLenum APIENTRY glGetGraphicsResetStatusARB (void); +GLAPI void APIENTRY glGetnTexImageARB (GLenum target, GLint level, GLenum format, GLenum type, GLsizei bufSize, GLvoid *img); +GLAPI void APIENTRY glReadnPixelsARB (GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, GLsizei bufSize, GLvoid *data); +GLAPI void APIENTRY glGetnCompressedTexImageARB (GLenum target, GLint lod, GLsizei bufSize, GLvoid *img); +GLAPI void APIENTRY glGetnUniformfvARB (GLuint program, GLint location, GLsizei bufSize, GLfloat *params); +GLAPI void APIENTRY glGetnUniformivARB (GLuint program, GLint location, GLsizei bufSize, GLint *params); +GLAPI void APIENTRY glGetnUniformuivARB (GLuint program, GLint location, GLsizei bufSize, GLuint *params); +GLAPI void APIENTRY glGetnUniformdvARB (GLuint program, GLint location, GLsizei bufSize, GLdouble *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef GLenum (APIENTRYP PFNGLGETGRAPHICSRESETSTATUSARBPROC) (void); +typedef void (APIENTRYP PFNGLGETNTEXIMAGEARBPROC) (GLenum target, GLint level, GLenum format, GLenum type, GLsizei bufSize, GLvoid *img); +typedef void (APIENTRYP PFNGLREADNPIXELSARBPROC) (GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, GLsizei bufSize, GLvoid *data); +typedef void (APIENTRYP PFNGLGETNCOMPRESSEDTEXIMAGEARBPROC) (GLenum target, GLint lod, GLsizei bufSize, GLvoid *img); +typedef void (APIENTRYP PFNGLGETNUNIFORMFVARBPROC) (GLuint program, GLint location, GLsizei bufSize, GLfloat *params); +typedef void (APIENTRYP PFNGLGETNUNIFORMIVARBPROC) (GLuint program, GLint location, GLsizei bufSize, GLint *params); +typedef void (APIENTRYP PFNGLGETNUNIFORMUIVARBPROC) (GLuint program, GLint location, GLsizei bufSize, GLuint *params); +typedef void (APIENTRYP PFNGLGETNUNIFORMDVARBPROC) (GLuint program, GLint location, GLsizei bufSize, GLdouble *params); +#endif + +#ifndef GL_ARB_shader_stencil_export +#define GL_ARB_shader_stencil_export 1 +#endif + +#ifndef GL_ARB_base_instance +#define GL_ARB_base_instance 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawArraysInstancedBaseInstance (GLenum mode, GLint first, GLsizei count, GLsizei instancecount, GLuint baseinstance); +GLAPI void APIENTRY glDrawElementsInstancedBaseInstance (GLenum mode, GLsizei count, GLenum type, const void *indices, GLsizei instancecount, GLuint baseinstance); +GLAPI void APIENTRY glDrawElementsInstancedBaseVertexBaseInstance (GLenum mode, GLsizei count, GLenum type, const void *indices, GLsizei instancecount, GLint basevertex, GLuint baseinstance); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWARRAYSINSTANCEDBASEINSTANCEPROC) (GLenum mode, GLint first, GLsizei count, GLsizei instancecount, GLuint baseinstance); +typedef void (APIENTRYP PFNGLDRAWELEMENTSINSTANCEDBASEINSTANCEPROC) (GLenum mode, GLsizei count, GLenum type, const void *indices, GLsizei instancecount, GLuint baseinstance); +typedef void (APIENTRYP PFNGLDRAWELEMENTSINSTANCEDBASEVERTEXBASEINSTANCEPROC) (GLenum mode, GLsizei count, GLenum type, const void *indices, GLsizei instancecount, GLint basevertex, GLuint baseinstance); +#endif + +#ifndef GL_ARB_shading_language_420pack +#define GL_ARB_shading_language_420pack 1 +#endif + +#ifndef GL_ARB_transform_feedback_instanced +#define GL_ARB_transform_feedback_instanced 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDrawTransformFeedbackInstanced (GLenum mode, GLuint id, GLsizei instancecount); +GLAPI void APIENTRY glDrawTransformFeedbackStreamInstanced (GLenum mode, GLuint id, GLuint stream, GLsizei instancecount); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDRAWTRANSFORMFEEDBACKINSTANCEDPROC) (GLenum mode, GLuint id, GLsizei instancecount); +typedef void (APIENTRYP PFNGLDRAWTRANSFORMFEEDBACKSTREAMINSTANCEDPROC) (GLenum mode, GLuint id, GLuint stream, GLsizei instancecount); +#endif + +#ifndef GL_ARB_compressed_texture_pixel_storage +#define GL_ARB_compressed_texture_pixel_storage 1 +#endif + +#ifndef GL_ARB_conservative_depth +#define GL_ARB_conservative_depth 1 +#endif + +#ifndef GL_ARB_internalformat_query +#define GL_ARB_internalformat_query 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetInternalformativ (GLenum target, GLenum internalformat, GLenum pname, GLsizei bufSize, GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETINTERNALFORMATIVPROC) (GLenum target, GLenum internalformat, GLenum pname, GLsizei bufSize, GLint *params); +#endif + +#ifndef GL_ARB_map_buffer_alignment +#define GL_ARB_map_buffer_alignment 1 +#endif + +#ifndef GL_ARB_shader_atomic_counters +#define GL_ARB_shader_atomic_counters 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetActiveAtomicCounterBufferiv (GLuint program, GLuint bufferIndex, GLenum pname, GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETACTIVEATOMICCOUNTERBUFFERIVPROC) (GLuint program, GLuint bufferIndex, GLenum pname, GLint *params); +#endif + +#ifndef GL_ARB_shader_image_load_store +#define GL_ARB_shader_image_load_store 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBindImageTexture (GLuint unit, GLuint texture, GLint level, GLboolean layered, GLint layer, GLenum access, GLenum format); +GLAPI void APIENTRY glMemoryBarrier (GLbitfield barriers); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBINDIMAGETEXTUREPROC) (GLuint unit, GLuint texture, GLint level, GLboolean layered, GLint layer, GLenum access, GLenum format); +typedef void (APIENTRYP PFNGLMEMORYBARRIERPROC) (GLbitfield barriers); +#endif + +#ifndef GL_ARB_shading_language_packing +#define GL_ARB_shading_language_packing 1 +#endif + +#ifndef GL_ARB_texture_storage +#define GL_ARB_texture_storage 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glTexStorage1D (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width); +GLAPI void APIENTRY glTexStorage2D (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height); +GLAPI void APIENTRY glTexStorage3D (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth); +GLAPI void APIENTRY glTextureStorage1DEXT (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width); +GLAPI void APIENTRY glTextureStorage2DEXT (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height); +GLAPI void APIENTRY glTextureStorage3DEXT (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLTEXSTORAGE1DPROC) (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width); +typedef void (APIENTRYP PFNGLTEXSTORAGE2DPROC) (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLTEXSTORAGE3DPROC) (GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth); +typedef void (APIENTRYP PFNGLTEXTURESTORAGE1DEXTPROC) (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width); +typedef void (APIENTRYP PFNGLTEXTURESTORAGE2DEXTPROC) (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height); +typedef void (APIENTRYP PFNGLTEXTURESTORAGE3DEXTPROC) (GLuint texture, GLenum target, GLsizei levels, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth); +#endif + +#ifndef GL_KHR_texture_compression_astc_ldr +#define GL_KHR_texture_compression_astc_ldr 1 +#endif + +#ifndef GL_KHR_debug +#define GL_KHR_debug 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDebugMessageControl (GLenum source, GLenum type, GLenum severity, GLsizei count, const GLuint *ids, GLboolean enabled); +GLAPI void APIENTRY glDebugMessageInsert (GLenum source, GLenum type, GLuint id, GLenum severity, GLsizei length, const GLchar *buf); +GLAPI void APIENTRY glDebugMessageCallback (GLDEBUGPROC callback, const void *userParam); +GLAPI GLuint APIENTRY glGetDebugMessageLog (GLuint count, GLsizei bufsize, GLenum *sources, GLenum *types, GLuint *ids, GLenum *severities, GLsizei *lengths, GLchar *messageLog); +GLAPI void APIENTRY glPushDebugGroup (GLenum source, GLuint id, GLsizei length, const GLchar *message); +GLAPI void APIENTRY glPopDebugGroup (void); +GLAPI void APIENTRY glObjectLabel (GLenum identifier, GLuint name, GLsizei length, const GLchar *label); +GLAPI void APIENTRY glGetObjectLabel (GLenum identifier, GLuint name, GLsizei bufSize, GLsizei *length, GLchar *label); +GLAPI void APIENTRY glObjectPtrLabel (const void *ptr, GLsizei length, const GLchar *label); +GLAPI void APIENTRY glGetObjectPtrLabel (const void *ptr, GLsizei bufSize, GLsizei *length, GLchar *label); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDEBUGMESSAGECONTROLPROC) (GLenum source, GLenum type, GLenum severity, GLsizei count, const GLuint *ids, GLboolean enabled); +typedef void (APIENTRYP PFNGLDEBUGMESSAGEINSERTPROC) (GLenum source, GLenum type, GLuint id, GLenum severity, GLsizei length, const GLchar *buf); +typedef void (APIENTRYP PFNGLDEBUGMESSAGECALLBACKPROC) (GLDEBUGPROC callback, const void *userParam); +typedef GLuint (APIENTRYP PFNGLGETDEBUGMESSAGELOGPROC) (GLuint count, GLsizei bufsize, GLenum *sources, GLenum *types, GLuint *ids, GLenum *severities, GLsizei *lengths, GLchar *messageLog); +typedef void (APIENTRYP PFNGLPUSHDEBUGGROUPPROC) (GLenum source, GLuint id, GLsizei length, const GLchar *message); +typedef void (APIENTRYP PFNGLPOPDEBUGGROUPPROC) (void); +typedef void (APIENTRYP PFNGLOBJECTLABELPROC) (GLenum identifier, GLuint name, GLsizei length, const GLchar *label); +typedef void (APIENTRYP PFNGLGETOBJECTLABELPROC) (GLenum identifier, GLuint name, GLsizei bufSize, GLsizei *length, GLchar *label); +typedef void (APIENTRYP PFNGLOBJECTPTRLABELPROC) (const void *ptr, GLsizei length, const GLchar *label); +typedef void (APIENTRYP PFNGLGETOBJECTPTRLABELPROC) (const void *ptr, GLsizei bufSize, GLsizei *length, GLchar *label); +#endif + +#ifndef GL_ARB_arrays_of_arrays +#define GL_ARB_arrays_of_arrays 1 +#endif + +#ifndef GL_ARB_clear_buffer_object +#define GL_ARB_clear_buffer_object 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glClearBufferData (GLenum target, GLenum internalformat, GLenum format, GLenum type, const void *data); +GLAPI void APIENTRY glClearBufferSubData (GLenum target, GLenum internalformat, GLintptr offset, GLsizeiptr size, GLenum format, GLenum type, const void *data); +GLAPI void APIENTRY glClearNamedBufferDataEXT (GLuint buffer, GLenum internalformat, GLenum format, GLenum type, const void *data); +GLAPI void APIENTRY glClearNamedBufferSubDataEXT (GLuint buffer, GLenum internalformat, GLenum format, GLenum type, GLsizeiptr offset, GLsizeiptr size, const void *data); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLCLEARBUFFERDATAPROC) (GLenum target, GLenum internalformat, GLenum format, GLenum type, const void *data); +typedef void (APIENTRYP PFNGLCLEARBUFFERSUBDATAPROC) (GLenum target, GLenum internalformat, GLintptr offset, GLsizeiptr size, GLenum format, GLenum type, const void *data); +typedef void (APIENTRYP PFNGLCLEARNAMEDBUFFERDATAEXTPROC) (GLuint buffer, GLenum internalformat, GLenum format, GLenum type, const void *data); +typedef void (APIENTRYP PFNGLCLEARNAMEDBUFFERSUBDATAEXTPROC) (GLuint buffer, GLenum internalformat, GLenum format, GLenum type, GLsizeiptr offset, GLsizeiptr size, const void *data); +#endif + +#ifndef GL_ARB_compute_shader +#define GL_ARB_compute_shader 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glDispatchCompute (GLuint num_groups_x, GLuint num_groups_y, GLuint num_groups_z); +GLAPI void APIENTRY glDispatchComputeIndirect (GLintptr indirect); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLDISPATCHCOMPUTEPROC) (GLuint num_groups_x, GLuint num_groups_y, GLuint num_groups_z); +typedef void (APIENTRYP PFNGLDISPATCHCOMPUTEINDIRECTPROC) (GLintptr indirect); +#endif + +#ifndef GL_ARB_copy_image +#define GL_ARB_copy_image 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glCopyImageSubData (GLuint srcName, GLenum srcTarget, GLint srcLevel, GLint srcX, GLint srcY, GLint srcZ, GLuint dstName, GLenum dstTarget, GLint dstLevel, GLint dstX, GLint dstY, GLint dstZ, GLsizei srcWidth, GLsizei srcHeight, GLsizei srcDepth); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLCOPYIMAGESUBDATAPROC) (GLuint srcName, GLenum srcTarget, GLint srcLevel, GLint srcX, GLint srcY, GLint srcZ, GLuint dstName, GLenum dstTarget, GLint dstLevel, GLint dstX, GLint dstY, GLint dstZ, GLsizei srcWidth, GLsizei srcHeight, GLsizei srcDepth); +#endif + +#ifndef GL_ARB_texture_view +#define GL_ARB_texture_view 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glTextureView (GLuint texture, GLenum target, GLuint origtexture, GLenum internalformat, GLuint minlevel, GLuint numlevels, GLuint minlayer, GLuint numlayers); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLTEXTUREVIEWPROC) (GLuint texture, GLenum target, GLuint origtexture, GLenum internalformat, GLuint minlevel, GLuint numlevels, GLuint minlayer, GLuint numlayers); +#endif + +#ifndef GL_ARB_vertex_attrib_binding +#define GL_ARB_vertex_attrib_binding 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glBindVertexBuffer (GLuint bindingindex, GLuint buffer, GLintptr offset, GLsizei stride); +GLAPI void APIENTRY glVertexAttribFormat (GLuint attribindex, GLint size, GLenum type, GLboolean normalized, GLuint relativeoffset); +GLAPI void APIENTRY glVertexAttribIFormat (GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +GLAPI void APIENTRY glVertexAttribLFormat (GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +GLAPI void APIENTRY glVertexAttribBinding (GLuint attribindex, GLuint bindingindex); +GLAPI void APIENTRY glVertexBindingDivisor (GLuint bindingindex, GLuint divisor); +GLAPI void APIENTRY glVertexArrayBindVertexBufferEXT (GLuint vaobj, GLuint bindingindex, GLuint buffer, GLintptr offset, GLsizei stride); +GLAPI void APIENTRY glVertexArrayVertexAttribFormatEXT (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLboolean normalized, GLuint relativeoffset); +GLAPI void APIENTRY glVertexArrayVertexAttribIFormatEXT (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +GLAPI void APIENTRY glVertexArrayVertexAttribLFormatEXT (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +GLAPI void APIENTRY glVertexArrayVertexAttribBindingEXT (GLuint vaobj, GLuint attribindex, GLuint bindingindex); +GLAPI void APIENTRY glVertexArrayVertexBindingDivisorEXT (GLuint vaobj, GLuint bindingindex, GLuint divisor); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLBINDVERTEXBUFFERPROC) (GLuint bindingindex, GLuint buffer, GLintptr offset, GLsizei stride); +typedef void (APIENTRYP PFNGLVERTEXATTRIBFORMATPROC) (GLuint attribindex, GLint size, GLenum type, GLboolean normalized, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXATTRIBIFORMATPROC) (GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXATTRIBLFORMATPROC) (GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXATTRIBBINDINGPROC) (GLuint attribindex, GLuint bindingindex); +typedef void (APIENTRYP PFNGLVERTEXBINDINGDIVISORPROC) (GLuint bindingindex, GLuint divisor); +typedef void (APIENTRYP PFNGLVERTEXARRAYBINDVERTEXBUFFEREXTPROC) (GLuint vaobj, GLuint bindingindex, GLuint buffer, GLintptr offset, GLsizei stride); +typedef void (APIENTRYP PFNGLVERTEXARRAYVERTEXATTRIBFORMATEXTPROC) (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLboolean normalized, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXARRAYVERTEXATTRIBIFORMATEXTPROC) (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXARRAYVERTEXATTRIBLFORMATEXTPROC) (GLuint vaobj, GLuint attribindex, GLint size, GLenum type, GLuint relativeoffset); +typedef void (APIENTRYP PFNGLVERTEXARRAYVERTEXATTRIBBINDINGEXTPROC) (GLuint vaobj, GLuint attribindex, GLuint bindingindex); +typedef void (APIENTRYP PFNGLVERTEXARRAYVERTEXBINDINGDIVISOREXTPROC) (GLuint vaobj, GLuint bindingindex, GLuint divisor); +#endif + +#ifndef GL_ARB_robustness_isolation +#define GL_ARB_robustness_isolation 1 +#endif + +#ifndef GL_ARB_ES3_compatibility +#define GL_ARB_ES3_compatibility 1 +#endif + +#ifndef GL_ARB_explicit_uniform_location +#define GL_ARB_explicit_uniform_location 1 +#endif + +#ifndef GL_ARB_fragment_layer_viewport +#define GL_ARB_fragment_layer_viewport 1 +#endif + +#ifndef GL_ARB_framebuffer_no_attachments +#define GL_ARB_framebuffer_no_attachments 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glFramebufferParameteri (GLenum target, GLenum pname, GLint param); +GLAPI void APIENTRY glGetFramebufferParameteriv (GLenum target, GLenum pname, GLint *params); +GLAPI void APIENTRY glNamedFramebufferParameteriEXT (GLuint framebuffer, GLenum pname, GLint param); +GLAPI void APIENTRY glGetNamedFramebufferParameterivEXT (GLuint framebuffer, GLenum pname, GLint *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLFRAMEBUFFERPARAMETERIPROC) (GLenum target, GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLGETFRAMEBUFFERPARAMETERIVPROC) (GLenum target, GLenum pname, GLint *params); +typedef void (APIENTRYP PFNGLNAMEDFRAMEBUFFERPARAMETERIEXTPROC) (GLuint framebuffer, GLenum pname, GLint param); +typedef void (APIENTRYP PFNGLGETNAMEDFRAMEBUFFERPARAMETERIVEXTPROC) (GLuint framebuffer, GLenum pname, GLint *params); +#endif + +#ifndef GL_ARB_internalformat_query2 +#define GL_ARB_internalformat_query2 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetInternalformati64v (GLenum target, GLenum internalformat, GLenum pname, GLsizei bufSize, GLint64 *params); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETINTERNALFORMATI64VPROC) (GLenum target, GLenum internalformat, GLenum pname, GLsizei bufSize, GLint64 *params); +#endif + +#ifndef GL_ARB_invalidate_subdata +#define GL_ARB_invalidate_subdata 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glInvalidateTexSubImage (GLuint texture, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth); +GLAPI void APIENTRY glInvalidateTexImage (GLuint texture, GLint level); +GLAPI void APIENTRY glInvalidateBufferSubData (GLuint buffer, GLintptr offset, GLsizeiptr length); +GLAPI void APIENTRY glInvalidateBufferData (GLuint buffer); +GLAPI void APIENTRY glInvalidateFramebuffer (GLenum target, GLsizei numAttachments, const GLenum *attachments); +GLAPI void APIENTRY glInvalidateSubFramebuffer (GLenum target, GLsizei numAttachments, const GLenum *attachments, GLint x, GLint y, GLsizei width, GLsizei height); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLINVALIDATETEXSUBIMAGEPROC) (GLuint texture, GLint level, GLint xoffset, GLint yoffset, GLint zoffset, GLsizei width, GLsizei height, GLsizei depth); +typedef void (APIENTRYP PFNGLINVALIDATETEXIMAGEPROC) (GLuint texture, GLint level); +typedef void (APIENTRYP PFNGLINVALIDATEBUFFERSUBDATAPROC) (GLuint buffer, GLintptr offset, GLsizeiptr length); +typedef void (APIENTRYP PFNGLINVALIDATEBUFFERDATAPROC) (GLuint buffer); +typedef void (APIENTRYP PFNGLINVALIDATEFRAMEBUFFERPROC) (GLenum target, GLsizei numAttachments, const GLenum *attachments); +typedef void (APIENTRYP PFNGLINVALIDATESUBFRAMEBUFFERPROC) (GLenum target, GLsizei numAttachments, const GLenum *attachments, GLint x, GLint y, GLsizei width, GLsizei height); +#endif + +#ifndef GL_ARB_multi_draw_indirect +#define GL_ARB_multi_draw_indirect 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glMultiDrawArraysIndirect (GLenum mode, const void *indirect, GLsizei drawcount, GLsizei stride); +GLAPI void APIENTRY glMultiDrawElementsIndirect (GLenum mode, GLenum type, const void *indirect, GLsizei drawcount, GLsizei stride); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLMULTIDRAWARRAYSINDIRECTPROC) (GLenum mode, const void *indirect, GLsizei drawcount, GLsizei stride); +typedef void (APIENTRYP PFNGLMULTIDRAWELEMENTSINDIRECTPROC) (GLenum mode, GLenum type, const void *indirect, GLsizei drawcount, GLsizei stride); +#endif + +#ifndef GL_ARB_program_interface_query +#define GL_ARB_program_interface_query 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glGetProgramInterfaceiv (GLuint program, GLenum programInterface, GLenum pname, GLint *params); +GLAPI GLuint APIENTRY glGetProgramResourceIndex (GLuint program, GLenum programInterface, const GLchar *name); +GLAPI void APIENTRY glGetProgramResourceName (GLuint program, GLenum programInterface, GLuint index, GLsizei bufSize, GLsizei *length, GLchar *name); +GLAPI void APIENTRY glGetProgramResourceiv (GLuint program, GLenum programInterface, GLuint index, GLsizei propCount, const GLenum *props, GLsizei bufSize, GLsizei *length, GLint *params); +GLAPI GLint APIENTRY glGetProgramResourceLocation (GLuint program, GLenum programInterface, const GLchar *name); +GLAPI GLint APIENTRY glGetProgramResourceLocationIndex (GLuint program, GLenum programInterface, const GLchar *name); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLGETPROGRAMINTERFACEIVPROC) (GLuint program, GLenum programInterface, GLenum pname, GLint *params); +typedef GLuint (APIENTRYP PFNGLGETPROGRAMRESOURCEINDEXPROC) (GLuint program, GLenum programInterface, const GLchar *name); +typedef void (APIENTRYP PFNGLGETPROGRAMRESOURCENAMEPROC) (GLuint program, GLenum programInterface, GLuint index, GLsizei bufSize, GLsizei *length, GLchar *name); +typedef void (APIENTRYP PFNGLGETPROGRAMRESOURCEIVPROC) (GLuint program, GLenum programInterface, GLuint index, GLsizei propCount, const GLenum *props, GLsizei bufSize, GLsizei *length, GLint *params); +typedef GLint (APIENTRYP PFNGLGETPROGRAMRESOURCELOCATIONPROC) (GLuint program, GLenum programInterface, const GLchar *name); +typedef GLint (APIENTRYP PFNGLGETPROGRAMRESOURCELOCATIONINDEXPROC) (GLuint program, GLenum programInterface, const GLchar *name); +#endif + +#ifndef GL_ARB_robust_buffer_access_behavior +#define GL_ARB_robust_buffer_access_behavior 1 +#endif + +#ifndef GL_ARB_shader_image_size +#define GL_ARB_shader_image_size 1 +#endif + +#ifndef GL_ARB_shader_storage_buffer_object +#define GL_ARB_shader_storage_buffer_object 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glShaderStorageBlockBinding (GLuint program, GLuint storageBlockIndex, GLuint storageBlockBinding); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLSHADERSTORAGEBLOCKBINDINGPROC) (GLuint program, GLuint storageBlockIndex, GLuint storageBlockBinding); +#endif + +#ifndef GL_ARB_stencil_texturing +#define GL_ARB_stencil_texturing 1 +#endif + +#ifndef GL_ARB_texture_buffer_range +#define GL_ARB_texture_buffer_range 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glTexBufferRange (GLenum target, GLenum internalformat, GLuint buffer, GLintptr offset, GLsizeiptr size); +GLAPI void APIENTRY glTextureBufferRangeEXT (GLuint texture, GLenum target, GLenum internalformat, GLuint buffer, GLintptr offset, GLsizeiptr size); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLTEXBUFFERRANGEPROC) (GLenum target, GLenum internalformat, GLuint buffer, GLintptr offset, GLsizeiptr size); +typedef void (APIENTRYP PFNGLTEXTUREBUFFERRANGEEXTPROC) (GLuint texture, GLenum target, GLenum internalformat, GLuint buffer, GLintptr offset, GLsizeiptr size); +#endif + +#ifndef GL_ARB_texture_query_levels +#define GL_ARB_texture_query_levels 1 +#endif + +#ifndef GL_ARB_texture_storage_multisample +#define GL_ARB_texture_storage_multisample 1 +#ifdef GLCOREARB_PROTOTYPES +GLAPI void APIENTRY glTexStorage2DMultisample (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +GLAPI void APIENTRY glTexStorage3DMultisample (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +GLAPI void APIENTRY glTextureStorage2DMultisampleEXT (GLuint texture, GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +GLAPI void APIENTRY glTextureStorage3DMultisampleEXT (GLuint texture, GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +#endif /* GLCOREARB_PROTOTYPES */ +typedef void (APIENTRYP PFNGLTEXSTORAGE2DMULTISAMPLEPROC) (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +typedef void (APIENTRYP PFNGLTEXSTORAGE3DMULTISAMPLEPROC) (GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +typedef void (APIENTRYP PFNGLTEXTURESTORAGE2DMULTISAMPLEEXTPROC) (GLuint texture, GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLboolean fixedsamplelocations); +typedef void (APIENTRYP PFNGLTEXTURESTORAGE3DMULTISAMPLEEXTPROC) (GLuint texture, GLenum target, GLsizei samples, GLenum internalformat, GLsizei width, GLsizei height, GLsizei depth, GLboolean fixedsamplelocations); +#endif + + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/gui/dependencies/imguizmo/ImGuizmo.cpp b/gui/dependencies/imguizmo/ImGuizmo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48ca7086cb118c7fed9c16c4f3e5d470072da688 --- /dev/null +++ b/gui/dependencies/imguizmo/ImGuizmo.cpp @@ -0,0 +1,2883 @@ +// https://github.com/CedricGuillemet/ImGuizmo +// v 1.84 WIP +// +// The MIT License(MIT) +// +// Copyright(c) 2021 Cedric Guillemet +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files(the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions : +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// + +#ifndef IMGUI_DEFINE_MATH_OPERATORS +#define IMGUI_DEFINE_MATH_OPERATORS +#endif +#include "imgui.h" +#include "imgui_internal.h" +#include "ImGuizmo.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#endif +#if !defined(_MSC_VER) && !defined(__MINGW64_VERSION_MAJOR) +#define _malloca(x) alloca(x) +#define _freea(x) +#endif + +// includes patches for multiview from +// https://github.com/CedricGuillemet/ImGuizmo/issues/15 + +namespace IMGUIZMO_NAMESPACE +{ + static const float ZPI = 3.14159265358979323846f; + static const float RAD2DEG = (180.f / ZPI); + static const float DEG2RAD = (ZPI / 180.f); + const float screenRotateSize = 0.06f; + // scale a bit so translate axis do not touch when in universal + const float rotationDisplayFactor = 1.2f; + + static OPERATION operator&(OPERATION lhs, OPERATION rhs) + { + return static_cast(static_cast(lhs) & static_cast(rhs)); + } + + static bool operator!=(OPERATION lhs, int rhs) + { + return static_cast(lhs) != rhs; + } + + static bool operator==(OPERATION lhs, int rhs) + { + return static_cast(lhs) == rhs; + } + + static bool Intersects(OPERATION lhs, OPERATION rhs) + { + return (lhs & rhs) != 0; + } + + // True if lhs contains rhs + static bool Contains(OPERATION lhs, OPERATION rhs) + { + return (lhs & rhs) == rhs; + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // utility and math + + void FPU_MatrixF_x_MatrixF(const float* a, const float* b, float* r) + { + r[0] = a[0] * b[0] + a[1] * b[4] + a[2] * b[8] + a[3] * b[12]; + r[1] = a[0] * b[1] + a[1] * b[5] + a[2] * b[9] + a[3] * b[13]; + r[2] = a[0] * b[2] + a[1] * b[6] + a[2] * b[10] + a[3] * b[14]; + r[3] = a[0] * b[3] + a[1] * b[7] + a[2] * b[11] + a[3] * b[15]; + + r[4] = a[4] * b[0] + a[5] * b[4] + a[6] * b[8] + a[7] * b[12]; + r[5] = a[4] * b[1] + a[5] * b[5] + a[6] * b[9] + a[7] * b[13]; + r[6] = a[4] * b[2] + a[5] * b[6] + a[6] * b[10] + a[7] * b[14]; + r[7] = a[4] * b[3] + a[5] * b[7] + a[6] * b[11] + a[7] * b[15]; + + r[8] = a[8] * b[0] + a[9] * b[4] + a[10] * b[8] + a[11] * b[12]; + r[9] = a[8] * b[1] + a[9] * b[5] + a[10] * b[9] + a[11] * b[13]; + r[10] = a[8] * b[2] + a[9] * b[6] + a[10] * b[10] + a[11] * b[14]; + r[11] = a[8] * b[3] + a[9] * b[7] + a[10] * b[11] + a[11] * b[15]; + + r[12] = a[12] * b[0] + a[13] * b[4] + a[14] * b[8] + a[15] * b[12]; + r[13] = a[12] * b[1] + a[13] * b[5] + a[14] * b[9] + a[15] * b[13]; + r[14] = a[12] * b[2] + a[13] * b[6] + a[14] * b[10] + a[15] * b[14]; + r[15] = a[12] * b[3] + a[13] * b[7] + a[14] * b[11] + a[15] * b[15]; + } + + void Frustum(float left, float right, float bottom, float top, float znear, float zfar, float* m16) + { + float temp, temp2, temp3, temp4; + temp = 2.0f * znear; + temp2 = right - left; + temp3 = top - bottom; + temp4 = zfar - znear; + m16[0] = temp / temp2; + m16[1] = 0.0; + m16[2] = 0.0; + m16[3] = 0.0; + m16[4] = 0.0; + m16[5] = temp / temp3; + m16[6] = 0.0; + m16[7] = 0.0; + m16[8] = (right + left) / temp2; + m16[9] = (top + bottom) / temp3; + m16[10] = (-zfar - znear) / temp4; + m16[11] = -1.0f; + m16[12] = 0.0; + m16[13] = 0.0; + m16[14] = (-temp * zfar) / temp4; + m16[15] = 0.0; + } + + void Perspective(float fovyInDegrees, float aspectRatio, float znear, float zfar, float* m16) + { + float ymax, xmax; + ymax = znear * tanf(fovyInDegrees * DEG2RAD); + xmax = ymax * aspectRatio; + Frustum(-xmax, xmax, -ymax, ymax, znear, zfar, m16); + } + + void Cross(const float* a, const float* b, float* r) + { + r[0] = a[1] * b[2] - a[2] * b[1]; + r[1] = a[2] * b[0] - a[0] * b[2]; + r[2] = a[0] * b[1] - a[1] * b[0]; + } + + float Dot(const float* a, const float* b) + { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; + } + + void Normalize(const float* a, float* r) + { + float il = 1.f / (sqrtf(Dot(a, a)) + FLT_EPSILON); + r[0] = a[0] * il; + r[1] = a[1] * il; + r[2] = a[2] * il; + } + + void LookAt(const float* eye, const float* at, const float* up, float* m16) + { + float X[3], Y[3], Z[3], tmp[3]; + + tmp[0] = eye[0] - at[0]; + tmp[1] = eye[1] - at[1]; + tmp[2] = eye[2] - at[2]; + Normalize(tmp, Z); + Normalize(up, Y); + Cross(Y, Z, tmp); + Normalize(tmp, X); + Cross(Z, X, tmp); + Normalize(tmp, Y); + + m16[0] = X[0]; + m16[1] = Y[0]; + m16[2] = Z[0]; + m16[3] = 0.0f; + m16[4] = X[1]; + m16[5] = Y[1]; + m16[6] = Z[1]; + m16[7] = 0.0f; + m16[8] = X[2]; + m16[9] = Y[2]; + m16[10] = Z[2]; + m16[11] = 0.0f; + m16[12] = -Dot(X, eye); + m16[13] = -Dot(Y, eye); + m16[14] = -Dot(Z, eye); + m16[15] = 1.0f; + } + + template T Clamp(T x, T y, T z) { return ((x < y) ? y : ((x > z) ? z : x)); } + template T max(T x, T y) { return (x > y) ? x : y; } + template T min(T x, T y) { return (x < y) ? x : y; } + template bool IsWithin(T x, T y, T z) { return (x >= y) && (x <= z); } + + struct matrix_t; + struct vec_t + { + public: + float x, y, z, w; + + void Lerp(const vec_t& v, float t) + { + x += (v.x - x) * t; + y += (v.y - y) * t; + z += (v.z - z) * t; + w += (v.w - w) * t; + } + + void Set(float v) { x = y = z = w = v; } + void Set(float _x, float _y, float _z = 0.f, float _w = 0.f) { x = _x; y = _y; z = _z; w = _w; } + + vec_t& operator -= (const vec_t& v) { x -= v.x; y -= v.y; z -= v.z; w -= v.w; return *this; } + vec_t& operator += (const vec_t& v) { x += v.x; y += v.y; z += v.z; w += v.w; return *this; } + vec_t& operator *= (const vec_t& v) { x *= v.x; y *= v.y; z *= v.z; w *= v.w; return *this; } + vec_t& operator *= (float v) { x *= v; y *= v; z *= v; w *= v; return *this; } + + vec_t operator * (float f) const; + vec_t operator - () const; + vec_t operator - (const vec_t& v) const; + vec_t operator + (const vec_t& v) const; + vec_t operator * (const vec_t& v) const; + + const vec_t& operator + () const { return (*this); } + float Length() const { return sqrtf(x * x + y * y + z * z); }; + float LengthSq() const { return (x * x + y * y + z * z); }; + vec_t Normalize() { (*this) *= (1.f / Length()); return (*this); } + vec_t Normalize(const vec_t& v) { this->Set(v.x, v.y, v.z, v.w); this->Normalize(); return (*this); } + vec_t Abs() const; + + void Cross(const vec_t& v) + { + vec_t res; + res.x = y * v.z - z * v.y; + res.y = z * v.x - x * v.z; + res.z = x * v.y - y * v.x; + + x = res.x; + y = res.y; + z = res.z; + w = 0.f; + } + + void Cross(const vec_t& v1, const vec_t& v2) + { + x = v1.y * v2.z - v1.z * v2.y; + y = v1.z * v2.x - v1.x * v2.z; + z = v1.x * v2.y - v1.y * v2.x; + w = 0.f; + } + + float Dot(const vec_t& v) const + { + return (x * v.x) + (y * v.y) + (z * v.z) + (w * v.w); + } + + float Dot3(const vec_t& v) const + { + return (x * v.x) + (y * v.y) + (z * v.z); + } + + void Transform(const matrix_t& matrix); + void Transform(const vec_t& s, const matrix_t& matrix); + + void TransformVector(const matrix_t& matrix); + void TransformPoint(const matrix_t& matrix); + void TransformVector(const vec_t& v, const matrix_t& matrix) { (*this) = v; this->TransformVector(matrix); } + void TransformPoint(const vec_t& v, const matrix_t& matrix) { (*this) = v; this->TransformPoint(matrix); } + + float& operator [] (size_t index) { return ((float*)&x)[index]; } + const float& operator [] (size_t index) const { return ((float*)&x)[index]; } + bool operator!=(const vec_t& other) const { return memcmp(this, &other, sizeof(vec_t)); } + }; + + vec_t makeVect(float _x, float _y, float _z = 0.f, float _w = 0.f) { vec_t res; res.x = _x; res.y = _y; res.z = _z; res.w = _w; return res; } + vec_t makeVect(ImVec2 v) { vec_t res; res.x = v.x; res.y = v.y; res.z = 0.f; res.w = 0.f; return res; } + vec_t vec_t::operator * (float f) const { return makeVect(x * f, y * f, z * f, w * f); } + vec_t vec_t::operator - () const { return makeVect(-x, -y, -z, -w); } + vec_t vec_t::operator - (const vec_t& v) const { return makeVect(x - v.x, y - v.y, z - v.z, w - v.w); } + vec_t vec_t::operator + (const vec_t& v) const { return makeVect(x + v.x, y + v.y, z + v.z, w + v.w); } + vec_t vec_t::operator * (const vec_t& v) const { return makeVect(x * v.x, y * v.y, z * v.z, w * v.w); } + vec_t vec_t::Abs() const { return makeVect(fabsf(x), fabsf(y), fabsf(z)); } + + vec_t Normalized(const vec_t& v) { vec_t res; res = v; res.Normalize(); return res; } + vec_t Cross(const vec_t& v1, const vec_t& v2) + { + vec_t res; + res.x = v1.y * v2.z - v1.z * v2.y; + res.y = v1.z * v2.x - v1.x * v2.z; + res.z = v1.x * v2.y - v1.y * v2.x; + res.w = 0.f; + return res; + } + + float Dot(const vec_t& v1, const vec_t& v2) + { + return (v1.x * v2.x) + (v1.y * v2.y) + (v1.z * v2.z); + } + + vec_t BuildPlan(const vec_t& p_point1, const vec_t& p_normal) + { + vec_t normal, res; + normal.Normalize(p_normal); + res.w = normal.Dot(p_point1); + res.x = normal.x; + res.y = normal.y; + res.z = normal.z; + return res; + } + + struct matrix_t + { + public: + + union + { + float m[4][4]; + float m16[16]; + struct + { + vec_t right, up, dir, position; + } v; + vec_t component[4]; + }; + + matrix_t(const matrix_t& other) { memcpy(&m16[0], &other.m16[0], sizeof(float) * 16); } + matrix_t() {} + + operator float* () { return m16; } + operator const float* () const { return m16; } + void Translation(float _x, float _y, float _z) { this->Translation(makeVect(_x, _y, _z)); } + + void Translation(const vec_t& vt) + { + v.right.Set(1.f, 0.f, 0.f, 0.f); + v.up.Set(0.f, 1.f, 0.f, 0.f); + v.dir.Set(0.f, 0.f, 1.f, 0.f); + v.position.Set(vt.x, vt.y, vt.z, 1.f); + } + + void Scale(float _x, float _y, float _z) + { + v.right.Set(_x, 0.f, 0.f, 0.f); + v.up.Set(0.f, _y, 0.f, 0.f); + v.dir.Set(0.f, 0.f, _z, 0.f); + v.position.Set(0.f, 0.f, 0.f, 1.f); + } + void Scale(const vec_t& s) { Scale(s.x, s.y, s.z); } + + matrix_t& operator *= (const matrix_t& mat) + { + matrix_t tmpMat; + tmpMat = *this; + tmpMat.Multiply(mat); + *this = tmpMat; + return *this; + } + matrix_t operator * (const matrix_t& mat) const + { + matrix_t matT; + matT.Multiply(*this, mat); + return matT; + } + + void Multiply(const matrix_t& matrix) + { + matrix_t tmp; + tmp = *this; + + FPU_MatrixF_x_MatrixF((float*)&tmp, (float*)&matrix, (float*)this); + } + + void Multiply(const matrix_t& m1, const matrix_t& m2) + { + FPU_MatrixF_x_MatrixF((float*)&m1, (float*)&m2, (float*)this); + } + + float GetDeterminant() const + { + return m[0][0] * m[1][1] * m[2][2] + m[0][1] * m[1][2] * m[2][0] + m[0][2] * m[1][0] * m[2][1] - + m[0][2] * m[1][1] * m[2][0] - m[0][1] * m[1][0] * m[2][2] - m[0][0] * m[1][2] * m[2][1]; + } + + float Inverse(const matrix_t& srcMatrix, bool affine = false); + void SetToIdentity() + { + v.right.Set(1.f, 0.f, 0.f, 0.f); + v.up.Set(0.f, 1.f, 0.f, 0.f); + v.dir.Set(0.f, 0.f, 1.f, 0.f); + v.position.Set(0.f, 0.f, 0.f, 1.f); + } + void Transpose() + { + matrix_t tmpm; + for (int l = 0; l < 4; l++) + { + for (int c = 0; c < 4; c++) + { + tmpm.m[l][c] = m[c][l]; + } + } + (*this) = tmpm; + } + + void RotationAxis(const vec_t& axis, float angle); + + void OrthoNormalize() + { + v.right.Normalize(); + v.up.Normalize(); + v.dir.Normalize(); + } + }; + + void vec_t::Transform(const matrix_t& matrix) + { + vec_t out; + + out.x = x * matrix.m[0][0] + y * matrix.m[1][0] + z * matrix.m[2][0] + w * matrix.m[3][0]; + out.y = x * matrix.m[0][1] + y * matrix.m[1][1] + z * matrix.m[2][1] + w * matrix.m[3][1]; + out.z = x * matrix.m[0][2] + y * matrix.m[1][2] + z * matrix.m[2][2] + w * matrix.m[3][2]; + out.w = x * matrix.m[0][3] + y * matrix.m[1][3] + z * matrix.m[2][3] + w * matrix.m[3][3]; + + x = out.x; + y = out.y; + z = out.z; + w = out.w; + } + + void vec_t::Transform(const vec_t& s, const matrix_t& matrix) + { + *this = s; + Transform(matrix); + } + + void vec_t::TransformPoint(const matrix_t& matrix) + { + vec_t out; + + out.x = x * matrix.m[0][0] + y * matrix.m[1][0] + z * matrix.m[2][0] + matrix.m[3][0]; + out.y = x * matrix.m[0][1] + y * matrix.m[1][1] + z * matrix.m[2][1] + matrix.m[3][1]; + out.z = x * matrix.m[0][2] + y * matrix.m[1][2] + z * matrix.m[2][2] + matrix.m[3][2]; + out.w = x * matrix.m[0][3] + y * matrix.m[1][3] + z * matrix.m[2][3] + matrix.m[3][3]; + + x = out.x; + y = out.y; + z = out.z; + w = out.w; + } + + void vec_t::TransformVector(const matrix_t& matrix) + { + vec_t out; + + out.x = x * matrix.m[0][0] + y * matrix.m[1][0] + z * matrix.m[2][0]; + out.y = x * matrix.m[0][1] + y * matrix.m[1][1] + z * matrix.m[2][1]; + out.z = x * matrix.m[0][2] + y * matrix.m[1][2] + z * matrix.m[2][2]; + out.w = x * matrix.m[0][3] + y * matrix.m[1][3] + z * matrix.m[2][3]; + + x = out.x; + y = out.y; + z = out.z; + w = out.w; + } + + float matrix_t::Inverse(const matrix_t& srcMatrix, bool affine) + { + float det = 0; + + if (affine) + { + det = GetDeterminant(); + float s = 1 / det; + m[0][0] = (srcMatrix.m[1][1] * srcMatrix.m[2][2] - srcMatrix.m[1][2] * srcMatrix.m[2][1]) * s; + m[0][1] = (srcMatrix.m[2][1] * srcMatrix.m[0][2] - srcMatrix.m[2][2] * srcMatrix.m[0][1]) * s; + m[0][2] = (srcMatrix.m[0][1] * srcMatrix.m[1][2] - srcMatrix.m[0][2] * srcMatrix.m[1][1]) * s; + m[1][0] = (srcMatrix.m[1][2] * srcMatrix.m[2][0] - srcMatrix.m[1][0] * srcMatrix.m[2][2]) * s; + m[1][1] = (srcMatrix.m[2][2] * srcMatrix.m[0][0] - srcMatrix.m[2][0] * srcMatrix.m[0][2]) * s; + m[1][2] = (srcMatrix.m[0][2] * srcMatrix.m[1][0] - srcMatrix.m[0][0] * srcMatrix.m[1][2]) * s; + m[2][0] = (srcMatrix.m[1][0] * srcMatrix.m[2][1] - srcMatrix.m[1][1] * srcMatrix.m[2][0]) * s; + m[2][1] = (srcMatrix.m[2][0] * srcMatrix.m[0][1] - srcMatrix.m[2][1] * srcMatrix.m[0][0]) * s; + m[2][2] = (srcMatrix.m[0][0] * srcMatrix.m[1][1] - srcMatrix.m[0][1] * srcMatrix.m[1][0]) * s; + m[3][0] = -(m[0][0] * srcMatrix.m[3][0] + m[1][0] * srcMatrix.m[3][1] + m[2][0] * srcMatrix.m[3][2]); + m[3][1] = -(m[0][1] * srcMatrix.m[3][0] + m[1][1] * srcMatrix.m[3][1] + m[2][1] * srcMatrix.m[3][2]); + m[3][2] = -(m[0][2] * srcMatrix.m[3][0] + m[1][2] * srcMatrix.m[3][1] + m[2][2] * srcMatrix.m[3][2]); + } + else + { + // transpose matrix + float src[16]; + for (int i = 0; i < 4; ++i) + { + src[i] = srcMatrix.m16[i * 4]; + src[i + 4] = srcMatrix.m16[i * 4 + 1]; + src[i + 8] = srcMatrix.m16[i * 4 + 2]; + src[i + 12] = srcMatrix.m16[i * 4 + 3]; + } + + // calculate pairs for first 8 elements (cofactors) + float tmp[12]; // temp array for pairs + tmp[0] = src[10] * src[15]; + tmp[1] = src[11] * src[14]; + tmp[2] = src[9] * src[15]; + tmp[3] = src[11] * src[13]; + tmp[4] = src[9] * src[14]; + tmp[5] = src[10] * src[13]; + tmp[6] = src[8] * src[15]; + tmp[7] = src[11] * src[12]; + tmp[8] = src[8] * src[14]; + tmp[9] = src[10] * src[12]; + tmp[10] = src[8] * src[13]; + tmp[11] = src[9] * src[12]; + + // calculate first 8 elements (cofactors) + m16[0] = (tmp[0] * src[5] + tmp[3] * src[6] + tmp[4] * src[7]) - (tmp[1] * src[5] + tmp[2] * src[6] + tmp[5] * src[7]); + m16[1] = (tmp[1] * src[4] + tmp[6] * src[6] + tmp[9] * src[7]) - (tmp[0] * src[4] + tmp[7] * src[6] + tmp[8] * src[7]); + m16[2] = (tmp[2] * src[4] + tmp[7] * src[5] + tmp[10] * src[7]) - (tmp[3] * src[4] + tmp[6] * src[5] + tmp[11] * src[7]); + m16[3] = (tmp[5] * src[4] + tmp[8] * src[5] + tmp[11] * src[6]) - (tmp[4] * src[4] + tmp[9] * src[5] + tmp[10] * src[6]); + m16[4] = (tmp[1] * src[1] + tmp[2] * src[2] + tmp[5] * src[3]) - (tmp[0] * src[1] + tmp[3] * src[2] + tmp[4] * src[3]); + m16[5] = (tmp[0] * src[0] + tmp[7] * src[2] + tmp[8] * src[3]) - (tmp[1] * src[0] + tmp[6] * src[2] + tmp[9] * src[3]); + m16[6] = (tmp[3] * src[0] + tmp[6] * src[1] + tmp[11] * src[3]) - (tmp[2] * src[0] + tmp[7] * src[1] + tmp[10] * src[3]); + m16[7] = (tmp[4] * src[0] + tmp[9] * src[1] + tmp[10] * src[2]) - (tmp[5] * src[0] + tmp[8] * src[1] + tmp[11] * src[2]); + + // calculate pairs for second 8 elements (cofactors) + tmp[0] = src[2] * src[7]; + tmp[1] = src[3] * src[6]; + tmp[2] = src[1] * src[7]; + tmp[3] = src[3] * src[5]; + tmp[4] = src[1] * src[6]; + tmp[5] = src[2] * src[5]; + tmp[6] = src[0] * src[7]; + tmp[7] = src[3] * src[4]; + tmp[8] = src[0] * src[6]; + tmp[9] = src[2] * src[4]; + tmp[10] = src[0] * src[5]; + tmp[11] = src[1] * src[4]; + + // calculate second 8 elements (cofactors) + m16[8] = (tmp[0] * src[13] + tmp[3] * src[14] + tmp[4] * src[15]) - (tmp[1] * src[13] + tmp[2] * src[14] + tmp[5] * src[15]); + m16[9] = (tmp[1] * src[12] + tmp[6] * src[14] + tmp[9] * src[15]) - (tmp[0] * src[12] + tmp[7] * src[14] + tmp[8] * src[15]); + m16[10] = (tmp[2] * src[12] + tmp[7] * src[13] + tmp[10] * src[15]) - (tmp[3] * src[12] + tmp[6] * src[13] + tmp[11] * src[15]); + m16[11] = (tmp[5] * src[12] + tmp[8] * src[13] + tmp[11] * src[14]) - (tmp[4] * src[12] + tmp[9] * src[13] + tmp[10] * src[14]); + m16[12] = (tmp[2] * src[10] + tmp[5] * src[11] + tmp[1] * src[9]) - (tmp[4] * src[11] + tmp[0] * src[9] + tmp[3] * src[10]); + m16[13] = (tmp[8] * src[11] + tmp[0] * src[8] + tmp[7] * src[10]) - (tmp[6] * src[10] + tmp[9] * src[11] + tmp[1] * src[8]); + m16[14] = (tmp[6] * src[9] + tmp[11] * src[11] + tmp[3] * src[8]) - (tmp[10] * src[11] + tmp[2] * src[8] + tmp[7] * src[9]); + m16[15] = (tmp[10] * src[10] + tmp[4] * src[8] + tmp[9] * src[9]) - (tmp[8] * src[9] + tmp[11] * src[10] + tmp[5] * src[8]); + + // calculate determinant + det = src[0] * m16[0] + src[1] * m16[1] + src[2] * m16[2] + src[3] * m16[3]; + + // calculate matrix inverse + float invdet = 1 / det; + for (int j = 0; j < 16; ++j) + { + m16[j] *= invdet; + } + } + + return det; + } + + void matrix_t::RotationAxis(const vec_t& axis, float angle) + { + float length2 = axis.LengthSq(); + if (length2 < FLT_EPSILON) + { + SetToIdentity(); + return; + } + + vec_t n = axis * (1.f / sqrtf(length2)); + float s = sinf(angle); + float c = cosf(angle); + float k = 1.f - c; + + float xx = n.x * n.x * k + c; + float yy = n.y * n.y * k + c; + float zz = n.z * n.z * k + c; + float xy = n.x * n.y * k; + float yz = n.y * n.z * k; + float zx = n.z * n.x * k; + float xs = n.x * s; + float ys = n.y * s; + float zs = n.z * s; + + m[0][0] = xx; + m[0][1] = xy + zs; + m[0][2] = zx - ys; + m[0][3] = 0.f; + m[1][0] = xy - zs; + m[1][1] = yy; + m[1][2] = yz + xs; + m[1][3] = 0.f; + m[2][0] = zx + ys; + m[2][1] = yz - xs; + m[2][2] = zz; + m[2][3] = 0.f; + m[3][0] = 0.f; + m[3][1] = 0.f; + m[3][2] = 0.f; + m[3][3] = 1.f; + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + + enum MOVETYPE + { + MT_NONE, + MT_MOVE_X, + MT_MOVE_Y, + MT_MOVE_Z, + MT_MOVE_YZ, + MT_MOVE_ZX, + MT_MOVE_XY, + MT_MOVE_SCREEN, + MT_ROTATE_X, + MT_ROTATE_Y, + MT_ROTATE_Z, + MT_ROTATE_SCREEN, + MT_SCALE_X, + MT_SCALE_Y, + MT_SCALE_Z, + MT_SCALE_XYZ + }; + + static bool IsTranslateType(int type) + { + return type >= MT_MOVE_X && type <= MT_MOVE_SCREEN; + } + + static bool IsRotateType(int type) + { + return type >= MT_ROTATE_X && type <= MT_ROTATE_SCREEN; + } + + static bool IsScaleType(int type) + { + return type >= MT_SCALE_X && type <= MT_SCALE_XYZ; + } + + // Matches MT_MOVE_AB order + static const OPERATION TRANSLATE_PLANS[3] = { TRANSLATE_Y | TRANSLATE_Z, TRANSLATE_X | TRANSLATE_Z, TRANSLATE_X | TRANSLATE_Y }; + + struct Context + { + Context() : mbUsing(false), mbEnable(true), mbUsingBounds(false) + { + } + + ImDrawList* mDrawList; + + MODE mMode; + matrix_t mViewMat; + matrix_t mProjectionMat; + matrix_t mModel; + matrix_t mModelLocal; // orthonormalized model + matrix_t mModelInverse; + matrix_t mModelSource; + matrix_t mModelSourceInverse; + matrix_t mMVP; + matrix_t mMVPLocal; // MVP with full model matrix whereas mMVP's model matrix might only be translation in case of World space edition + matrix_t mViewProjection; + + vec_t mModelScaleOrigin; + vec_t mCameraEye; + vec_t mCameraRight; + vec_t mCameraDir; + vec_t mCameraUp; + vec_t mRayOrigin; + vec_t mRayVector; + + float mRadiusSquareCenter; + ImVec2 mScreenSquareCenter; + ImVec2 mScreenSquareMin; + ImVec2 mScreenSquareMax; + + float mScreenFactor; + vec_t mRelativeOrigin; + + bool mbUsing; + bool mbEnable; + + bool mReversed; // reversed projection matrix + + // translation + vec_t mTranslationPlan; + vec_t mTranslationPlanOrigin; + vec_t mMatrixOrigin; + vec_t mTranslationLastDelta; + + // rotation + vec_t mRotationVectorSource; + float mRotationAngle; + float mRotationAngleOrigin; + //vec_t mWorldToLocalAxis; + + // scale + vec_t mScale; + vec_t mScaleValueOrigin; + vec_t mScaleLast; + float mSaveMousePosx; + + // save axis factor when using gizmo + bool mBelowAxisLimit[3]; + bool mBelowPlaneLimit[3]; + float mAxisFactor[3]; + + // bounds stretching + vec_t mBoundsPivot; + vec_t mBoundsAnchor; + vec_t mBoundsPlan; + vec_t mBoundsLocalPivot; + int mBoundsBestAxis; + int mBoundsAxis[2]; + bool mbUsingBounds; + matrix_t mBoundsMatrix; + + // + int mCurrentOperation; + + float mX = 0.f; + float mY = 0.f; + float mWidth = 0.f; + float mHeight = 0.f; + float mXMax = 0.f; + float mYMax = 0.f; + float mDisplayRatio = 1.f; + + bool mIsOrthographic = false; + + int mActualID = -1; + int mEditingID = -1; + OPERATION mOperation = OPERATION(-1); + + bool mAllowAxisFlip = true; + float mGizmoSizeClipSpace = 0.1f; + }; + + static Context gContext; + + static const vec_t directionUnary[3] = { makeVect(1.f, 0.f, 0.f), makeVect(0.f, 1.f, 0.f), makeVect(0.f, 0.f, 1.f) }; + static const ImU32 directionColor[3] = { IM_COL32(0xAA, 0, 0, 0xFF), IM_COL32(0, 0xAA, 0, 0xFF), IM_COL32(0, 0, 0xAA, 0XFF) }; + + // Alpha: 100%: FF, 87%: DE, 70%: B3, 54%: 8A, 50%: 80, 38%: 61, 12%: 1F + static const ImU32 planeColor[3] = { IM_COL32(0xAA, 0, 0, 0x61), IM_COL32(0, 0xAA, 0, 0x61), IM_COL32(0, 0, 0xAA, 0x61) }; + static const ImU32 selectionColor = IM_COL32(0xFF, 0x80, 0x10, 0x8A); + static const ImU32 inactiveColor = IM_COL32(0x99, 0x99, 0x99, 0x99); + static const ImU32 translationLineColor = IM_COL32(0xAA, 0xAA, 0xAA, 0xAA); + static const char* translationInfoMask[] = { "X : %5.3f", "Y : %5.3f", "Z : %5.3f", + "Y : %5.3f Z : %5.3f", "X : %5.3f Z : %5.3f", "X : %5.3f Y : %5.3f", + "X : %5.3f Y : %5.3f Z : %5.3f" }; + static const char* scaleInfoMask[] = { "X : %5.2f", "Y : %5.2f", "Z : %5.2f", "XYZ : %5.2f" }; + static const char* rotationInfoMask[] = { "X : %5.2f deg %5.2f rad", "Y : %5.2f deg %5.2f rad", "Z : %5.2f deg %5.2f rad", "Screen : %5.2f deg %5.2f rad" }; + static const int translationInfoIndex[] = { 0,0,0, 1,0,0, 2,0,0, 1,2,0, 0,2,0, 0,1,0, 0,1,2 }; + static const float quadMin = 0.5f; + static const float quadMax = 0.8f; + static const float quadUV[8] = { quadMin, quadMin, quadMin, quadMax, quadMax, quadMax, quadMax, quadMin }; + static const int halfCircleSegmentCount = 64; + static const float snapTension = 0.5f; + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + static int GetMoveType(OPERATION op, vec_t* gizmoHitProportion); + static int GetRotateType(OPERATION op); + static int GetScaleType(OPERATION op); + + static ImVec2 worldToPos(const vec_t& worldPos, const matrix_t& mat, ImVec2 position = ImVec2(gContext.mX, gContext.mY), ImVec2 size = ImVec2(gContext.mWidth, gContext.mHeight)) + { + vec_t trans; + trans.TransformPoint(worldPos, mat); + trans *= 0.5f / trans.w; + trans += makeVect(0.5f, 0.5f); + trans.y = 1.f - trans.y; + trans.x *= size.x; + trans.y *= size.y; + trans.x += position.x; + trans.y += position.y; + return ImVec2(trans.x, trans.y); + } + + static void ComputeCameraRay(vec_t& rayOrigin, vec_t& rayDir, ImVec2 position = ImVec2(gContext.mX, gContext.mY), ImVec2 size = ImVec2(gContext.mWidth, gContext.mHeight)) + { + ImGuiIO& io = ImGui::GetIO(); + + matrix_t mViewProjInverse; + mViewProjInverse.Inverse(gContext.mViewMat * gContext.mProjectionMat); + + const float mox = ((io.MousePos.x - position.x) / size.x) * 2.f - 1.f; + const float moy = (1.f - ((io.MousePos.y - position.y) / size.y)) * 2.f - 1.f; + + const float zNear = gContext.mReversed ? (1.f - FLT_EPSILON) : 0.f; + const float zFar = gContext.mReversed ? 0.f : (1.f - FLT_EPSILON); + + rayOrigin.Transform(makeVect(mox, moy, zNear, 1.f), mViewProjInverse); + rayOrigin *= 1.f / rayOrigin.w; + vec_t rayEnd; + rayEnd.Transform(makeVect(mox, moy, zFar, 1.f), mViewProjInverse); + rayEnd *= 1.f / rayEnd.w; + rayDir = Normalized(rayEnd - rayOrigin); + } + + static float GetSegmentLengthClipSpace(const vec_t& start, const vec_t& end, const bool localCoordinates = false) + { + vec_t startOfSegment = start; + const matrix_t& mvp = localCoordinates ? gContext.mMVPLocal : gContext.mMVP; + startOfSegment.TransformPoint(mvp); + if (fabsf(startOfSegment.w) > FLT_EPSILON) // check for axis aligned with camera direction + { + startOfSegment *= 1.f / startOfSegment.w; + } + + vec_t endOfSegment = end; + endOfSegment.TransformPoint(mvp); + if (fabsf(endOfSegment.w) > FLT_EPSILON) // check for axis aligned with camera direction + { + endOfSegment *= 1.f / endOfSegment.w; + } + + vec_t clipSpaceAxis = endOfSegment - startOfSegment; + clipSpaceAxis.y /= gContext.mDisplayRatio; + float segmentLengthInClipSpace = sqrtf(clipSpaceAxis.x * clipSpaceAxis.x + clipSpaceAxis.y * clipSpaceAxis.y); + return segmentLengthInClipSpace; + } + + static float GetParallelogram(const vec_t& ptO, const vec_t& ptA, const vec_t& ptB) + { + vec_t pts[] = { ptO, ptA, ptB }; + for (unsigned int i = 0; i < 3; i++) + { + pts[i].TransformPoint(gContext.mMVP); + if (fabsf(pts[i].w) > FLT_EPSILON) // check for axis aligned with camera direction + { + pts[i] *= 1.f / pts[i].w; + } + } + vec_t segA = pts[1] - pts[0]; + vec_t segB = pts[2] - pts[0]; + segA.y /= gContext.mDisplayRatio; + segB.y /= gContext.mDisplayRatio; + vec_t segAOrtho = makeVect(-segA.y, segA.x); + segAOrtho.Normalize(); + float dt = segAOrtho.Dot3(segB); + float surface = sqrtf(segA.x * segA.x + segA.y * segA.y) * fabsf(dt); + return surface; + } + + inline vec_t PointOnSegment(const vec_t& point, const vec_t& vertPos1, const vec_t& vertPos2) + { + vec_t c = point - vertPos1; + vec_t V; + + V.Normalize(vertPos2 - vertPos1); + float d = (vertPos2 - vertPos1).Length(); + float t = V.Dot3(c); + + if (t < 0.f) + { + return vertPos1; + } + + if (t > d) + { + return vertPos2; + } + + return vertPos1 + V * t; + } + + static float IntersectRayPlane(const vec_t& rOrigin, const vec_t& rVector, const vec_t& plan) + { + const float numer = plan.Dot3(rOrigin) - plan.w; + const float denom = plan.Dot3(rVector); + + if (fabsf(denom) < FLT_EPSILON) // normal is orthogonal to vector, cant intersect + { + return -1.0f; + } + + return -(numer / denom); + } + + static float DistanceToPlane(const vec_t& point, const vec_t& plan) + { + return plan.Dot3(point) + plan.w; + } + + static bool IsInContextRect(ImVec2 p) + { + return IsWithin(p.x, gContext.mX, gContext.mXMax) && IsWithin(p.y, gContext.mY, gContext.mYMax); + } + + void SetRect(float x, float y, float width, float height) + { + gContext.mX = x; + gContext.mY = y; + gContext.mWidth = width; + gContext.mHeight = height; + gContext.mXMax = gContext.mX + gContext.mWidth; + gContext.mYMax = gContext.mY + gContext.mXMax; + gContext.mDisplayRatio = width / height; + } + + void SetOrthographic(bool isOrthographic) + { + gContext.mIsOrthographic = isOrthographic; + } + + void SetDrawlist(ImDrawList* drawlist) + { + gContext.mDrawList = drawlist ? drawlist : ImGui::GetWindowDrawList(); + } + + void SetImGuiContext(ImGuiContext* ctx) + { + ImGui::SetCurrentContext(ctx); + } + + void BeginFrame() + { + const ImU32 flags = ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoScrollbar | ImGuiWindowFlags_NoInputs | ImGuiWindowFlags_NoSavedSettings | ImGuiWindowFlags_NoFocusOnAppearing | ImGuiWindowFlags_NoBringToFrontOnFocus; + +#ifdef IMGUI_HAS_VIEWPORT + ImGui::SetNextWindowSize(ImGui::GetMainViewport()->Size); + ImGui::SetNextWindowPos(ImGui::GetMainViewport()->Pos); +#else + ImGuiIO& io = ImGui::GetIO(); + ImGui::SetNextWindowSize(io.DisplaySize); + ImGui::SetNextWindowPos(ImVec2(0, 0)); +#endif + + ImGui::PushStyleColor(ImGuiCol_WindowBg, 0); + ImGui::PushStyleColor(ImGuiCol_Border, 0); + ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f); + + ImGui::Begin("gizmo", NULL, flags); + gContext.mDrawList = ImGui::GetWindowDrawList(); + ImGui::End(); + ImGui::PopStyleVar(); + ImGui::PopStyleColor(2); + } + + bool IsUsing() + { + return gContext.mbUsing || gContext.mbUsingBounds; + } + + bool IsOver() + { + return (Intersects(gContext.mOperation, TRANSLATE) && GetMoveType(gContext.mOperation, NULL) != MT_NONE) || + (Intersects(gContext.mOperation, ROTATE) && GetRotateType(gContext.mOperation) != MT_NONE) || + (Intersects(gContext.mOperation, SCALE) && GetScaleType(gContext.mOperation) != MT_NONE) || IsUsing(); + } + + bool IsOver(OPERATION op) + { + if(IsUsing()) + { + return true; + } + if(Intersects(op, SCALE) && GetScaleType(op) != MT_NONE) + { + return true; + } + if(Intersects(op, ROTATE) && GetRotateType(op) != MT_NONE) + { + return true; + } + if(Intersects(op, TRANSLATE) && GetMoveType(op, NULL) != MT_NONE) + { + return true; + } + return false; + } + + void Enable(bool enable) + { + gContext.mbEnable = enable; + if (!enable) + { + gContext.mbUsing = false; + gContext.mbUsingBounds = false; + } + } + + static void ComputeContext(const float* view, const float* projection, float* matrix, MODE mode) + { + gContext.mMode = mode; + gContext.mViewMat = *(matrix_t*)view; + gContext.mProjectionMat = *(matrix_t*)projection; + + gContext.mModelLocal = *(matrix_t*)matrix; + gContext.mModelLocal.OrthoNormalize(); + + if (mode == LOCAL) + { + gContext.mModel = gContext.mModelLocal; + } + else + { + gContext.mModel.Translation(((matrix_t*)matrix)->v.position); + } + gContext.mModelSource = *(matrix_t*)matrix; + gContext.mModelScaleOrigin.Set(gContext.mModelSource.v.right.Length(), gContext.mModelSource.v.up.Length(), gContext.mModelSource.v.dir.Length()); + + gContext.mModelInverse.Inverse(gContext.mModel); + gContext.mModelSourceInverse.Inverse(gContext.mModelSource); + gContext.mViewProjection = gContext.mViewMat * gContext.mProjectionMat; + gContext.mMVP = gContext.mModel * gContext.mViewProjection; + gContext.mMVPLocal = gContext.mModelLocal * gContext.mViewProjection; + + matrix_t viewInverse; + viewInverse.Inverse(gContext.mViewMat); + gContext.mCameraDir = viewInverse.v.dir; + gContext.mCameraEye = viewInverse.v.position; + gContext.mCameraRight = viewInverse.v.right; + gContext.mCameraUp = viewInverse.v.up; + + // projection reverse + vec_t nearPos, farPos; + nearPos.Transform(makeVect(0, 0, 1.f, 1.f), gContext.mProjectionMat); + farPos.Transform(makeVect(0, 0, 2.f, 1.f), gContext.mProjectionMat); + + gContext.mReversed = (nearPos.z/nearPos.w) > (farPos.z / farPos.w); + + // compute scale from the size of camera right vector projected on screen at the matrix position + vec_t pointRight = viewInverse.v.right; + pointRight.TransformPoint(gContext.mViewProjection); + gContext.mScreenFactor = gContext.mGizmoSizeClipSpace / (pointRight.x / pointRight.w - gContext.mMVP.v.position.x / gContext.mMVP.v.position.w); + + vec_t rightViewInverse = viewInverse.v.right; + rightViewInverse.TransformVector(gContext.mModelInverse); + float rightLength = GetSegmentLengthClipSpace(makeVect(0.f, 0.f), rightViewInverse); + gContext.mScreenFactor = gContext.mGizmoSizeClipSpace / rightLength; + + ImVec2 centerSSpace = worldToPos(makeVect(0.f, 0.f), gContext.mMVP); + gContext.mScreenSquareCenter = centerSSpace; + gContext.mScreenSquareMin = ImVec2(centerSSpace.x - 10.f, centerSSpace.y - 10.f); + gContext.mScreenSquareMax = ImVec2(centerSSpace.x + 10.f, centerSSpace.y + 10.f); + + ComputeCameraRay(gContext.mRayOrigin, gContext.mRayVector); + } + + static void ComputeColors(ImU32* colors, int type, OPERATION operation) + { + if (gContext.mbEnable) + { + switch (operation) + { + case TRANSLATE: + colors[0] = (type == MT_MOVE_SCREEN) ? selectionColor : IM_COL32_WHITE; + for (int i = 0; i < 3; i++) + { + colors[i + 1] = (type == (int)(MT_MOVE_X + i)) ? selectionColor : directionColor[i]; + colors[i + 4] = (type == (int)(MT_MOVE_YZ + i)) ? selectionColor : planeColor[i]; + colors[i + 4] = (type == MT_MOVE_SCREEN) ? selectionColor : colors[i + 4]; + } + break; + case ROTATE: + colors[0] = (type == MT_ROTATE_SCREEN) ? selectionColor : IM_COL32_WHITE; + for (int i = 0; i < 3; i++) + { + colors[i + 1] = (type == (int)(MT_ROTATE_X + i)) ? selectionColor : directionColor[i]; + } + break; + case SCALEU: + case SCALE: + colors[0] = (type == MT_SCALE_XYZ) ? selectionColor : IM_COL32_WHITE; + for (int i = 0; i < 3; i++) + { + colors[i + 1] = (type == (int)(MT_SCALE_X + i)) ? selectionColor : directionColor[i]; + } + break; + // note: this internal function is only called with three possible values for operation + default: + break; + } + } + else + { + for (int i = 0; i < 7; i++) + { + colors[i] = inactiveColor; + } + } + } + + static void ComputeTripodAxisAndVisibility(const int axisIndex, vec_t& dirAxis, vec_t& dirPlaneX, vec_t& dirPlaneY, bool& belowAxisLimit, bool& belowPlaneLimit, const bool localCoordinates = false) + { + dirAxis = directionUnary[axisIndex]; + dirPlaneX = directionUnary[(axisIndex + 1) % 3]; + dirPlaneY = directionUnary[(axisIndex + 2) % 3]; + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + // when using, use stored factors so the gizmo doesn't flip when we translate + belowAxisLimit = gContext.mBelowAxisLimit[axisIndex]; + belowPlaneLimit = gContext.mBelowPlaneLimit[axisIndex]; + + dirAxis *= gContext.mAxisFactor[axisIndex]; + dirPlaneX *= gContext.mAxisFactor[(axisIndex + 1) % 3]; + dirPlaneY *= gContext.mAxisFactor[(axisIndex + 2) % 3]; + } + else + { + // new method + float lenDir = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), dirAxis, localCoordinates); + float lenDirMinus = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), -dirAxis, localCoordinates); + + float lenDirPlaneX = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), dirPlaneX, localCoordinates); + float lenDirMinusPlaneX = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), -dirPlaneX, localCoordinates); + + float lenDirPlaneY = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), dirPlaneY, localCoordinates); + float lenDirMinusPlaneY = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), -dirPlaneY, localCoordinates); + + // For readability + bool & allowFlip = gContext.mAllowAxisFlip; + float mulAxis = (allowFlip && lenDir < lenDirMinus&& fabsf(lenDir - lenDirMinus) > FLT_EPSILON) ? -1.f : 1.f; + float mulAxisX = (allowFlip && lenDirPlaneX < lenDirMinusPlaneX&& fabsf(lenDirPlaneX - lenDirMinusPlaneX) > FLT_EPSILON) ? -1.f : 1.f; + float mulAxisY = (allowFlip && lenDirPlaneY < lenDirMinusPlaneY&& fabsf(lenDirPlaneY - lenDirMinusPlaneY) > FLT_EPSILON) ? -1.f : 1.f; + dirAxis *= mulAxis; + dirPlaneX *= mulAxisX; + dirPlaneY *= mulAxisY; + + // for axis + float axisLengthInClipSpace = GetSegmentLengthClipSpace(makeVect(0.f, 0.f, 0.f), dirAxis * gContext.mScreenFactor, localCoordinates); + + float paraSurf = GetParallelogram(makeVect(0.f, 0.f, 0.f), dirPlaneX * gContext.mScreenFactor, dirPlaneY * gContext.mScreenFactor); + belowPlaneLimit = (paraSurf > 0.0025f); + belowAxisLimit = (axisLengthInClipSpace > 0.02f); + + // and store values + gContext.mAxisFactor[axisIndex] = mulAxis; + gContext.mAxisFactor[(axisIndex + 1) % 3] = mulAxisX; + gContext.mAxisFactor[(axisIndex + 2) % 3] = mulAxisY; + gContext.mBelowAxisLimit[axisIndex] = belowAxisLimit; + gContext.mBelowPlaneLimit[axisIndex] = belowPlaneLimit; + } + } + + static void ComputeSnap(float* value, float snap) + { + if (snap <= FLT_EPSILON) + { + return; + } + + float modulo = fmodf(*value, snap); + float moduloRatio = fabsf(modulo) / snap; + if (moduloRatio < snapTension) + { + *value -= modulo; + } + else if (moduloRatio > (1.f - snapTension)) + { + *value = *value - modulo + snap * ((*value < 0.f) ? -1.f : 1.f); + } + } + static void ComputeSnap(vec_t& value, const float* snap) + { + for (int i = 0; i < 3; i++) + { + ComputeSnap(&value[i], snap[i]); + } + } + + static float ComputeAngleOnPlan() + { + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + vec_t localPos = Normalized(gContext.mRayOrigin + gContext.mRayVector * len - gContext.mModel.v.position); + + vec_t perpendicularVector; + perpendicularVector.Cross(gContext.mRotationVectorSource, gContext.mTranslationPlan); + perpendicularVector.Normalize(); + float acosAngle = Clamp(Dot(localPos, gContext.mRotationVectorSource), -1.f, 1.f); + float angle = acosf(acosAngle); + angle *= (Dot(localPos, perpendicularVector) < 0.f) ? 1.f : -1.f; + return angle; + } + + static void DrawRotationGizmo(OPERATION op, int type) + { + if(!Intersects(op, ROTATE)) + { + return; + } + ImDrawList* drawList = gContext.mDrawList; + + // colors + ImU32 colors[7]; + ComputeColors(colors, type, ROTATE); + + vec_t cameraToModelNormalized; + if (gContext.mIsOrthographic) + { + matrix_t viewInverse; + viewInverse.Inverse(*(matrix_t*)&gContext.mViewMat); + cameraToModelNormalized = viewInverse.v.dir; + } + else + { + cameraToModelNormalized = Normalized(gContext.mModel.v.position - gContext.mCameraEye); + } + + cameraToModelNormalized.TransformVector(gContext.mModelInverse); + + gContext.mRadiusSquareCenter = screenRotateSize * gContext.mHeight; + + bool hasRSC = Intersects(op, ROTATE_SCREEN); + for (int axis = 0; axis < 3; axis++) + { + if(!Intersects(op, static_cast(ROTATE_Z >> axis))) + { + continue; + } + const bool usingAxis = (gContext.mbUsing && type == MT_ROTATE_Z - axis); + const int circleMul = (hasRSC && !usingAxis ) ? 1 : 2; + + ImVec2* circlePos = (ImVec2*)alloca(sizeof(ImVec2) * (circleMul * halfCircleSegmentCount + 1)); + + float angleStart = atan2f(cameraToModelNormalized[(4 - axis) % 3], cameraToModelNormalized[(3 - axis) % 3]) + ZPI * 0.5f; + + for (int i = 0; i < circleMul * halfCircleSegmentCount + 1; i++) + { + float ng = angleStart + circleMul * ZPI * ((float)i / (float)halfCircleSegmentCount); + vec_t axisPos = makeVect(cosf(ng), sinf(ng), 0.f); + vec_t pos = makeVect(axisPos[axis], axisPos[(axis + 1) % 3], axisPos[(axis + 2) % 3]) * gContext.mScreenFactor * rotationDisplayFactor; + circlePos[i] = worldToPos(pos, gContext.mMVP); + } + if (!gContext.mbUsing || usingAxis) + { + drawList->AddPolyline(circlePos, circleMul* halfCircleSegmentCount + 1, colors[3 - axis], false, 2); + } + + float radiusAxis = sqrtf((ImLengthSqr(worldToPos(gContext.mModel.v.position, gContext.mViewProjection) - circlePos[0]))); + if (radiusAxis > gContext.mRadiusSquareCenter) + { + gContext.mRadiusSquareCenter = radiusAxis; + } + } + if(hasRSC && (!gContext.mbUsing || type == MT_ROTATE_SCREEN)) + { + drawList->AddCircle(worldToPos(gContext.mModel.v.position, gContext.mViewProjection), gContext.mRadiusSquareCenter, colors[0], 64, 3.f); + } + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsRotateType(type)) + { + ImVec2 circlePos[halfCircleSegmentCount + 1]; + + circlePos[0] = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + for (unsigned int i = 1; i < halfCircleSegmentCount; i++) + { + float ng = gContext.mRotationAngle * ((float)(i - 1) / (float)(halfCircleSegmentCount - 1)); + matrix_t rotateVectorMatrix; + rotateVectorMatrix.RotationAxis(gContext.mTranslationPlan, ng); + vec_t pos; + pos.TransformPoint(gContext.mRotationVectorSource, rotateVectorMatrix); + pos *= gContext.mScreenFactor * rotationDisplayFactor; + circlePos[i] = worldToPos(pos + gContext.mModel.v.position, gContext.mViewProjection); + } + drawList->AddConvexPolyFilled(circlePos, halfCircleSegmentCount, IM_COL32(0xFF, 0x80, 0x10, 0x80)); + drawList->AddPolyline(circlePos, halfCircleSegmentCount, IM_COL32(0xFF, 0x80, 0x10, 0xFF), true, 2); + + ImVec2 destinationPosOnScreen = circlePos[1]; + char tmps[512]; + ImFormatString(tmps, sizeof(tmps), rotationInfoMask[type - MT_ROTATE_X], (gContext.mRotationAngle / ZPI) * 180.f, gContext.mRotationAngle); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 15, destinationPosOnScreen.y + 15), IM_COL32_BLACK, tmps); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 14, destinationPosOnScreen.y + 14), IM_COL32_WHITE, tmps); + } + } + + static void DrawHatchedAxis(const vec_t& axis) + { + for (int j = 1; j < 10; j++) + { + ImVec2 baseSSpace2 = worldToPos(axis * 0.05f * (float)(j * 2) * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpace2 = worldToPos(axis * 0.05f * (float)(j * 2 + 1) * gContext.mScreenFactor, gContext.mMVP); + gContext.mDrawList->AddLine(baseSSpace2, worldDirSSpace2, IM_COL32(0, 0, 0, 0x80), 6.f); + } + } + + static void DrawScaleGizmo(OPERATION op, int type) + { + ImDrawList* drawList = gContext.mDrawList; + + if(!Intersects(op, SCALE)) + { + return; + } + + // colors + ImU32 colors[7]; + ComputeColors(colors, type, SCALE); + + // draw + vec_t scaleDisplay = { 1.f, 1.f, 1.f, 1.f }; + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + scaleDisplay = gContext.mScale; + } + + for (unsigned int i = 0; i < 3; i++) + { + if(!Intersects(op, static_cast(SCALE_X << i))) + { + continue; + } + const bool usingAxis = (gContext.mbUsing && type == MT_SCALE_X + i); + if (!gContext.mbUsing || usingAxis) + { + vec_t dirPlaneX, dirPlaneY, dirAxis; + bool belowAxisLimit, belowPlaneLimit; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit, true); + + // draw axis + if (belowAxisLimit) + { + bool hasTranslateOnAxis = Contains(op, static_cast(TRANSLATE_X << i)); + float markerScale = hasTranslateOnAxis ? 1.4f : 1.0f; + ImVec2 baseSSpace = worldToPos(dirAxis * 0.1f * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpaceNoScale = worldToPos(dirAxis * markerScale * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpace = worldToPos((dirAxis * markerScale * scaleDisplay[i]) * gContext.mScreenFactor, gContext.mMVP); + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + drawList->AddLine(baseSSpace, worldDirSSpaceNoScale, IM_COL32(0x40, 0x40, 0x40, 0xFF), 3.f); + drawList->AddCircleFilled(worldDirSSpaceNoScale, 6.f, IM_COL32(0x40, 0x40, 0x40, 0xFF)); + } + + if (!hasTranslateOnAxis || gContext.mbUsing) + { + drawList->AddLine(baseSSpace, worldDirSSpace, colors[i + 1], 3.f); + } + drawList->AddCircleFilled(worldDirSSpace, 6.f, colors[i + 1]); + + if (gContext.mAxisFactor[i] < 0.f) + { + DrawHatchedAxis(dirAxis * scaleDisplay[i]); + } + } + } + } + + // draw screen cirle + drawList->AddCircleFilled(gContext.mScreenSquareCenter, 6.f, colors[0], 32); + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsScaleType(type)) + { + //ImVec2 sourcePosOnScreen = worldToPos(gContext.mMatrixOrigin, gContext.mViewProjection); + ImVec2 destinationPosOnScreen = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + /*vec_t dif(destinationPosOnScreen.x - sourcePosOnScreen.x, destinationPosOnScreen.y - sourcePosOnScreen.y); + dif.Normalize(); + dif *= 5.f; + drawList->AddCircle(sourcePosOnScreen, 6.f, translationLineColor); + drawList->AddCircle(destinationPosOnScreen, 6.f, translationLineColor); + drawList->AddLine(ImVec2(sourcePosOnScreen.x + dif.x, sourcePosOnScreen.y + dif.y), ImVec2(destinationPosOnScreen.x - dif.x, destinationPosOnScreen.y - dif.y), translationLineColor, 2.f); + */ + char tmps[512]; + //vec_t deltaInfo = gContext.mModel.v.position - gContext.mMatrixOrigin; + int componentInfoIndex = (type - MT_SCALE_X) * 3; + ImFormatString(tmps, sizeof(tmps), scaleInfoMask[type - MT_SCALE_X], scaleDisplay[translationInfoIndex[componentInfoIndex]]); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 15, destinationPosOnScreen.y + 15), IM_COL32_BLACK, tmps); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 14, destinationPosOnScreen.y + 14), IM_COL32_WHITE, tmps); + } + } + + + static void DrawScaleUniveralGizmo(OPERATION op, int type) + { + ImDrawList* drawList = gContext.mDrawList; + + if (!Intersects(op, SCALEU)) + { + return; + } + + // colors + ImU32 colors[7]; + ComputeColors(colors, type, SCALEU); + + // draw + vec_t scaleDisplay = { 1.f, 1.f, 1.f, 1.f }; + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + scaleDisplay = gContext.mScale; + } + + for (unsigned int i = 0; i < 3; i++) + { + if (!Intersects(op, static_cast(SCALE_XU << i))) + { + continue; + } + const bool usingAxis = (gContext.mbUsing && type == MT_SCALE_X + i); + if (!gContext.mbUsing || usingAxis) + { + vec_t dirPlaneX, dirPlaneY, dirAxis; + bool belowAxisLimit, belowPlaneLimit; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit, true); + + // draw axis + if (belowAxisLimit) + { + bool hasTranslateOnAxis = Contains(op, static_cast(TRANSLATE_X << i)); + float markerScale = hasTranslateOnAxis ? 1.4f : 1.0f; + ImVec2 baseSSpace = worldToPos(dirAxis * 0.1f * gContext.mScreenFactor, gContext.mMVPLocal); + //ImVec2 worldDirSSpaceNoScale = worldToPos(dirAxis * markerScale * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpace = worldToPos((dirAxis * markerScale * scaleDisplay[i]) * gContext.mScreenFactor, gContext.mMVPLocal); + + /*if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + drawList->AddLine(baseSSpace, worldDirSSpaceNoScale, IM_COL32(0x40, 0x40, 0x40, 0xFF), 3.f); + drawList->AddCircleFilled(worldDirSSpaceNoScale, 6.f, IM_COL32(0x40, 0x40, 0x40, 0xFF)); + } + */ + /* + if (!hasTranslateOnAxis || gContext.mbUsing) + { + drawList->AddLine(baseSSpace, worldDirSSpace, colors[i + 1], 3.f); + } + */ + drawList->AddCircleFilled(worldDirSSpace, 12.f, colors[i + 1]); + } + } + } + + // draw screen cirle + drawList->AddCircle(gContext.mScreenSquareCenter, 20.f, colors[0], 32, 3.f); + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsScaleType(type)) + { + //ImVec2 sourcePosOnScreen = worldToPos(gContext.mMatrixOrigin, gContext.mViewProjection); + ImVec2 destinationPosOnScreen = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + /*vec_t dif(destinationPosOnScreen.x - sourcePosOnScreen.x, destinationPosOnScreen.y - sourcePosOnScreen.y); + dif.Normalize(); + dif *= 5.f; + drawList->AddCircle(sourcePosOnScreen, 6.f, translationLineColor); + drawList->AddCircle(destinationPosOnScreen, 6.f, translationLineColor); + drawList->AddLine(ImVec2(sourcePosOnScreen.x + dif.x, sourcePosOnScreen.y + dif.y), ImVec2(destinationPosOnScreen.x - dif.x, destinationPosOnScreen.y - dif.y), translationLineColor, 2.f); + */ + char tmps[512]; + //vec_t deltaInfo = gContext.mModel.v.position - gContext.mMatrixOrigin; + int componentInfoIndex = (type - MT_SCALE_X) * 3; + ImFormatString(tmps, sizeof(tmps), scaleInfoMask[type - MT_SCALE_X], scaleDisplay[translationInfoIndex[componentInfoIndex]]); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 15, destinationPosOnScreen.y + 15), IM_COL32_BLACK, tmps); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 14, destinationPosOnScreen.y + 14), IM_COL32_WHITE, tmps); + } + } + + static void DrawTranslationGizmo(OPERATION op, int type) + { + ImDrawList* drawList = gContext.mDrawList; + if (!drawList) + { + return; + } + + if(!Intersects(op, TRANSLATE)) + { + return; + } + + // colors + ImU32 colors[7]; + ComputeColors(colors, type, TRANSLATE); + + const ImVec2 origin = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + + // draw + bool belowAxisLimit = false; + bool belowPlaneLimit = false; + for (unsigned int i = 0; i < 3; ++i) + { + vec_t dirPlaneX, dirPlaneY, dirAxis; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit); + + if (!gContext.mbUsing || (gContext.mbUsing && type == MT_MOVE_X + i)) + { + // draw axis + if (belowAxisLimit && Intersects(op, static_cast(TRANSLATE_X << i))) + { + ImVec2 baseSSpace = worldToPos(dirAxis * 0.1f * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpace = worldToPos(dirAxis * gContext.mScreenFactor, gContext.mMVP); + + drawList->AddLine(baseSSpace, worldDirSSpace, colors[i + 1], 3.f); + + // Arrow head begin + ImVec2 dir(origin - worldDirSSpace); + + float d = sqrtf(ImLengthSqr(dir)); + dir /= d; // Normalize + dir *= 6.0f; + + ImVec2 ortogonalDir(dir.y, -dir.x); // Perpendicular vector + ImVec2 a(worldDirSSpace + dir); + drawList->AddTriangleFilled(worldDirSSpace - dir, a + ortogonalDir, a - ortogonalDir, colors[i + 1]); + // Arrow head end + + if (gContext.mAxisFactor[i] < 0.f) + { + DrawHatchedAxis(dirAxis); + } + } + } + // draw plane + if (!gContext.mbUsing || (gContext.mbUsing && type == MT_MOVE_YZ + i)) + { + if (belowPlaneLimit && Contains(op, TRANSLATE_PLANS[i])) + { + ImVec2 screenQuadPts[4]; + for (int j = 0; j < 4; ++j) + { + vec_t cornerWorldPos = (dirPlaneX * quadUV[j * 2] + dirPlaneY * quadUV[j * 2 + 1]) * gContext.mScreenFactor; + screenQuadPts[j] = worldToPos(cornerWorldPos, gContext.mMVP); + } + drawList->AddPolyline(screenQuadPts, 4, directionColor[i], true, 1.0f); + drawList->AddConvexPolyFilled(screenQuadPts, 4, colors[i + 4]); + } + } + } + + drawList->AddCircleFilled(gContext.mScreenSquareCenter, 6.f, colors[0], 32); + + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsTranslateType(type)) + { + ImVec2 sourcePosOnScreen = worldToPos(gContext.mMatrixOrigin, gContext.mViewProjection); + ImVec2 destinationPosOnScreen = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + vec_t dif = { destinationPosOnScreen.x - sourcePosOnScreen.x, destinationPosOnScreen.y - sourcePosOnScreen.y, 0.f, 0.f }; + dif.Normalize(); + dif *= 5.f; + drawList->AddCircle(sourcePosOnScreen, 6.f, translationLineColor); + drawList->AddCircle(destinationPosOnScreen, 6.f, translationLineColor); + drawList->AddLine(ImVec2(sourcePosOnScreen.x + dif.x, sourcePosOnScreen.y + dif.y), ImVec2(destinationPosOnScreen.x - dif.x, destinationPosOnScreen.y - dif.y), translationLineColor, 2.f); + + char tmps[512]; + vec_t deltaInfo = gContext.mModel.v.position - gContext.mMatrixOrigin; + int componentInfoIndex = (type - MT_MOVE_X) * 3; + ImFormatString(tmps, sizeof(tmps), translationInfoMask[type - MT_MOVE_X], deltaInfo[translationInfoIndex[componentInfoIndex]], deltaInfo[translationInfoIndex[componentInfoIndex + 1]], deltaInfo[translationInfoIndex[componentInfoIndex + 2]]); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 15, destinationPosOnScreen.y + 15), IM_COL32_BLACK, tmps); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 14, destinationPosOnScreen.y + 14), IM_COL32_WHITE, tmps); + } + } + + static bool CanActivate() + { + if (ImGui::IsMouseClicked(0) && !ImGui::IsAnyItemHovered() && !ImGui::IsAnyItemActive()) + { + return true; + } + return false; + } + + static void HandleAndDrawLocalBounds(const float* bounds, matrix_t* matrix, const float* snapValues, OPERATION operation) + { + ImGuiIO& io = ImGui::GetIO(); + ImDrawList* drawList = gContext.mDrawList; + + // compute best projection axis + vec_t axesWorldDirections[3]; + vec_t bestAxisWorldDirection = { 0.0f, 0.0f, 0.0f, 0.0f }; + int axes[3]; + unsigned int numAxes = 1; + axes[0] = gContext.mBoundsBestAxis; + int bestAxis = axes[0]; + if (!gContext.mbUsingBounds) + { + numAxes = 0; + float bestDot = 0.f; + for (unsigned int i = 0; i < 3; i++) + { + vec_t dirPlaneNormalWorld; + dirPlaneNormalWorld.TransformVector(directionUnary[i], gContext.mModelSource); + dirPlaneNormalWorld.Normalize(); + + float dt = fabsf(Dot(Normalized(gContext.mCameraEye - gContext.mModelSource.v.position), dirPlaneNormalWorld)); + if (dt >= bestDot) + { + bestDot = dt; + bestAxis = i; + bestAxisWorldDirection = dirPlaneNormalWorld; + } + + if (dt >= 0.1f) + { + axes[numAxes] = i; + axesWorldDirections[numAxes] = dirPlaneNormalWorld; + ++numAxes; + } + } + } + + if (numAxes == 0) + { + axes[0] = bestAxis; + axesWorldDirections[0] = bestAxisWorldDirection; + numAxes = 1; + } + + else if (bestAxis != axes[0]) + { + unsigned int bestIndex = 0; + for (unsigned int i = 0; i < numAxes; i++) + { + if (axes[i] == bestAxis) + { + bestIndex = i; + break; + } + } + int tempAxis = axes[0]; + axes[0] = axes[bestIndex]; + axes[bestIndex] = tempAxis; + vec_t tempDirection = axesWorldDirections[0]; + axesWorldDirections[0] = axesWorldDirections[bestIndex]; + axesWorldDirections[bestIndex] = tempDirection; + } + + for (unsigned int axisIndex = 0; axisIndex < numAxes; ++axisIndex) + { + bestAxis = axes[axisIndex]; + bestAxisWorldDirection = axesWorldDirections[axisIndex]; + + // corners + vec_t aabb[4]; + + int secondAxis = (bestAxis + 1) % 3; + int thirdAxis = (bestAxis + 2) % 3; + + for (int i = 0; i < 4; i++) + { + aabb[i][3] = aabb[i][bestAxis] = 0.f; + aabb[i][secondAxis] = bounds[secondAxis + 3 * (i >> 1)]; + aabb[i][thirdAxis] = bounds[thirdAxis + 3 * ((i >> 1) ^ (i & 1))]; + } + + // draw bounds + unsigned int anchorAlpha = gContext.mbEnable ? IM_COL32_BLACK : IM_COL32(0, 0, 0, 0x80); + + matrix_t boundsMVP = gContext.mModelSource * gContext.mViewProjection; + for (int i = 0; i < 4; i++) + { + ImVec2 worldBound1 = worldToPos(aabb[i], boundsMVP); + ImVec2 worldBound2 = worldToPos(aabb[(i + 1) % 4], boundsMVP); + if (!IsInContextRect(worldBound1) || !IsInContextRect(worldBound2)) + { + continue; + } + float boundDistance = sqrtf(ImLengthSqr(worldBound1 - worldBound2)); + int stepCount = (int)(boundDistance / 10.f); + stepCount = min(stepCount, 1000); + float stepLength = 1.f / (float)stepCount; + for (int j = 0; j < stepCount; j++) + { + float t1 = (float)j * stepLength; + float t2 = (float)j * stepLength + stepLength * 0.5f; + ImVec2 worldBoundSS1 = ImLerp(worldBound1, worldBound2, ImVec2(t1, t1)); + ImVec2 worldBoundSS2 = ImLerp(worldBound1, worldBound2, ImVec2(t2, t2)); + //drawList->AddLine(worldBoundSS1, worldBoundSS2, IM_COL32(0, 0, 0, 0) + anchorAlpha, 3.f); + drawList->AddLine(worldBoundSS1, worldBoundSS2, IM_COL32(0xAA, 0xAA, 0xAA, 0) + anchorAlpha, 2.f); + } + vec_t midPoint = (aabb[i] + aabb[(i + 1) % 4]) * 0.5f; + ImVec2 midBound = worldToPos(midPoint, boundsMVP); + static const float AnchorBigRadius = 8.f; + static const float AnchorSmallRadius = 6.f; + bool overBigAnchor = ImLengthSqr(worldBound1 - io.MousePos) <= (AnchorBigRadius * AnchorBigRadius); + bool overSmallAnchor = ImLengthSqr(midBound - io.MousePos) <= (AnchorBigRadius * AnchorBigRadius); + + int type = MT_NONE; + vec_t gizmoHitProportion; + + if(Intersects(operation, TRANSLATE)) + { + type = GetMoveType(operation, &gizmoHitProportion); + } + if(Intersects(operation, ROTATE) && type == MT_NONE) + { + type = GetRotateType(operation); + } + if(Intersects(operation, SCALE) && type == MT_NONE) + { + type = GetScaleType(operation); + } + + if (type != MT_NONE) + { + overBigAnchor = false; + overSmallAnchor = false; + } + + unsigned int bigAnchorColor = overBigAnchor ? selectionColor : (IM_COL32(0xAA, 0xAA, 0xAA, 0) + anchorAlpha); + unsigned int smallAnchorColor = overSmallAnchor ? selectionColor : (IM_COL32(0xAA, 0xAA, 0xAA, 0) + anchorAlpha); + + drawList->AddCircleFilled(worldBound1, AnchorBigRadius, IM_COL32_BLACK); + drawList->AddCircleFilled(worldBound1, AnchorBigRadius - 1.2f, bigAnchorColor); + + drawList->AddCircleFilled(midBound, AnchorSmallRadius, IM_COL32_BLACK); + drawList->AddCircleFilled(midBound, AnchorSmallRadius - 1.2f, smallAnchorColor); + int oppositeIndex = (i + 2) % 4; + // big anchor on corners + if (!gContext.mbUsingBounds && gContext.mbEnable && overBigAnchor && CanActivate()) + { + gContext.mBoundsPivot.TransformPoint(aabb[(i + 2) % 4], gContext.mModelSource); + gContext.mBoundsAnchor.TransformPoint(aabb[i], gContext.mModelSource); + gContext.mBoundsPlan = BuildPlan(gContext.mBoundsAnchor, bestAxisWorldDirection); + gContext.mBoundsBestAxis = bestAxis; + gContext.mBoundsAxis[0] = secondAxis; + gContext.mBoundsAxis[1] = thirdAxis; + + gContext.mBoundsLocalPivot.Set(0.f); + gContext.mBoundsLocalPivot[secondAxis] = aabb[oppositeIndex][secondAxis]; + gContext.mBoundsLocalPivot[thirdAxis] = aabb[oppositeIndex][thirdAxis]; + + gContext.mbUsingBounds = true; + gContext.mEditingID = gContext.mActualID; + gContext.mBoundsMatrix = gContext.mModelSource; + } + // small anchor on middle of segment + if (!gContext.mbUsingBounds && gContext.mbEnable && overSmallAnchor && CanActivate()) + { + vec_t midPointOpposite = (aabb[(i + 2) % 4] + aabb[(i + 3) % 4]) * 0.5f; + gContext.mBoundsPivot.TransformPoint(midPointOpposite, gContext.mModelSource); + gContext.mBoundsAnchor.TransformPoint(midPoint, gContext.mModelSource); + gContext.mBoundsPlan = BuildPlan(gContext.mBoundsAnchor, bestAxisWorldDirection); + gContext.mBoundsBestAxis = bestAxis; + int indices[] = { secondAxis , thirdAxis }; + gContext.mBoundsAxis[0] = indices[i % 2]; + gContext.mBoundsAxis[1] = -1; + + gContext.mBoundsLocalPivot.Set(0.f); + gContext.mBoundsLocalPivot[gContext.mBoundsAxis[0]] = aabb[oppositeIndex][indices[i % 2]];// bounds[gContext.mBoundsAxis[0]] * (((i + 1) & 2) ? 1.f : -1.f); + + gContext.mbUsingBounds = true; + gContext.mEditingID = gContext.mActualID; + gContext.mBoundsMatrix = gContext.mModelSource; + } + } + + if (gContext.mbUsingBounds && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID)) + { + matrix_t scale; + scale.SetToIdentity(); + + // compute projected mouse position on plan + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mBoundsPlan); + vec_t newPos = gContext.mRayOrigin + gContext.mRayVector * len; + + // compute a reference and delta vectors base on mouse move + vec_t deltaVector = (newPos - gContext.mBoundsPivot).Abs(); + vec_t referenceVector = (gContext.mBoundsAnchor - gContext.mBoundsPivot).Abs(); + + // for 1 or 2 axes, compute a ratio that's used for scale and snap it based on resulting length + for (int i = 0; i < 2; i++) + { + int axisIndex1 = gContext.mBoundsAxis[i]; + if (axisIndex1 == -1) + { + continue; + } + + float ratioAxis = 1.f; + vec_t axisDir = gContext.mBoundsMatrix.component[axisIndex1].Abs(); + + float dtAxis = axisDir.Dot(referenceVector); + float boundSize = bounds[axisIndex1 + 3] - bounds[axisIndex1]; + if (dtAxis > FLT_EPSILON) + { + ratioAxis = axisDir.Dot(deltaVector) / dtAxis; + } + + if (snapValues) + { + float length = boundSize * ratioAxis; + ComputeSnap(&length, snapValues[axisIndex1]); + if (boundSize > FLT_EPSILON) + { + ratioAxis = length / boundSize; + } + } + scale.component[axisIndex1] *= ratioAxis; + } + + // transform matrix + matrix_t preScale, postScale; + preScale.Translation(-gContext.mBoundsLocalPivot); + postScale.Translation(gContext.mBoundsLocalPivot); + matrix_t res = preScale * scale * postScale * gContext.mBoundsMatrix; + *matrix = res; + + // info text + char tmps[512]; + ImVec2 destinationPosOnScreen = worldToPos(gContext.mModel.v.position, gContext.mViewProjection); + ImFormatString(tmps, sizeof(tmps), "X: %.2f Y: %.2f Z:%.2f" + , (bounds[3] - bounds[0]) * gContext.mBoundsMatrix.component[0].Length() * scale.component[0].Length() + , (bounds[4] - bounds[1]) * gContext.mBoundsMatrix.component[1].Length() * scale.component[1].Length() + , (bounds[5] - bounds[2]) * gContext.mBoundsMatrix.component[2].Length() * scale.component[2].Length() + ); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 15, destinationPosOnScreen.y + 15), IM_COL32_BLACK, tmps); + drawList->AddText(ImVec2(destinationPosOnScreen.x + 14, destinationPosOnScreen.y + 14), IM_COL32_WHITE, tmps); + } + + if (!io.MouseDown[0]) { + gContext.mbUsingBounds = false; + gContext.mEditingID = -1; + } + if (gContext.mbUsingBounds) + { + break; + } + } + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + + static int GetScaleType(OPERATION op) + { + if (gContext.mbUsing) + { + return MT_NONE; + } + ImGuiIO& io = ImGui::GetIO(); + int type = MT_NONE; + + // screen + if (io.MousePos.x >= gContext.mScreenSquareMin.x && io.MousePos.x <= gContext.mScreenSquareMax.x && + io.MousePos.y >= gContext.mScreenSquareMin.y && io.MousePos.y <= gContext.mScreenSquareMax.y && + Contains(op, SCALE)) + { + type = MT_SCALE_XYZ; + } + + // compute + for (unsigned int i = 0; i < 3 && type == MT_NONE; i++) + { + if(!Intersects(op, static_cast(SCALE_X << i))) + { + continue; + } + vec_t dirPlaneX, dirPlaneY, dirAxis; + bool belowAxisLimit, belowPlaneLimit; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit, true); + dirAxis.TransformVector(gContext.mModelLocal); + dirPlaneX.TransformVector(gContext.mModelLocal); + dirPlaneY.TransformVector(gContext.mModelLocal); + + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, BuildPlan(gContext.mModelLocal.v.position, dirAxis)); + vec_t posOnPlan = gContext.mRayOrigin + gContext.mRayVector * len; + + const float startOffset = Contains(op, static_cast(TRANSLATE_X << i)) ? 1.0f : 0.1f; + const float endOffset = Contains(op, static_cast(TRANSLATE_X << i)) ? 1.4f : 1.0f; + const ImVec2 posOnPlanScreen = worldToPos(posOnPlan, gContext.mViewProjection); + const ImVec2 axisStartOnScreen = worldToPos(gContext.mModelLocal.v.position + dirAxis * gContext.mScreenFactor * startOffset, gContext.mViewProjection); + const ImVec2 axisEndOnScreen = worldToPos(gContext.mModelLocal.v.position + dirAxis * gContext.mScreenFactor * endOffset, gContext.mViewProjection); + + vec_t closestPointOnAxis = PointOnSegment(makeVect(posOnPlanScreen), makeVect(axisStartOnScreen), makeVect(axisEndOnScreen)); + + if ((closestPointOnAxis - makeVect(posOnPlanScreen)).Length() < 12.f) // pixel size + { + type = MT_SCALE_X + i; + } + } + + // universal + + vec_t deltaScreen = { io.MousePos.x - gContext.mScreenSquareCenter.x, io.MousePos.y - gContext.mScreenSquareCenter.y, 0.f, 0.f }; + float dist = deltaScreen.Length(); + if (Contains(op, SCALEU) && dist >= 17.0f && dist < 23.0f) + { + type = MT_SCALE_XYZ; + } + + for (unsigned int i = 0; i < 3 && type == MT_NONE; i++) + { + if (!Intersects(op, static_cast(SCALE_XU << i))) + { + continue; + } + + vec_t dirPlaneX, dirPlaneY, dirAxis; + bool belowAxisLimit, belowPlaneLimit; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit, true); + + // draw axis + if (belowAxisLimit) + { + bool hasTranslateOnAxis = Contains(op, static_cast(TRANSLATE_X << i)); + float markerScale = hasTranslateOnAxis ? 1.4f : 1.0f; + ImVec2 baseSSpace = worldToPos(dirAxis * 0.1f * gContext.mScreenFactor, gContext.mMVPLocal); + //ImVec2 worldDirSSpaceNoScale = worldToPos(dirAxis * markerScale * gContext.mScreenFactor, gContext.mMVP); + ImVec2 worldDirSSpace = worldToPos((dirAxis * markerScale) * gContext.mScreenFactor, gContext.mMVPLocal); + + float distance = sqrtf(ImLengthSqr(worldDirSSpace - io.MousePos)); + if (distance < 12.f) + { + type = MT_SCALE_X + i; + } + } + } + return type; + } + + static int GetRotateType(OPERATION op) + { + if (gContext.mbUsing) + { + return MT_NONE; + } + ImGuiIO& io = ImGui::GetIO(); + int type = MT_NONE; + + vec_t deltaScreen = { io.MousePos.x - gContext.mScreenSquareCenter.x, io.MousePos.y - gContext.mScreenSquareCenter.y, 0.f, 0.f }; + float dist = deltaScreen.Length(); + if (Intersects(op, ROTATE_SCREEN) && dist >= (gContext.mRadiusSquareCenter - 4.0f) && dist < (gContext.mRadiusSquareCenter + 4.0f)) + { + type = MT_ROTATE_SCREEN; + } + + const vec_t planNormals[] = { gContext.mModel.v.right, gContext.mModel.v.up, gContext.mModel.v.dir }; + + vec_t modelViewPos; + modelViewPos.TransformPoint(gContext.mModel.v.position, gContext.mViewMat); + + for (unsigned int i = 0; i < 3 && type == MT_NONE; i++) + { + if(!Intersects(op, static_cast(ROTATE_X << i))) + { + continue; + } + // pickup plan + vec_t pickupPlan = BuildPlan(gContext.mModel.v.position, planNormals[i]); + + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, pickupPlan); + const vec_t intersectWorldPos = gContext.mRayOrigin + gContext.mRayVector * len; + vec_t intersectViewPos; + intersectViewPos.TransformPoint(intersectWorldPos, gContext.mViewMat); + + if (ImAbs(modelViewPos.z) - ImAbs(intersectViewPos.z) < -FLT_EPSILON) + { + continue; + } + + const vec_t localPos = intersectWorldPos - gContext.mModel.v.position; + vec_t idealPosOnCircle = Normalized(localPos); + idealPosOnCircle.TransformVector(gContext.mModelInverse); + const ImVec2 idealPosOnCircleScreen = worldToPos(idealPosOnCircle * rotationDisplayFactor * gContext.mScreenFactor, gContext.mMVP); + + //gContext.mDrawList->AddCircle(idealPosOnCircleScreen, 5.f, IM_COL32_WHITE); + const ImVec2 distanceOnScreen = idealPosOnCircleScreen - io.MousePos; + + const float distance = makeVect(distanceOnScreen).Length(); + if (distance < 8.f) // pixel size + { + type = MT_ROTATE_X + i; + } + } + + return type; + } + + static int GetMoveType(OPERATION op, vec_t* gizmoHitProportion) + { + if(!Intersects(op, TRANSLATE) || gContext.mbUsing) + { + return MT_NONE; + } + ImGuiIO& io = ImGui::GetIO(); + int type = MT_NONE; + + // screen + if (io.MousePos.x >= gContext.mScreenSquareMin.x && io.MousePos.x <= gContext.mScreenSquareMax.x && + io.MousePos.y >= gContext.mScreenSquareMin.y && io.MousePos.y <= gContext.mScreenSquareMax.y && + Contains(op, TRANSLATE)) + { + type = MT_MOVE_SCREEN; + } + + const vec_t screenCoord = makeVect(io.MousePos - ImVec2(gContext.mX, gContext.mY)); + + // compute + for (unsigned int i = 0; i < 3 && type == MT_NONE; i++) + { + vec_t dirPlaneX, dirPlaneY, dirAxis; + bool belowAxisLimit, belowPlaneLimit; + ComputeTripodAxisAndVisibility(i, dirAxis, dirPlaneX, dirPlaneY, belowAxisLimit, belowPlaneLimit); + dirAxis.TransformVector(gContext.mModel); + dirPlaneX.TransformVector(gContext.mModel); + dirPlaneY.TransformVector(gContext.mModel); + + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, BuildPlan(gContext.mModel.v.position, dirAxis)); + vec_t posOnPlan = gContext.mRayOrigin + gContext.mRayVector * len; + + const ImVec2 axisStartOnScreen = worldToPos(gContext.mModel.v.position + dirAxis * gContext.mScreenFactor * 0.1f, gContext.mViewProjection) - ImVec2(gContext.mX, gContext.mY); + const ImVec2 axisEndOnScreen = worldToPos(gContext.mModel.v.position + dirAxis * gContext.mScreenFactor, gContext.mViewProjection) - ImVec2(gContext.mX, gContext.mY); + + vec_t closestPointOnAxis = PointOnSegment(screenCoord, makeVect(axisStartOnScreen), makeVect(axisEndOnScreen)); + if ((closestPointOnAxis - screenCoord).Length() < 12.f && Intersects(op, static_cast(TRANSLATE_X << i))) // pixel size + { + type = MT_MOVE_X + i; + } + + const float dx = dirPlaneX.Dot3((posOnPlan - gContext.mModel.v.position) * (1.f / gContext.mScreenFactor)); + const float dy = dirPlaneY.Dot3((posOnPlan - gContext.mModel.v.position) * (1.f / gContext.mScreenFactor)); + if (belowPlaneLimit && dx >= quadUV[0] && dx <= quadUV[4] && dy >= quadUV[1] && dy <= quadUV[3] && Contains(op, TRANSLATE_PLANS[i])) + { + type = MT_MOVE_YZ + i; + } + + if (gizmoHitProportion) + { + *gizmoHitProportion = makeVect(dx, dy, 0.f); + } + } + return type; + } + + static bool HandleTranslation(float* matrix, float* deltaMatrix, OPERATION op, int& type, const float* snap) + { + if(!Intersects(op, TRANSLATE) || type != MT_NONE) + { + return false; + } + const ImGuiIO& io = ImGui::GetIO(); + const bool applyRotationLocaly = gContext.mMode == LOCAL || type == MT_MOVE_SCREEN; + bool modified = false; + + // move + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsTranslateType(gContext.mCurrentOperation)) + { + ImGui::CaptureMouseFromApp(); + const float signedLength = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + const float len = fabsf(signedLength); // near plan + const vec_t newPos = gContext.mRayOrigin + gContext.mRayVector * len; + + // compute delta + const vec_t newOrigin = newPos - gContext.mRelativeOrigin * gContext.mScreenFactor; + vec_t delta = newOrigin - gContext.mModel.v.position; + + // 1 axis constraint + if (gContext.mCurrentOperation >= MT_MOVE_X && gContext.mCurrentOperation <= MT_MOVE_Z) + { + const int axisIndex = gContext.mCurrentOperation - MT_MOVE_X; + const vec_t& axisValue = *(vec_t*)&gContext.mModel.m[axisIndex]; + const float lengthOnAxis = Dot(axisValue, delta); + delta = axisValue * lengthOnAxis; + } + + // snap + if (snap) + { + vec_t cumulativeDelta = gContext.mModel.v.position + delta - gContext.mMatrixOrigin; + if (applyRotationLocaly) + { + matrix_t modelSourceNormalized = gContext.mModelSource; + modelSourceNormalized.OrthoNormalize(); + matrix_t modelSourceNormalizedInverse; + modelSourceNormalizedInverse.Inverse(modelSourceNormalized); + cumulativeDelta.TransformVector(modelSourceNormalizedInverse); + ComputeSnap(cumulativeDelta, snap); + cumulativeDelta.TransformVector(modelSourceNormalized); + } + else + { + ComputeSnap(cumulativeDelta, snap); + } + delta = gContext.mMatrixOrigin + cumulativeDelta - gContext.mModel.v.position; + + } + + if (delta != gContext.mTranslationLastDelta) + { + modified = true; + } + gContext.mTranslationLastDelta = delta; + + // compute matrix & delta + matrix_t deltaMatrixTranslation; + deltaMatrixTranslation.Translation(delta); + if (deltaMatrix) + { + memcpy(deltaMatrix, deltaMatrixTranslation.m16, sizeof(float) * 16); + } + + const matrix_t res = gContext.mModelSource * deltaMatrixTranslation; + *(matrix_t*)matrix = res; + + if (!io.MouseDown[0]) + { + gContext.mbUsing = false; + } + + type = gContext.mCurrentOperation; + } + else + { + // find new possible way to move + vec_t gizmoHitProportion; + type = GetMoveType(op, &gizmoHitProportion); + if (type != MT_NONE) + { + ImGui::CaptureMouseFromApp(); + } + if (CanActivate() && type != MT_NONE) + { + gContext.mbUsing = true; + gContext.mEditingID = gContext.mActualID; + gContext.mCurrentOperation = type; + vec_t movePlanNormal[] = { gContext.mModel.v.right, gContext.mModel.v.up, gContext.mModel.v.dir, + gContext.mModel.v.right, gContext.mModel.v.up, gContext.mModel.v.dir, + -gContext.mCameraDir }; + + vec_t cameraToModelNormalized = Normalized(gContext.mModel.v.position - gContext.mCameraEye); + for (unsigned int i = 0; i < 3; i++) + { + vec_t orthoVector = Cross(movePlanNormal[i], cameraToModelNormalized); + movePlanNormal[i].Cross(orthoVector); + movePlanNormal[i].Normalize(); + } + // pickup plan + gContext.mTranslationPlan = BuildPlan(gContext.mModel.v.position, movePlanNormal[type - MT_MOVE_X]); + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + gContext.mTranslationPlanOrigin = gContext.mRayOrigin + gContext.mRayVector * len; + gContext.mMatrixOrigin = gContext.mModel.v.position; + + gContext.mRelativeOrigin = (gContext.mTranslationPlanOrigin - gContext.mModel.v.position) * (1.f / gContext.mScreenFactor); + } + } + return modified; + } + + static bool HandleScale(float* matrix, float* deltaMatrix, OPERATION op, int& type, const float* snap) + { + if((!Intersects(op, SCALE) && !Intersects(op, SCALEU)) || type != MT_NONE) + { + return false; + } + ImGuiIO& io = ImGui::GetIO(); + bool modified = false; + + if (!gContext.mbUsing) + { + // find new possible way to scale + type = GetScaleType(op); + if (type != MT_NONE) + { + ImGui::CaptureMouseFromApp(); + } + if (CanActivate() && type != MT_NONE) + { + gContext.mbUsing = true; + gContext.mEditingID = gContext.mActualID; + gContext.mCurrentOperation = type; + const vec_t movePlanNormal[] = { gContext.mModel.v.up, gContext.mModel.v.dir, gContext.mModel.v.right, gContext.mModel.v.dir, gContext.mModel.v.up, gContext.mModel.v.right, -gContext.mCameraDir }; + // pickup plan + + gContext.mTranslationPlan = BuildPlan(gContext.mModel.v.position, movePlanNormal[type - MT_SCALE_X]); + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + gContext.mTranslationPlanOrigin = gContext.mRayOrigin + gContext.mRayVector * len; + gContext.mMatrixOrigin = gContext.mModel.v.position; + gContext.mScale.Set(1.f, 1.f, 1.f); + gContext.mRelativeOrigin = (gContext.mTranslationPlanOrigin - gContext.mModel.v.position) * (1.f / gContext.mScreenFactor); + gContext.mScaleValueOrigin = makeVect(gContext.mModelSource.v.right.Length(), gContext.mModelSource.v.up.Length(), gContext.mModelSource.v.dir.Length()); + gContext.mSaveMousePosx = io.MousePos.x; + } + } + // scale + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsScaleType(gContext.mCurrentOperation)) + { + ImGui::CaptureMouseFromApp(); + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + vec_t newPos = gContext.mRayOrigin + gContext.mRayVector * len; + vec_t newOrigin = newPos - gContext.mRelativeOrigin * gContext.mScreenFactor; + vec_t delta = newOrigin - gContext.mModelLocal.v.position; + + // 1 axis constraint + if (gContext.mCurrentOperation >= MT_SCALE_X && gContext.mCurrentOperation <= MT_SCALE_Z) + { + int axisIndex = gContext.mCurrentOperation - MT_SCALE_X; + const vec_t& axisValue = *(vec_t*)&gContext.mModelLocal.m[axisIndex]; + float lengthOnAxis = Dot(axisValue, delta); + delta = axisValue * lengthOnAxis; + + vec_t baseVector = gContext.mTranslationPlanOrigin - gContext.mModelLocal.v.position; + float ratio = Dot(axisValue, baseVector + delta) / Dot(axisValue, baseVector); + + gContext.mScale[axisIndex] = max(ratio, 0.001f); + } + else + { + float scaleDelta = (io.MousePos.x - gContext.mSaveMousePosx) * 0.01f; + gContext.mScale.Set(max(1.f + scaleDelta, 0.001f)); + } + + // snap + if (snap) + { + float scaleSnap[] = { snap[0], snap[0], snap[0] }; + ComputeSnap(gContext.mScale, scaleSnap); + } + + // no 0 allowed + for (int i = 0; i < 3; i++) + gContext.mScale[i] = max(gContext.mScale[i], 0.001f); + + if (gContext.mScaleLast != gContext.mScale) + { + modified = true; + } + gContext.mScaleLast = gContext.mScale; + + // compute matrix & delta + matrix_t deltaMatrixScale; + deltaMatrixScale.Scale(gContext.mScale * gContext.mScaleValueOrigin); + + matrix_t res = deltaMatrixScale * gContext.mModelLocal; + *(matrix_t*)matrix = res; + + if (deltaMatrix) + { + vec_t deltaScale = gContext.mScale * gContext.mScaleValueOrigin; + + vec_t originalScaleDivider; + originalScaleDivider.x = 1 / gContext.mModelScaleOrigin.x; + originalScaleDivider.y = 1 / gContext.mModelScaleOrigin.y; + originalScaleDivider.z = 1 / gContext.mModelScaleOrigin.z; + + deltaScale = deltaScale * originalScaleDivider; + + deltaMatrixScale.Scale(deltaScale); + memcpy(deltaMatrix, deltaMatrixScale.m16, sizeof(float) * 16); + } + + if (!io.MouseDown[0]) + { + gContext.mbUsing = false; + gContext.mScale.Set(1.f, 1.f, 1.f); + } + + type = gContext.mCurrentOperation; + } + return modified; + } + + static bool HandleRotation(float* matrix, float* deltaMatrix, OPERATION op, int& type, const float* snap) + { + if(!Intersects(op, ROTATE) || type != MT_NONE) + { + return false; + } + ImGuiIO& io = ImGui::GetIO(); + bool applyRotationLocaly = gContext.mMode == LOCAL; + bool modified = false; + + if (!gContext.mbUsing) + { + type = GetRotateType(op); + + if (type != MT_NONE) + { + ImGui::CaptureMouseFromApp(); + } + + if (type == MT_ROTATE_SCREEN) + { + applyRotationLocaly = true; + } + + if (CanActivate() && type != MT_NONE) + { + gContext.mbUsing = true; + gContext.mEditingID = gContext.mActualID; + gContext.mCurrentOperation = type; + const vec_t rotatePlanNormal[] = { gContext.mModel.v.right, gContext.mModel.v.up, gContext.mModel.v.dir, -gContext.mCameraDir }; + // pickup plan + if (applyRotationLocaly) + { + gContext.mTranslationPlan = BuildPlan(gContext.mModel.v.position, rotatePlanNormal[type - MT_ROTATE_X]); + } + else + { + gContext.mTranslationPlan = BuildPlan(gContext.mModelSource.v.position, directionUnary[type - MT_ROTATE_X]); + } + + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, gContext.mTranslationPlan); + vec_t localPos = gContext.mRayOrigin + gContext.mRayVector * len - gContext.mModel.v.position; + gContext.mRotationVectorSource = Normalized(localPos); + gContext.mRotationAngleOrigin = ComputeAngleOnPlan(); + } + } + + // rotation + if (gContext.mbUsing && (gContext.mActualID == -1 || gContext.mActualID == gContext.mEditingID) && IsRotateType(gContext.mCurrentOperation)) + { + ImGui::CaptureMouseFromApp(); + gContext.mRotationAngle = ComputeAngleOnPlan(); + if (snap) + { + float snapInRadian = snap[0] * DEG2RAD; + ComputeSnap(&gContext.mRotationAngle, snapInRadian); + } + vec_t rotationAxisLocalSpace; + + rotationAxisLocalSpace.TransformVector(makeVect(gContext.mTranslationPlan.x, gContext.mTranslationPlan.y, gContext.mTranslationPlan.z, 0.f), gContext.mModelInverse); + rotationAxisLocalSpace.Normalize(); + + matrix_t deltaRotation; + deltaRotation.RotationAxis(rotationAxisLocalSpace, gContext.mRotationAngle - gContext.mRotationAngleOrigin); + if (gContext.mRotationAngle != gContext.mRotationAngleOrigin) + { + modified = true; + } + gContext.mRotationAngleOrigin = gContext.mRotationAngle; + + matrix_t scaleOrigin; + scaleOrigin.Scale(gContext.mModelScaleOrigin); + + if (applyRotationLocaly) + { + *(matrix_t*)matrix = scaleOrigin * deltaRotation * gContext.mModelLocal; + } + else + { + matrix_t res = gContext.mModelSource; + res.v.position.Set(0.f); + + *(matrix_t*)matrix = res * deltaRotation; + ((matrix_t*)matrix)->v.position = gContext.mModelSource.v.position; + } + + if (deltaMatrix) + { + *(matrix_t*)deltaMatrix = gContext.mModelInverse * deltaRotation * gContext.mModel; + } + + if (!io.MouseDown[0]) + { + gContext.mbUsing = false; + gContext.mEditingID = -1; + } + type = gContext.mCurrentOperation; + } + return modified; + } + + void DecomposeMatrixToComponents(const float* matrix, float* translation, float* rotation, float* scale) + { + matrix_t mat = *(matrix_t*)matrix; + + scale[0] = mat.v.right.Length(); + scale[1] = mat.v.up.Length(); + scale[2] = mat.v.dir.Length(); + + mat.OrthoNormalize(); + + rotation[0] = RAD2DEG * atan2f(mat.m[1][2], mat.m[2][2]); + rotation[1] = RAD2DEG * atan2f(-mat.m[0][2], sqrtf(mat.m[1][2] * mat.m[1][2] + mat.m[2][2] * mat.m[2][2])); + rotation[2] = RAD2DEG * atan2f(mat.m[0][1], mat.m[0][0]); + + translation[0] = mat.v.position.x; + translation[1] = mat.v.position.y; + translation[2] = mat.v.position.z; + } + + void RecomposeMatrixFromComponents(const float* translation, const float* rotation, const float* scale, float* matrix) + { + matrix_t& mat = *(matrix_t*)matrix; + + matrix_t rot[3]; + for (int i = 0; i < 3; i++) + { + rot[i].RotationAxis(directionUnary[i], rotation[i] * DEG2RAD); + } + + mat = rot[0] * rot[1] * rot[2]; + + float validScale[3]; + for (int i = 0; i < 3; i++) + { + if (fabsf(scale[i]) < FLT_EPSILON) + { + validScale[i] = 0.001f; + } + else + { + validScale[i] = scale[i]; + } + } + mat.v.right *= validScale[0]; + mat.v.up *= validScale[1]; + mat.v.dir *= validScale[2]; + mat.v.position.Set(translation[0], translation[1], translation[2], 1.f); + } + + void SetID(int id) + { + gContext.mActualID = id; + } + + void AllowAxisFlip(bool value) + { + gContext.mAllowAxisFlip = value; + } + + bool Manipulate(const float* view, const float* projection, OPERATION operation, MODE mode, float* matrix, float* deltaMatrix, const float* snap, const float* localBounds, const float* boundsSnap) + { + // Scale is always local or matrix will be skewed when applying world scale or oriented matrix + ComputeContext(view, projection, matrix, (operation & SCALE) ? LOCAL : mode); + + // set delta to identity + if (deltaMatrix) + { + ((matrix_t*)deltaMatrix)->SetToIdentity(); + } + + // behind camera + vec_t camSpacePosition; + camSpacePosition.TransformPoint(makeVect(0.f, 0.f, 0.f), gContext.mMVP); + if (!gContext.mIsOrthographic && camSpacePosition.z < 0.001f) + { + return false; + } + + // -- + int type = MT_NONE; + bool manipulated = false; + if (gContext.mbEnable) + { + if (!gContext.mbUsingBounds) + { + manipulated = HandleTranslation(matrix, deltaMatrix, operation, type, snap) || + HandleScale(matrix, deltaMatrix, operation, type, snap) || + HandleRotation(matrix, deltaMatrix, operation, type, snap); + } + } + + if (localBounds && !gContext.mbUsing) + { + HandleAndDrawLocalBounds(localBounds, (matrix_t*)matrix, boundsSnap, operation); + } + + gContext.mOperation = operation; + if (!gContext.mbUsingBounds) + { + DrawRotationGizmo(operation, type); + DrawTranslationGizmo(operation, type); + DrawScaleGizmo(operation, type); + DrawScaleUniveralGizmo(operation, type); + } + return manipulated; + } + + void SetGizmoSizeClipSpace(float value) + { + gContext.mGizmoSizeClipSpace = value; + } + + /////////////////////////////////////////////////////////////////////////////////////////////////// + void ComputeFrustumPlanes(vec_t* frustum, const float* clip) + { + frustum[0].x = clip[3] - clip[0]; + frustum[0].y = clip[7] - clip[4]; + frustum[0].z = clip[11] - clip[8]; + frustum[0].w = clip[15] - clip[12]; + + frustum[1].x = clip[3] + clip[0]; + frustum[1].y = clip[7] + clip[4]; + frustum[1].z = clip[11] + clip[8]; + frustum[1].w = clip[15] + clip[12]; + + frustum[2].x = clip[3] + clip[1]; + frustum[2].y = clip[7] + clip[5]; + frustum[2].z = clip[11] + clip[9]; + frustum[2].w = clip[15] + clip[13]; + + frustum[3].x = clip[3] - clip[1]; + frustum[3].y = clip[7] - clip[5]; + frustum[3].z = clip[11] - clip[9]; + frustum[3].w = clip[15] - clip[13]; + + frustum[4].x = clip[3] - clip[2]; + frustum[4].y = clip[7] - clip[6]; + frustum[4].z = clip[11] - clip[10]; + frustum[4].w = clip[15] - clip[14]; + + frustum[5].x = clip[3] + clip[2]; + frustum[5].y = clip[7] + clip[6]; + frustum[5].z = clip[11] + clip[10]; + frustum[5].w = clip[15] + clip[14]; + + for (int i = 0; i < 6; i++) + { + frustum[i].Normalize(); + } + } + + void DrawCubes(const float* view, const float* projection, const float* matrices, int matrixCount) + { + matrix_t viewInverse; + viewInverse.Inverse(*(matrix_t*)view); + + struct CubeFace + { + float z; + ImVec2 faceCoordsScreen[4]; + ImU32 color; + }; + CubeFace* faces = (CubeFace*)_malloca(sizeof(CubeFace) * matrixCount * 6); + + if (!faces) + { + return; + } + + vec_t frustum[6]; + matrix_t viewProjection = *(matrix_t*)view * *(matrix_t*)projection; + ComputeFrustumPlanes(frustum, viewProjection.m16); + + int cubeFaceCount = 0; + for (int cube = 0; cube < matrixCount; cube++) + { + const float* matrix = &matrices[cube * 16]; + + matrix_t res = *(matrix_t*)matrix * *(matrix_t*)view * *(matrix_t*)projection; + + for (int iFace = 0; iFace < 6; iFace++) + { + const int normalIndex = (iFace % 3); + const int perpXIndex = (normalIndex + 1) % 3; + const int perpYIndex = (normalIndex + 2) % 3; + const float invert = (iFace > 2) ? -1.f : 1.f; + + const vec_t faceCoords[4] = { directionUnary[normalIndex] + directionUnary[perpXIndex] + directionUnary[perpYIndex], + directionUnary[normalIndex] + directionUnary[perpXIndex] - directionUnary[perpYIndex], + directionUnary[normalIndex] - directionUnary[perpXIndex] - directionUnary[perpYIndex], + directionUnary[normalIndex] - directionUnary[perpXIndex] + directionUnary[perpYIndex], + }; + + // clipping + /* + bool skipFace = false; + for (unsigned int iCoord = 0; iCoord < 4; iCoord++) + { + vec_t camSpacePosition; + camSpacePosition.TransformPoint(faceCoords[iCoord] * 0.5f * invert, res); + if (camSpacePosition.z < 0.001f) + { + skipFace = true; + break; + } + } + if (skipFace) + { + continue; + } + */ + vec_t centerPosition, centerPositionVP; + centerPosition.TransformPoint(directionUnary[normalIndex] * 0.5f * invert, *(matrix_t*)matrix); + centerPositionVP.TransformPoint(directionUnary[normalIndex] * 0.5f * invert, res); + + bool inFrustum = true; + for (int iFrustum = 0; iFrustum < 6; iFrustum++) + { + float dist = DistanceToPlane(centerPosition, frustum[iFrustum]); + if (dist < 0.f) + { + inFrustum = false; + break; + } + } + + if (!inFrustum) + { + continue; + } + CubeFace& cubeFace = faces[cubeFaceCount]; + + // 3D->2D + //ImVec2 faceCoordsScreen[4]; + for (unsigned int iCoord = 0; iCoord < 4; iCoord++) + { + cubeFace.faceCoordsScreen[iCoord] = worldToPos(faceCoords[iCoord] * 0.5f * invert, res); + } + cubeFace.color = directionColor[normalIndex] | IM_COL32(0x80, 0x80, 0x80, 0); + + cubeFace.z = centerPositionVP.z / centerPositionVP.w; + cubeFaceCount++; + } + } + qsort(faces, cubeFaceCount, sizeof(CubeFace), [](void const* _a, void const* _b) { + CubeFace* a = (CubeFace*)_a; + CubeFace* b = (CubeFace*)_b; + if (a->z < b->z) + { + return 1; + } + return -1; + }); + // draw face with lighter color + for (int iFace = 0; iFace < cubeFaceCount; iFace++) + { + const CubeFace& cubeFace = faces[iFace]; + gContext.mDrawList->AddConvexPolyFilled(cubeFace.faceCoordsScreen, 4, cubeFace.color); + } + + _freea(faces); + } + + void DrawGrid(const float* view, const float* projection, const float* matrix, const float gridSize) + { + matrix_t viewProjection = *(matrix_t*)view * *(matrix_t*)projection; + vec_t frustum[6]; + ComputeFrustumPlanes(frustum, viewProjection.m16); + matrix_t res = *(matrix_t*)matrix * viewProjection; + + for (float f = -gridSize; f <= gridSize; f += 1.f) + { + for (int dir = 0; dir < 2; dir++) + { + vec_t ptA = makeVect(dir ? -gridSize : f, 0.f, dir ? f : -gridSize); + vec_t ptB = makeVect(dir ? gridSize : f, 0.f, dir ? f : gridSize); + bool visible = true; + for (int i = 0; i < 6; i++) + { + float dA = DistanceToPlane(ptA, frustum[i]); + float dB = DistanceToPlane(ptB, frustum[i]); + if (dA < 0.f && dB < 0.f) + { + visible = false; + break; + } + if (dA > 0.f && dB > 0.f) + { + continue; + } + if (dA < 0.f) + { + float len = fabsf(dA - dB); + float t = fabsf(dA) / len; + ptA.Lerp(ptB, t); + } + if (dB < 0.f) + { + float len = fabsf(dB - dA); + float t = fabsf(dB) / len; + ptB.Lerp(ptA, t); + } + } + if (visible) + { + ImU32 col = IM_COL32(0x80, 0x80, 0x80, 0xFF); + col = (fmodf(fabsf(f), 10.f) < FLT_EPSILON) ? IM_COL32(0x90, 0x90, 0x90, 0xFF) : col; + col = (fabsf(f) < FLT_EPSILON) ? IM_COL32(0x40, 0x40, 0x40, 0xFF): col; + + float thickness = 1.f; + thickness = (fmodf(fabsf(f), 10.f) < FLT_EPSILON) ? 1.5f : thickness; + thickness = (fabsf(f) < FLT_EPSILON) ? 2.3f : thickness; + + gContext.mDrawList->AddLine(worldToPos(ptA, res), worldToPos(ptB, res), col, thickness); + } + } + } + } + + void ViewManipulate(float* view, float length, ImVec2 position, ImVec2 size, ImU32 backgroundColor) + { + static bool isDraging = false; + static bool isClicking = false; + static bool isInside = false; + static vec_t interpolationUp; + static vec_t interpolationDir; + static int interpolationFrames = 0; + const vec_t referenceUp = makeVect(0.f, 1.f, 0.f); + + matrix_t svgView, svgProjection; + svgView = gContext.mViewMat; + svgProjection = gContext.mProjectionMat; + + ImGuiIO& io = ImGui::GetIO(); + gContext.mDrawList->AddRectFilled(position, position + size, backgroundColor); + matrix_t viewInverse; + viewInverse.Inverse(*(matrix_t*)view); + + const vec_t camTarget = viewInverse.v.position - viewInverse.v.dir * length; + + // view/projection matrices + const float distance = 3.f; + matrix_t cubeProjection, cubeView; + float fov = acosf(distance / (sqrtf(distance * distance + 3.f))) * RAD2DEG; + Perspective(fov / sqrtf(2.f), size.x / size.y, 0.01f, 1000.f, cubeProjection.m16); + + vec_t dir = makeVect(viewInverse.m[2][0], viewInverse.m[2][1], viewInverse.m[2][2]); + vec_t up = makeVect(viewInverse.m[1][0], viewInverse.m[1][1], viewInverse.m[1][2]); + vec_t eye = dir * distance; + vec_t zero = makeVect(0.f, 0.f); + LookAt(&eye.x, &zero.x, &up.x, cubeView.m16); + + // set context + gContext.mViewMat = cubeView; + gContext.mProjectionMat = cubeProjection; + ComputeCameraRay(gContext.mRayOrigin, gContext.mRayVector, position, size); + + const matrix_t res = cubeView * cubeProjection; + + // panels + static const ImVec2 panelPosition[9] = { ImVec2(0.75f,0.75f), ImVec2(0.25f, 0.75f), ImVec2(0.f, 0.75f), + ImVec2(0.75f, 0.25f), ImVec2(0.25f, 0.25f), ImVec2(0.f, 0.25f), + ImVec2(0.75f, 0.f), ImVec2(0.25f, 0.f), ImVec2(0.f, 0.f) }; + + static const ImVec2 panelSize[9] = { ImVec2(0.25f,0.25f), ImVec2(0.5f, 0.25f), ImVec2(0.25f, 0.25f), + ImVec2(0.25f, 0.5f), ImVec2(0.5f, 0.5f), ImVec2(0.25f, 0.5f), + ImVec2(0.25f, 0.25f), ImVec2(0.5f, 0.25f), ImVec2(0.25f, 0.25f) }; + + // tag faces + bool boxes[27]{}; + for (int iPass = 0; iPass < 2; iPass++) + { + for (int iFace = 0; iFace < 6; iFace++) + { + const int normalIndex = (iFace % 3); + const int perpXIndex = (normalIndex + 1) % 3; + const int perpYIndex = (normalIndex + 2) % 3; + const float invert = (iFace > 2) ? -1.f : 1.f; + const vec_t indexVectorX = directionUnary[perpXIndex] * invert; + const vec_t indexVectorY = directionUnary[perpYIndex] * invert; + const vec_t boxOrigin = directionUnary[normalIndex] * -invert - indexVectorX - indexVectorY; + + // plan local space + const vec_t n = directionUnary[normalIndex] * invert; + vec_t viewSpaceNormal = n; + vec_t viewSpacePoint = n * 0.5f; + viewSpaceNormal.TransformVector(cubeView); + viewSpaceNormal.Normalize(); + viewSpacePoint.TransformPoint(cubeView); + const vec_t viewSpaceFacePlan = BuildPlan(viewSpacePoint, viewSpaceNormal); + + // back face culling + if (viewSpaceFacePlan.w > 0.f) + { + continue; + } + + const vec_t facePlan = BuildPlan(n * 0.5f, n); + + const float len = IntersectRayPlane(gContext.mRayOrigin, gContext.mRayVector, facePlan); + vec_t posOnPlan = gContext.mRayOrigin + gContext.mRayVector * len - (n * 0.5f); + + float localx = Dot(directionUnary[perpXIndex], posOnPlan) * invert + 0.5f; + float localy = Dot(directionUnary[perpYIndex], posOnPlan) * invert + 0.5f; + + // panels + const vec_t dx = directionUnary[perpXIndex]; + const vec_t dy = directionUnary[perpYIndex]; + const vec_t origin = directionUnary[normalIndex] - dx - dy; + for (int iPanel = 0; iPanel < 9; iPanel++) + { + vec_t boxCoord = boxOrigin + indexVectorX * float(iPanel % 3) + indexVectorY * float(iPanel / 3) + makeVect(1.f, 1.f, 1.f); + const ImVec2 p = panelPosition[iPanel] * 2.f; + const ImVec2 s = panelSize[iPanel] * 2.f; + ImVec2 faceCoordsScreen[4]; + vec_t panelPos[4] = { dx * p.x + dy * p.y, + dx * p.x + dy * (p.y + s.y), + dx * (p.x + s.x) + dy * (p.y + s.y), + dx * (p.x + s.x) + dy * p.y }; + + for (unsigned int iCoord = 0; iCoord < 4; iCoord++) + { + faceCoordsScreen[iCoord] = worldToPos((panelPos[iCoord] + origin) * 0.5f * invert, res, position, size); + } + + const ImVec2 panelCorners[2] = { panelPosition[iPanel], panelPosition[iPanel] + panelSize[iPanel] }; + bool insidePanel = localx > panelCorners[0].x && localx < panelCorners[1].x&& localy > panelCorners[0].y && localy < panelCorners[1].y; + int boxCoordInt = int(boxCoord.x * 9.f + boxCoord.y * 3.f + boxCoord.z); + assert(boxCoordInt < 27); + boxes[boxCoordInt] |= insidePanel && (!isDraging); + + // draw face with lighter color + if (iPass) + { + gContext.mDrawList->AddConvexPolyFilled(faceCoordsScreen, 4, (directionColor[normalIndex] | IM_COL32(0x80, 0x80, 0x80, 0x80)) | (isInside ? IM_COL32(0x08, 0x08, 0x08, 0) : 0)); + if (boxes[boxCoordInt]) + { + gContext.mDrawList->AddConvexPolyFilled(faceCoordsScreen, 4, IM_COL32(0xF0, 0xA0, 0x60, 0x80)); + + if (!io.MouseDown[0] && !isDraging && isClicking) + { + // apply new view direction + int cx = boxCoordInt / 9; + int cy = (boxCoordInt - cx * 9) / 3; + int cz = boxCoordInt % 3; + interpolationDir = makeVect(1.f - cx, 1.f - cy, 1.f - cz); + interpolationDir.Normalize(); + + if (fabsf(Dot(interpolationDir, referenceUp)) > 1.0f - 0.01f) + { + vec_t right = viewInverse.v.right; + if (fabsf(right.x) > fabsf(right.z)) + { + right.z = 0.f; + } + else + { + right.x = 0.f; + } + right.Normalize(); + interpolationUp = Cross(interpolationDir, right); + interpolationUp.Normalize(); + } + else + { + interpolationUp = referenceUp; + } + interpolationFrames = 40; + isClicking = false; + } + if (io.MouseDown[0] && !isDraging) + { + isClicking = true; + } + } + } + } + } + } + if (interpolationFrames) + { + interpolationFrames--; + vec_t newDir = viewInverse.v.dir; + newDir.Lerp(interpolationDir, 0.2f); + newDir.Normalize(); + + vec_t newUp = viewInverse.v.up; + newUp.Lerp(interpolationUp, 0.3f); + newUp.Normalize(); + newUp = interpolationUp; + vec_t newEye = camTarget + newDir * length; + LookAt(&newEye.x, &camTarget.x, &newUp.x, view); + } + isInside = ImRect(position, position + size).Contains(io.MousePos); + + // drag view + if (!isDraging && io.MouseDown[0] && isInside && (fabsf(io.MouseDelta.x) > 0.f || fabsf(io.MouseDelta.y) > 0.f)) + { + isDraging = true; + isClicking = false; + } + else if (isDraging && !io.MouseDown[0]) + { + isDraging = false; + } + + if (isDraging) + { + matrix_t rx, ry, roll; + + rx.RotationAxis(referenceUp, -io.MouseDelta.x * 0.01f); + ry.RotationAxis(viewInverse.v.right, -io.MouseDelta.y * 0.01f); + + roll = rx * ry; + + vec_t newDir = viewInverse.v.dir; + newDir.TransformVector(roll); + newDir.Normalize(); + + // clamp + vec_t planDir = Cross(viewInverse.v.right, referenceUp); + planDir.y = 0.f; + planDir.Normalize(); + float dt = Dot(planDir, newDir); + if (dt < 0.0f) + { + newDir += planDir * dt; + newDir.Normalize(); + } + + vec_t newEye = camTarget + newDir * length; + LookAt(&newEye.x, &camTarget.x, &referenceUp.x, view); + } + + // restore view/projection because it was used to compute ray + ComputeContext(svgView.m16, svgProjection.m16, gContext.mModelSource.m16, gContext.mMode); + } +}; diff --git a/gui/dependencies/imguizmo/ImGuizmo.h b/gui/dependencies/imguizmo/ImGuizmo.h new file mode 100644 index 0000000000000000000000000000000000000000..8859c61bba1acfc8cc4e639d30380e29a85f4a30 --- /dev/null +++ b/gui/dependencies/imguizmo/ImGuizmo.h @@ -0,0 +1,223 @@ +// https://github.com/CedricGuillemet/ImGuizmo +// v 1.84 WIP +// +// The MIT License(MIT) +// +// Copyright(c) 2021 Cedric Guillemet +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files(the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions : +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// ------------------------------------------------------------------------------------------- +// History : +// 2019/11/03 View gizmo +// 2016/09/11 Behind camera culling. Scaling Delta matrix not multiplied by source matrix scales. local/world rotation and translation fixed. Display message is incorrect (X: ... Y:...) in local mode. +// 2016/09/09 Hatched negative axis. Snapping. Documentation update. +// 2016/09/04 Axis switch and translation plan autohiding. Scale transform stability improved +// 2016/09/01 Mogwai changed to Manipulate. Draw debug cube. Fixed inverted scale. Mixing scale and translation/rotation gives bad results. +// 2016/08/31 First version +// +// ------------------------------------------------------------------------------------------- +// Future (no order): +// +// - Multi view +// - display rotation/translation/scale infos in local/world space and not only local +// - finish local/world matrix application +// - OPERATION as bitmask +// +// ------------------------------------------------------------------------------------------- +// Example +#if 0 +void EditTransform(const Camera& camera, matrix_t& matrix) +{ + static ImGuizmo::OPERATION mCurrentGizmoOperation(ImGuizmo::ROTATE); + static ImGuizmo::MODE mCurrentGizmoMode(ImGuizmo::WORLD); + if (ImGui::IsKeyPressed(90)) + mCurrentGizmoOperation = ImGuizmo::TRANSLATE; + if (ImGui::IsKeyPressed(69)) + mCurrentGizmoOperation = ImGuizmo::ROTATE; + if (ImGui::IsKeyPressed(82)) // r Key + mCurrentGizmoOperation = ImGuizmo::SCALE; + if (ImGui::RadioButton("Translate", mCurrentGizmoOperation == ImGuizmo::TRANSLATE)) + mCurrentGizmoOperation = ImGuizmo::TRANSLATE; + ImGui::SameLine(); + if (ImGui::RadioButton("Rotate", mCurrentGizmoOperation == ImGuizmo::ROTATE)) + mCurrentGizmoOperation = ImGuizmo::ROTATE; + ImGui::SameLine(); + if (ImGui::RadioButton("Scale", mCurrentGizmoOperation == ImGuizmo::SCALE)) + mCurrentGizmoOperation = ImGuizmo::SCALE; + float matrixTranslation[3], matrixRotation[3], matrixScale[3]; + ImGuizmo::DecomposeMatrixToComponents(matrix.m16, matrixTranslation, matrixRotation, matrixScale); + ImGui::InputFloat3("Tr", matrixTranslation, 3); + ImGui::InputFloat3("Rt", matrixRotation, 3); + ImGui::InputFloat3("Sc", matrixScale, 3); + ImGuizmo::RecomposeMatrixFromComponents(matrixTranslation, matrixRotation, matrixScale, matrix.m16); + + if (mCurrentGizmoOperation != ImGuizmo::SCALE) + { + if (ImGui::RadioButton("Local", mCurrentGizmoMode == ImGuizmo::LOCAL)) + mCurrentGizmoMode = ImGuizmo::LOCAL; + ImGui::SameLine(); + if (ImGui::RadioButton("World", mCurrentGizmoMode == ImGuizmo::WORLD)) + mCurrentGizmoMode = ImGuizmo::WORLD; + } + static bool useSnap(false); + if (ImGui::IsKeyPressed(83)) + useSnap = !useSnap; + ImGui::Checkbox("", &useSnap); + ImGui::SameLine(); + vec_t snap; + switch (mCurrentGizmoOperation) + { + case ImGuizmo::TRANSLATE: + snap = config.mSnapTranslation; + ImGui::InputFloat3("Snap", &snap.x); + break; + case ImGuizmo::ROTATE: + snap = config.mSnapRotation; + ImGui::InputFloat("Angle Snap", &snap.x); + break; + case ImGuizmo::SCALE: + snap = config.mSnapScale; + ImGui::InputFloat("Scale Snap", &snap.x); + break; + } + ImGuiIO& io = ImGui::GetIO(); + ImGuizmo::SetRect(0, 0, io.DisplaySize.x, io.DisplaySize.y); + ImGuizmo::Manipulate(camera.mView.m16, camera.mProjection.m16, mCurrentGizmoOperation, mCurrentGizmoMode, matrix.m16, NULL, useSnap ? &snap.x : NULL); +} +#endif +#pragma once + +#ifdef USE_IMGUI_API +#include "imconfig.h" +#endif +#ifndef IMGUI_API +#define IMGUI_API +#endif + +#ifndef IMGUIZMO_NAMESPACE +#define IMGUIZMO_NAMESPACE ImGuizmo +#endif + +namespace IMGUIZMO_NAMESPACE +{ + // call inside your own window and before Manipulate() in order to draw gizmo to that window. + // Or pass a specific ImDrawList to draw to (e.g. ImGui::GetForegroundDrawList()). + IMGUI_API void SetDrawlist(ImDrawList* drawlist = nullptr); + + // call BeginFrame right after ImGui_XXXX_NewFrame(); + IMGUI_API void BeginFrame(); + + // this is necessary because when imguizmo is compiled into a dll, and imgui into another + // globals are not shared between them. + // More details at https://stackoverflow.com/questions/19373061/what-happens-to-global-and-static-variables-in-a-shared-library-when-it-is-dynam + // expose method to set imgui context + IMGUI_API void SetImGuiContext(ImGuiContext* ctx); + + // return true if mouse cursor is over any gizmo control (axis, plan or screen component) + IMGUI_API bool IsOver(); + + // return true if mouse IsOver or if the gizmo is in moving state + IMGUI_API bool IsUsing(); + + // enable/disable the gizmo. Stay in the state until next call to Enable. + // gizmo is rendered with gray half transparent color when disabled + IMGUI_API void Enable(bool enable); + + // helper functions for manualy editing translation/rotation/scale with an input float + // translation, rotation and scale float points to 3 floats each + // Angles are in degrees (more suitable for human editing) + // example: + // float matrixTranslation[3], matrixRotation[3], matrixScale[3]; + // ImGuizmo::DecomposeMatrixToComponents(gizmoMatrix.m16, matrixTranslation, matrixRotation, matrixScale); + // ImGui::InputFloat3("Tr", matrixTranslation, 3); + // ImGui::InputFloat3("Rt", matrixRotation, 3); + // ImGui::InputFloat3("Sc", matrixScale, 3); + // ImGuizmo::RecomposeMatrixFromComponents(matrixTranslation, matrixRotation, matrixScale, gizmoMatrix.m16); + // + // These functions have some numerical stability issues for now. Use with caution. + IMGUI_API void DecomposeMatrixToComponents(const float* matrix, float* translation, float* rotation, float* scale); + IMGUI_API void RecomposeMatrixFromComponents(const float* translation, const float* rotation, const float* scale, float* matrix); + + IMGUI_API void SetRect(float x, float y, float width, float height); + // default is false + IMGUI_API void SetOrthographic(bool isOrthographic); + + // Render a cube with face color corresponding to face normal. Usefull for debug/tests + IMGUI_API void DrawCubes(const float* view, const float* projection, const float* matrices, int matrixCount); + IMGUI_API void DrawGrid(const float* view, const float* projection, const float* matrix, const float gridSize); + + // call it when you want a gizmo + // Needs view and projection matrices. + // matrix parameter is the source matrix (where will be gizmo be drawn) and might be transformed by the function. Return deltaMatrix is optional + // translation is applied in world space + enum OPERATION + { + TRANSLATE_X = (1u << 0), + TRANSLATE_Y = (1u << 1), + TRANSLATE_Z = (1u << 2), + ROTATE_X = (1u << 3), + ROTATE_Y = (1u << 4), + ROTATE_Z = (1u << 5), + ROTATE_SCREEN = (1u << 6), + SCALE_X = (1u << 7), + SCALE_Y = (1u << 8), + SCALE_Z = (1u << 9), + BOUNDS = (1u << 10), + SCALE_XU = (1u << 11), + SCALE_YU = (1u << 12), + SCALE_ZU = (1u << 13), + + TRANSLATE = TRANSLATE_X | TRANSLATE_Y | TRANSLATE_Z, + ROTATE = ROTATE_X | ROTATE_Y | ROTATE_Z | ROTATE_SCREEN, + SCALE = SCALE_X | SCALE_Y | SCALE_Z, + SCALEU = SCALE_XU | SCALE_YU | SCALE_ZU, // universal + UNIVERSAL = TRANSLATE | ROTATE | SCALEU + }; + + inline OPERATION operator|(OPERATION lhs, OPERATION rhs) + { + return static_cast(static_cast(lhs) | static_cast(rhs)); + } + + enum MODE + { + LOCAL, + WORLD + }; + + IMGUI_API bool Manipulate(const float* view, const float* projection, OPERATION operation, MODE mode, float* matrix, float* deltaMatrix = NULL, const float* snap = NULL, const float* localBounds = NULL, const float* boundsSnap = NULL); + // + // Please note that this cubeview is patented by Autodesk : https://patents.google.com/patent/US7782319B2/en + // It seems to be a defensive patent in the US. I don't think it will bring troubles using it as + // other software are using the same mechanics. But just in case, you are now warned! + // + IMGUI_API void ViewManipulate(float* view, float length, ImVec2 position, ImVec2 size, ImU32 backgroundColor); + + IMGUI_API void SetID(int id); + + // return true if the cursor is over the operation's gizmo + IMGUI_API bool IsOver(OPERATION op); + IMGUI_API void SetGizmoSizeClipSpace(float value); + + // Allow axis to flip + // When true (default), the guizmo axis flip for better visibility + // When false, they always stay along the positive world/local axis + IMGUI_API void AllowAxisFlip(bool value); +} diff --git a/gui/dependencies/json/LICENSE.MIT b/gui/dependencies/json/LICENSE.MIT new file mode 100644 index 0000000000000000000000000000000000000000..f0622d6dc24895e62ca570419872fdb19e064f79 --- /dev/null +++ b/gui/dependencies/json/LICENSE.MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2013-2021 Niels Lohmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/gui/dependencies/json/json.hpp b/gui/dependencies/json/json.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58c7a1d478c350a3cc1d6a5b68fe85b46d4f8b95 --- /dev/null +++ b/gui/dependencies/json/json.hpp @@ -0,0 +1,26713 @@ +/* + __ _____ _____ _____ + __| | __| | | | JSON for Modern C++ +| | |__ | | | | | | version 3.10.4 +|_____|_____|_____|_|___| https://github.com/nlohmann/json + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2013-2019 Niels Lohmann . + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef INCLUDE_NLOHMANN_JSON_HPP_ +#define INCLUDE_NLOHMANN_JSON_HPP_ + +#define NLOHMANN_JSON_VERSION_MAJOR 3 +#define NLOHMANN_JSON_VERSION_MINOR 10 +#define NLOHMANN_JSON_VERSION_PATCH 4 + +#include // all_of, find, for_each +#include // nullptr_t, ptrdiff_t, size_t +#include // hash, less +#include // initializer_list +#ifndef JSON_NO_IO + #include // istream, ostream +#endif // JSON_NO_IO +#include // random_access_iterator_tag +#include // unique_ptr +#include // accumulate +#include // string, stoi, to_string +#include // declval, forward, move, pair, swap +#include // vector + +// #include + + +#include +#include + +// #include + + +#include // transform +#include // array +#include // forward_list +#include // inserter, front_inserter, end +#include // map +#include // string +#include // tuple, make_tuple +#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible +#include // unordered_map +#include // pair, declval +#include // valarray + +// #include + + +#include // exception +#include // runtime_error +#include // to_string +#include // vector + +// #include + + +#include // array +#include // size_t +#include // uint8_t +#include // string + +namespace nlohmann +{ +namespace detail +{ +/////////////////////////// +// JSON type enumeration // +/////////////////////////// + +/*! +@brief the JSON type enumeration + +This enumeration collects the different JSON types. It is internally used to +distinguish the stored values, and the functions @ref basic_json::is_null(), +@ref basic_json::is_object(), @ref basic_json::is_array(), +@ref basic_json::is_string(), @ref basic_json::is_boolean(), +@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), +@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), +@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and +@ref basic_json::is_structured() rely on it. + +@note There are three enumeration entries (number_integer, number_unsigned, and +number_float), because the library distinguishes these three types for numbers: +@ref basic_json::number_unsigned_t is used for unsigned integers, +@ref basic_json::number_integer_t is used for signed integers, and +@ref basic_json::number_float_t is used for floating-point numbers or to +approximate integers which do not fit in the limits of their respective type. + +@sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON +value with the default value for a given type + +@since version 1.0.0 +*/ +enum class value_t : std::uint8_t +{ + null, ///< null value + object, ///< object (unordered set of name/value pairs) + array, ///< array (ordered collection of values) + string, ///< string value + boolean, ///< boolean value + number_integer, ///< number value (signed integer) + number_unsigned, ///< number value (unsigned integer) + number_float, ///< number value (floating-point) + binary, ///< binary array (ordered collection of bytes) + discarded ///< discarded by the parser callback function +}; + +/*! +@brief comparison operator for JSON types + +Returns an ordering that is similar to Python: +- order: null < boolean < number < object < array < string < binary +- furthermore, each type is not smaller than itself +- discarded values are not comparable +- binary is represented as a b"" string in python and directly comparable to a + string; however, making a binary array directly comparable with a string would + be surprising behavior in a JSON file. + +@since version 1.0.0 +*/ +inline bool operator<(const value_t lhs, const value_t rhs) noexcept +{ + static constexpr std::array order = {{ + 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, + 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, + 6 /* binary */ + } + }; + + const auto l_index = static_cast(lhs); + const auto r_index = static_cast(rhs); + return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; +} +} // namespace detail +} // namespace nlohmann + +// #include + + +#include +// #include + + +#include // declval, pair +// #include + + +/* Hedley - https://nemequ.github.io/hedley + * Created by Evan Nemerson + * + * To the extent possible under law, the author(s) have dedicated all + * copyright and related and neighboring rights to this software to + * the public domain worldwide. This software is distributed without + * any warranty. + * + * For details, see . + * SPDX-License-Identifier: CC0-1.0 + */ + +#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15) +#if defined(JSON_HEDLEY_VERSION) + #undef JSON_HEDLEY_VERSION +#endif +#define JSON_HEDLEY_VERSION 15 + +#if defined(JSON_HEDLEY_STRINGIFY_EX) + #undef JSON_HEDLEY_STRINGIFY_EX +#endif +#define JSON_HEDLEY_STRINGIFY_EX(x) #x + +#if defined(JSON_HEDLEY_STRINGIFY) + #undef JSON_HEDLEY_STRINGIFY +#endif +#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) + +#if defined(JSON_HEDLEY_CONCAT_EX) + #undef JSON_HEDLEY_CONCAT_EX +#endif +#define JSON_HEDLEY_CONCAT_EX(a,b) a##b + +#if defined(JSON_HEDLEY_CONCAT) + #undef JSON_HEDLEY_CONCAT +#endif +#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) + +#if defined(JSON_HEDLEY_CONCAT3_EX) + #undef JSON_HEDLEY_CONCAT3_EX +#endif +#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c + +#if defined(JSON_HEDLEY_CONCAT3) + #undef JSON_HEDLEY_CONCAT3 +#endif +#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) + +#if defined(JSON_HEDLEY_VERSION_ENCODE) + #undef JSON_HEDLEY_VERSION_ENCODE +#endif +#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) + #undef JSON_HEDLEY_VERSION_DECODE_MAJOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) + #undef JSON_HEDLEY_VERSION_DECODE_MINOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) + #undef JSON_HEDLEY_VERSION_DECODE_REVISION +#endif +#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) + +#if defined(JSON_HEDLEY_GNUC_VERSION) + #undef JSON_HEDLEY_GNUC_VERSION +#endif +#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#elif defined(__GNUC__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) +#endif + +#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) + #undef JSON_HEDLEY_GNUC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GNUC_VERSION) + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION) + #undef JSON_HEDLEY_MSVC_VERSION +#endif +#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) +#elif defined(_MSC_FULL_VER) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) +#elif defined(_MSC_VER) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) + #undef JSON_HEDLEY_MSVC_VERSION_CHECK +#endif +#if !defined(JSON_HEDLEY_MSVC_VERSION) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) +#elif defined(_MSC_VER) && (_MSC_VER >= 1400) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) +#elif defined(_MSC_VER) && (_MSC_VER >= 1200) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) +#else + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION) + #undef JSON_HEDLEY_INTEL_VERSION +#endif +#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) +#elif defined(__INTEL_COMPILER) && !defined(__ICL) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) + #undef JSON_HEDLEY_INTEL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_INTEL_VERSION) + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_INTEL_CL_VERSION) + #undef JSON_HEDLEY_INTEL_CL_VERSION +#endif +#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL) + #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0) +#endif + +#if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK) + #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_INTEL_CL_VERSION) + #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION) + #undef JSON_HEDLEY_PGI_VERSION +#endif +#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) + #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION_CHECK) + #undef JSON_HEDLEY_PGI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PGI_VERSION) + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #undef JSON_HEDLEY_SUNPRO_VERSION +#endif +#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) +#elif defined(__SUNPRO_C) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) +#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) +#elif defined(__SUNPRO_CC) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) + #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION +#endif +#if defined(__EMSCRIPTEN__) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION) + #undef JSON_HEDLEY_ARM_VERSION +#endif +#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) +#elif defined(__CC_ARM) && defined(__ARMCC_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION_CHECK) + #undef JSON_HEDLEY_ARM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_ARM_VERSION) + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION) + #undef JSON_HEDLEY_IBM_VERSION +#endif +#if defined(__ibmxl__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) +#elif defined(__xlC__) && defined(__xlC_ver__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) +#elif defined(__xlC__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION_CHECK) + #undef JSON_HEDLEY_IBM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IBM_VERSION) + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_VERSION) + #undef JSON_HEDLEY_TI_VERSION +#endif +#if \ + defined(__TI_COMPILER_VERSION__) && \ + ( \ + defined(__TMS470__) || defined(__TI_ARM__) || \ + defined(__MSP430__) || \ + defined(__TMS320C2000__) \ + ) +#if (__TI_COMPILER_VERSION__ >= 16000000) + #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif +#endif + +#if defined(JSON_HEDLEY_TI_VERSION_CHECK) + #undef JSON_HEDLEY_TI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_VERSION) + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #undef JSON_HEDLEY_TI_CL2000_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) + #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #undef JSON_HEDLEY_TI_CL430_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) + #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #undef JSON_HEDLEY_TI_ARMCL_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) + #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) + #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #undef JSON_HEDLEY_TI_CL6X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) + #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #undef JSON_HEDLEY_TI_CL7X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) + #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #undef JSON_HEDLEY_TI_CLPRU_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) + #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION) + #undef JSON_HEDLEY_CRAY_VERSION +#endif +#if defined(_CRAYC) + #if defined(_RELEASE_PATCHLEVEL) + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) + #else + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) + #undef JSON_HEDLEY_CRAY_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_CRAY_VERSION) + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION) + #undef JSON_HEDLEY_IAR_VERSION +#endif +#if defined(__IAR_SYSTEMS_ICC__) + #if __VER__ > 1000 + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) + #else + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION_CHECK) + #undef JSON_HEDLEY_IAR_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IAR_VERSION) + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION) + #undef JSON_HEDLEY_TINYC_VERSION +#endif +#if defined(__TINYC__) + #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) + #undef JSON_HEDLEY_TINYC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION) + #undef JSON_HEDLEY_DMC_VERSION +#endif +#if defined(__DMC__) + #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION_CHECK) + #undef JSON_HEDLEY_DMC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_DMC_VERSION) + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #undef JSON_HEDLEY_COMPCERT_VERSION +#endif +#if defined(__COMPCERT_VERSION__) + #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) + #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION) + #undef JSON_HEDLEY_PELLES_VERSION +#endif +#if defined(__POCC__) + #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) + #undef JSON_HEDLEY_PELLES_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PELLES_VERSION) + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_MCST_LCC_VERSION) + #undef JSON_HEDLEY_MCST_LCC_VERSION +#endif +#if defined(__LCC__) && defined(__LCC_MINOR__) + #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__) +#endif + +#if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK) + #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_MCST_LCC_VERSION) + #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION) + #undef JSON_HEDLEY_GCC_VERSION +#endif +#if \ + defined(JSON_HEDLEY_GNUC_VERSION) && \ + !defined(__clang__) && \ + !defined(JSON_HEDLEY_INTEL_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_ARM_VERSION) && \ + !defined(JSON_HEDLEY_CRAY_VERSION) && \ + !defined(JSON_HEDLEY_TI_VERSION) && \ + !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ + !defined(__COMPCERT__) && \ + !defined(JSON_HEDLEY_MCST_LCC_VERSION) + #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GCC_VERSION) + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_ATTRIBUTE +#endif +#if \ + defined(__has_attribute) && \ + ( \ + (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \ + ) +# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) +#else +# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE +#endif +#if \ + defined(__has_cpp_attribute) && \ + defined(__cplusplus) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS +#endif +#if !defined(__cplusplus) || !defined(__has_cpp_attribute) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#elif \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ + (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_BUILTIN) + #undef JSON_HEDLEY_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else + #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) + #undef JSON_HEDLEY_GNUC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) + #undef JSON_HEDLEY_GCC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_FEATURE) + #undef JSON_HEDLEY_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) +#else + #define JSON_HEDLEY_HAS_FEATURE(feature) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) + #undef JSON_HEDLEY_GNUC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_FEATURE) + #undef JSON_HEDLEY_GCC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_EXTENSION) + #undef JSON_HEDLEY_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) +#else + #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) + #undef JSON_HEDLEY_GNUC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) + #undef JSON_HEDLEY_GCC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_WARNING) + #undef JSON_HEDLEY_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) +#else + #define JSON_HEDLEY_HAS_WARNING(warning) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_WARNING) + #undef JSON_HEDLEY_GNUC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_WARNING) + #undef JSON_HEDLEY_GCC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + defined(__clang__) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ + (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) + #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_PRAGMA(value) __pragma(value) +#else + #define JSON_HEDLEY_PRAGMA(value) +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) + #undef JSON_HEDLEY_DIAGNOSTIC_PUSH +#endif +#if defined(JSON_HEDLEY_DIAGNOSTIC_POP) + #undef JSON_HEDLEY_DIAGNOSTIC_POP +#endif +#if defined(__clang__) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) + #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) +#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP +#endif + +/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") +# if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") +# if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions") +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ + _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# endif +# else +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# endif +# endif +#endif +#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x +#endif + +#if defined(JSON_HEDLEY_CONST_CAST) + #undef JSON_HEDLEY_CONST_CAST +#endif +#if defined(__cplusplus) +# define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) +#elif \ + JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_REINTERPRET_CAST) + #undef JSON_HEDLEY_REINTERPRET_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) +#else + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_STATIC_CAST) + #undef JSON_HEDLEY_STATIC_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) +#else + #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_CPP_CAST) + #undef JSON_HEDLEY_CPP_CAST +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ + ((T) (expr)) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("diag_suppress=Pe137") \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) +# endif +#else +# define JSON_HEDLEY_CPP_CAST(T, expr) (expr) +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunused-function") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION +#endif + +#if defined(JSON_HEDLEY_DEPRECATED) + #undef JSON_HEDLEY_DEPRECATED +#endif +#if defined(JSON_HEDLEY_DEPRECATED_FOR) + #undef JSON_HEDLEY_DEPRECATED_FOR +#endif +#if \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) +#elif \ + (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) +#elif defined(__cplusplus) && (__cplusplus >= 201402L) + #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") +#else + #define JSON_HEDLEY_DEPRECATED(since) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) +#endif + +#if defined(JSON_HEDLEY_UNAVAILABLE) + #undef JSON_HEDLEY_UNAVAILABLE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) +#else + #define JSON_HEDLEY_UNAVAILABLE(available_since) +#endif + +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT +#endif +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) +#elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) +#elif defined(_Check_return_) /* SAL */ + #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ +#else + #define JSON_HEDLEY_WARN_UNUSED_RESULT + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) +#endif + +#if defined(JSON_HEDLEY_SENTINEL) + #undef JSON_HEDLEY_SENTINEL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) +#else + #define JSON_HEDLEY_SENTINEL(position) +#endif + +#if defined(JSON_HEDLEY_NO_RETURN) + #undef JSON_HEDLEY_NO_RETURN +#endif +#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NO_RETURN __noreturn +#elif \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L + #define JSON_HEDLEY_NO_RETURN _Noreturn +#elif defined(__cplusplus) && (__cplusplus >= 201103L) + #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#else + #define JSON_HEDLEY_NO_RETURN +#endif + +#if defined(JSON_HEDLEY_NO_ESCAPE) + #undef JSON_HEDLEY_NO_ESCAPE +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) + #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) +#else + #define JSON_HEDLEY_NO_ESCAPE +#endif + +#if defined(JSON_HEDLEY_UNREACHABLE) + #undef JSON_HEDLEY_UNREACHABLE +#endif +#if defined(JSON_HEDLEY_UNREACHABLE_RETURN) + #undef JSON_HEDLEY_UNREACHABLE_RETURN +#endif +#if defined(JSON_HEDLEY_ASSUME) + #undef JSON_HEDLEY_ASSUME +#endif +#if \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_ASSUME(expr) __assume(expr) +#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) + #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) +#elif \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #if defined(__cplusplus) + #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) + #else + #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) + #endif +#endif +#if \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() +#elif defined(JSON_HEDLEY_ASSUME) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif +#if !defined(JSON_HEDLEY_ASSUME) + #if defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) + #else + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) + #endif +#endif +#if defined(JSON_HEDLEY_UNREACHABLE) + #if \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) + #else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() + #endif +#else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) +#endif +#if !defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif + +JSON_HEDLEY_DIAGNOSTIC_PUSH +#if JSON_HEDLEY_HAS_WARNING("-Wpedantic") + #pragma clang diagnostic ignored "-Wpedantic" +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) + #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" +#endif +#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) + #if defined(__clang__) + #pragma clang diagnostic ignored "-Wvariadic-macros" + #elif defined(JSON_HEDLEY_GCC_VERSION) + #pragma GCC diagnostic ignored "-Wvariadic-macros" + #endif +#endif +#if defined(JSON_HEDLEY_NON_NULL) + #undef JSON_HEDLEY_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) +#else + #define JSON_HEDLEY_NON_NULL(...) +#endif +JSON_HEDLEY_DIAGNOSTIC_POP + +#if defined(JSON_HEDLEY_PRINTF_FORMAT) + #undef JSON_HEDLEY_PRINTF_FORMAT +#endif +#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) +#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) +#else + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) +#endif + +#if defined(JSON_HEDLEY_CONSTEXPR) + #undef JSON_HEDLEY_CONSTEXPR +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) + #endif +#endif +#if !defined(JSON_HEDLEY_CONSTEXPR) + #define JSON_HEDLEY_CONSTEXPR +#endif + +#if defined(JSON_HEDLEY_PREDICT) + #undef JSON_HEDLEY_PREDICT +#endif +#if defined(JSON_HEDLEY_LIKELY) + #undef JSON_HEDLEY_LIKELY +#endif +#if defined(JSON_HEDLEY_UNLIKELY) + #undef JSON_HEDLEY_UNLIKELY +#endif +#if defined(JSON_HEDLEY_UNPREDICTABLE) + #undef JSON_HEDLEY_UNPREDICTABLE +#endif +#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) + #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) +#endif +#if \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) +#elif \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PREDICT(expr, expected, probability) \ + (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ + })) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ + })) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#else +# define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_LIKELY(expr) (!!(expr)) +# define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) +#endif +#if !defined(JSON_HEDLEY_UNPREDICTABLE) + #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) +#endif + +#if defined(JSON_HEDLEY_MALLOC) + #undef JSON_HEDLEY_MALLOC +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_MALLOC __declspec(restrict) +#else + #define JSON_HEDLEY_MALLOC +#endif + +#if defined(JSON_HEDLEY_PURE) + #undef JSON_HEDLEY_PURE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PURE __attribute__((__pure__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) +# define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ + ) +# define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") +#else +# define JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_CONST) + #undef JSON_HEDLEY_CONST +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_CONST __attribute__((__const__)) +#elif \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_CONST _Pragma("no_side_effect") +#else + #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_RESTRICT) + #undef JSON_HEDLEY_RESTRICT +#endif +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT restrict +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + defined(__clang__) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_RESTRICT __restrict +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT _Restrict +#else + #define JSON_HEDLEY_RESTRICT +#endif + +#if defined(JSON_HEDLEY_INLINE) + #undef JSON_HEDLEY_INLINE +#endif +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + (defined(__cplusplus) && (__cplusplus >= 199711L)) + #define JSON_HEDLEY_INLINE inline +#elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) + #define JSON_HEDLEY_INLINE __inline__ +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_INLINE __inline +#else + #define JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_ALWAYS_INLINE) + #undef JSON_HEDLEY_ALWAYS_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) +# define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_ALWAYS_INLINE __forceinline +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ + ) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") +#else +# define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_NEVER_INLINE) + #undef JSON_HEDLEY_NEVER_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#else + #define JSON_HEDLEY_NEVER_INLINE +#endif + +#if defined(JSON_HEDLEY_PRIVATE) + #undef JSON_HEDLEY_PRIVATE +#endif +#if defined(JSON_HEDLEY_PUBLIC) + #undef JSON_HEDLEY_PUBLIC +#endif +#if defined(JSON_HEDLEY_IMPORT) + #undef JSON_HEDLEY_IMPORT +#endif +#if defined(_WIN32) || defined(__CYGWIN__) +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC __declspec(dllexport) +# define JSON_HEDLEY_IMPORT __declspec(dllimport) +#else +# if \ + JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + ( \ + defined(__TI_EABI__) && \ + ( \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ + ) \ + ) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) +# define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) +# else +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC +# endif +# define JSON_HEDLEY_IMPORT extern +#endif + +#if defined(JSON_HEDLEY_NO_THROW) + #undef JSON_HEDLEY_NO_THROW +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NO_THROW __declspec(nothrow) +#else + #define JSON_HEDLEY_NO_THROW +#endif + +#if defined(JSON_HEDLEY_FALL_THROUGH) + #undef JSON_HEDLEY_FALL_THROUGH +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) +#elif defined(__fallthrough) /* SAL */ + #define JSON_HEDLEY_FALL_THROUGH __fallthrough +#else + #define JSON_HEDLEY_FALL_THROUGH +#endif + +#if defined(JSON_HEDLEY_RETURNS_NON_NULL) + #undef JSON_HEDLEY_RETURNS_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) +#elif defined(_Ret_notnull_) /* SAL */ + #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ +#else + #define JSON_HEDLEY_RETURNS_NON_NULL +#endif + +#if defined(JSON_HEDLEY_ARRAY_PARAM) + #undef JSON_HEDLEY_ARRAY_PARAM +#endif +#if \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ + !defined(__STDC_NO_VLA__) && \ + !defined(__cplusplus) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_ARRAY_PARAM(name) (name) +#else + #define JSON_HEDLEY_ARRAY_PARAM(name) +#endif + +#if defined(JSON_HEDLEY_IS_CONSTANT) + #undef JSON_HEDLEY_IS_CONSTANT +#endif +#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) + #undef JSON_HEDLEY_REQUIRE_CONSTEXPR +#endif +/* JSON_HEDLEY_IS_CONSTEXPR_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #undef JSON_HEDLEY_IS_CONSTEXPR_ +#endif +#if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) +#endif +#if !defined(__cplusplus) +# if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) +#endif +# elif \ + ( \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ + !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION)) || \ + (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) +#endif +# elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + defined(JSON_HEDLEY_INTEL_VERSION) || \ + defined(JSON_HEDLEY_TINYC_VERSION) || \ + defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ + defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ + defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ + defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ + defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ + defined(__clang__) +# define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ + sizeof(void) != \ + sizeof(*( \ + 1 ? \ + ((void*) ((expr) * 0L) ) : \ +((struct { char v[sizeof(void) * 2]; } *) 1) \ + ) \ + ) \ + ) +# endif +#endif +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) +#else + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) (0) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) +#endif + +#if defined(JSON_HEDLEY_BEGIN_C_DECLS) + #undef JSON_HEDLEY_BEGIN_C_DECLS +#endif +#if defined(JSON_HEDLEY_END_C_DECLS) + #undef JSON_HEDLEY_END_C_DECLS +#endif +#if defined(JSON_HEDLEY_C_DECL) + #undef JSON_HEDLEY_C_DECL +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { + #define JSON_HEDLEY_END_C_DECLS } + #define JSON_HEDLEY_C_DECL extern "C" +#else + #define JSON_HEDLEY_BEGIN_C_DECLS + #define JSON_HEDLEY_END_C_DECLS + #define JSON_HEDLEY_C_DECL +#endif + +#if defined(JSON_HEDLEY_STATIC_ASSERT) + #undef JSON_HEDLEY_STATIC_ASSERT +#endif +#if \ + !defined(__cplusplus) && ( \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ + (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + defined(_Static_assert) \ + ) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) +#elif \ + (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) +#else +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) +#endif + +#if defined(JSON_HEDLEY_NULL) + #undef JSON_HEDLEY_NULL +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) + #elif defined(NULL) + #define JSON_HEDLEY_NULL NULL + #else + #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) + #endif +#elif defined(NULL) + #define JSON_HEDLEY_NULL NULL +#else + #define JSON_HEDLEY_NULL ((void*) 0) +#endif + +#if defined(JSON_HEDLEY_MESSAGE) + #undef JSON_HEDLEY_MESSAGE +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_MESSAGE(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(message msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) +#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_WARNING) + #undef JSON_HEDLEY_WARNING +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_WARNING(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(clang warning msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_REQUIRE) + #undef JSON_HEDLEY_REQUIRE +#endif +#if defined(JSON_HEDLEY_REQUIRE_MSG) + #undef JSON_HEDLEY_REQUIRE_MSG +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) +# if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") +# define JSON_HEDLEY_REQUIRE(expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), #expr, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), msg, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) +# endif +#else +# define JSON_HEDLEY_REQUIRE(expr) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) +#endif + +#if defined(JSON_HEDLEY_FLAGS) + #undef JSON_HEDLEY_FLAGS +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING("-Wbitfield-enum-conversion")) + #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) +#else + #define JSON_HEDLEY_FLAGS +#endif + +#if defined(JSON_HEDLEY_FLAGS_CAST) + #undef JSON_HEDLEY_FLAGS_CAST +#endif +#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) +# define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("warning(disable:188)") \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) +#endif + +#if defined(JSON_HEDLEY_EMPTY_BASES) + #undef JSON_HEDLEY_EMPTY_BASES +#endif +#if \ + (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) +#else + #define JSON_HEDLEY_EMPTY_BASES +#endif + +/* Remaining macros are deprecated. */ + +#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK +#endif +#if defined(__clang__) + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) +#else + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) + #undef JSON_HEDLEY_CLANG_HAS_BUILTIN +#endif +#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) + +#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) + #undef JSON_HEDLEY_CLANG_HAS_FEATURE +#endif +#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) + +#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) + #undef JSON_HEDLEY_CLANG_HAS_EXTENSION +#endif +#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) + +#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_WARNING) + #undef JSON_HEDLEY_CLANG_HAS_WARNING +#endif +#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) + +#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ + +// #include + + +#include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template struct make_void +{ + using type = void; +}; +template using void_t = typename make_void::type; +} // namespace detail +} // namespace nlohmann + + +// https://en.cppreference.com/w/cpp/experimental/is_detected +namespace nlohmann +{ +namespace detail +{ +struct nonesuch +{ + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + nonesuch(nonesuch const&&) = delete; + void operator=(nonesuch const&) = delete; + void operator=(nonesuch&&) = delete; +}; + +template class Op, + class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; + +template class Op, class... Args> +using is_detected = typename detector::value_t; + +template class Op, class... Args> +struct is_detected_lazy : is_detected { }; + +template class Op, class... Args> +using detected_t = typename detector::type; + +template class Op, class... Args> +using detected_or = detector; + +template class Op, class... Args> +using detected_or_t = typename detected_or::type; + +template class Op, class... Args> +using is_detected_exact = std::is_same>; + +template class Op, class... Args> +using is_detected_convertible = + std::is_convertible, To>; +} // namespace detail +} // namespace nlohmann + + +// This file contains all internal macro definitions +// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them + +// exclude unsupported compilers +#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) + #if defined(__clang__) + #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 + #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) + #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 + #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #endif +#endif + +// C++ language standard detection +// if the user manually specified the used c++ version this is skipped +#if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11) + #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) + #define JSON_HAS_CPP_20 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 + #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 + #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) + #define JSON_HAS_CPP_14 + #endif + // the cpp 11 flag is always specified because it is the minimal required version + #define JSON_HAS_CPP_11 +#endif + +// disable documentation warnings on clang +#if defined(__clang__) + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wdocumentation" + #pragma clang diagnostic ignored "-Wdocumentation-unknown-command" +#endif + +// allow to disable exceptions +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) + #define JSON_THROW(exception) throw exception + #define JSON_TRY try + #define JSON_CATCH(exception) catch(exception) + #define JSON_INTERNAL_CATCH(exception) catch(exception) +#else + #include + #define JSON_THROW(exception) std::abort() + #define JSON_TRY if(true) + #define JSON_CATCH(exception) if(false) + #define JSON_INTERNAL_CATCH(exception) if(false) +#endif + +// override exception macros +#if defined(JSON_THROW_USER) + #undef JSON_THROW + #define JSON_THROW JSON_THROW_USER +#endif +#if defined(JSON_TRY_USER) + #undef JSON_TRY + #define JSON_TRY JSON_TRY_USER +#endif +#if defined(JSON_CATCH_USER) + #undef JSON_CATCH + #define JSON_CATCH JSON_CATCH_USER + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_CATCH_USER +#endif +#if defined(JSON_INTERNAL_CATCH_USER) + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER +#endif + +// allow to override assert +#if !defined(JSON_ASSERT) + #include // assert + #define JSON_ASSERT(x) assert(x) +#endif + +// allow to access some private functions (needed by the test suite) +#if defined(JSON_TESTS_PRIVATE) + #define JSON_PRIVATE_UNLESS_TESTED public +#else + #define JSON_PRIVATE_UNLESS_TESTED private +#endif + +/*! +@brief macro to briefly define a mapping between an enum and JSON +@def NLOHMANN_JSON_SERIALIZE_ENUM +@since version 3.4.0 +*/ +#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ + template \ + inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [e](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.first == e; \ + }); \ + j = ((it != std::end(m)) ? it : std::begin(m))->second; \ + } \ + template \ + inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [&j](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.second == j; \ + }); \ + e = ((it != std::end(m)) ? it : std::begin(m))->first; \ + } + +// Ugly macros to avoid uglier copy-paste when specializing basic_json. They +// may be removed in the future once the class is split. + +#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ + template class ObjectType, \ + template class ArrayType, \ + class StringType, class BooleanType, class NumberIntegerType, \ + class NumberUnsignedType, class NumberFloatType, \ + template class AllocatorType, \ + template class JSONSerializer, \ + class BinaryType> + +#define NLOHMANN_BASIC_JSON_TPL \ + basic_json + +// Macros to simplify conversion from/to types + +#define NLOHMANN_JSON_EXPAND( x ) x +#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME +#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ + NLOHMANN_JSON_PASTE64, \ + NLOHMANN_JSON_PASTE63, \ + NLOHMANN_JSON_PASTE62, \ + NLOHMANN_JSON_PASTE61, \ + NLOHMANN_JSON_PASTE60, \ + NLOHMANN_JSON_PASTE59, \ + NLOHMANN_JSON_PASTE58, \ + NLOHMANN_JSON_PASTE57, \ + NLOHMANN_JSON_PASTE56, \ + NLOHMANN_JSON_PASTE55, \ + NLOHMANN_JSON_PASTE54, \ + NLOHMANN_JSON_PASTE53, \ + NLOHMANN_JSON_PASTE52, \ + NLOHMANN_JSON_PASTE51, \ + NLOHMANN_JSON_PASTE50, \ + NLOHMANN_JSON_PASTE49, \ + NLOHMANN_JSON_PASTE48, \ + NLOHMANN_JSON_PASTE47, \ + NLOHMANN_JSON_PASTE46, \ + NLOHMANN_JSON_PASTE45, \ + NLOHMANN_JSON_PASTE44, \ + NLOHMANN_JSON_PASTE43, \ + NLOHMANN_JSON_PASTE42, \ + NLOHMANN_JSON_PASTE41, \ + NLOHMANN_JSON_PASTE40, \ + NLOHMANN_JSON_PASTE39, \ + NLOHMANN_JSON_PASTE38, \ + NLOHMANN_JSON_PASTE37, \ + NLOHMANN_JSON_PASTE36, \ + NLOHMANN_JSON_PASTE35, \ + NLOHMANN_JSON_PASTE34, \ + NLOHMANN_JSON_PASTE33, \ + NLOHMANN_JSON_PASTE32, \ + NLOHMANN_JSON_PASTE31, \ + NLOHMANN_JSON_PASTE30, \ + NLOHMANN_JSON_PASTE29, \ + NLOHMANN_JSON_PASTE28, \ + NLOHMANN_JSON_PASTE27, \ + NLOHMANN_JSON_PASTE26, \ + NLOHMANN_JSON_PASTE25, \ + NLOHMANN_JSON_PASTE24, \ + NLOHMANN_JSON_PASTE23, \ + NLOHMANN_JSON_PASTE22, \ + NLOHMANN_JSON_PASTE21, \ + NLOHMANN_JSON_PASTE20, \ + NLOHMANN_JSON_PASTE19, \ + NLOHMANN_JSON_PASTE18, \ + NLOHMANN_JSON_PASTE17, \ + NLOHMANN_JSON_PASTE16, \ + NLOHMANN_JSON_PASTE15, \ + NLOHMANN_JSON_PASTE14, \ + NLOHMANN_JSON_PASTE13, \ + NLOHMANN_JSON_PASTE12, \ + NLOHMANN_JSON_PASTE11, \ + NLOHMANN_JSON_PASTE10, \ + NLOHMANN_JSON_PASTE9, \ + NLOHMANN_JSON_PASTE8, \ + NLOHMANN_JSON_PASTE7, \ + NLOHMANN_JSON_PASTE6, \ + NLOHMANN_JSON_PASTE5, \ + NLOHMANN_JSON_PASTE4, \ + NLOHMANN_JSON_PASTE3, \ + NLOHMANN_JSON_PASTE2, \ + NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) +#define NLOHMANN_JSON_PASTE2(func, v1) func(v1) +#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) +#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) +#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) +#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) +#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) +#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) +#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) +#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) +#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) +#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) +#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) +#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) +#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) +#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) +#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) +#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) +#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) +#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) +#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) +#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) +#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) +#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) +#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) +#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) +#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) +#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) +#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) +#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) +#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) +#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) +#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) +#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) +#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) +#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) +#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) +#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) +#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) +#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) +#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) +#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) +#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) +#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) +#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) +#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) +#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) +#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) +#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) +#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) +#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) +#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) +#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) +#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) +#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) +#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) +#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) +#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) +#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) +#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) +#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) +#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) +#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) + +#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; +#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + + +// inspired from https://stackoverflow.com/a/26745591 +// allows to call any std function as if (e.g. with begin): +// using std::begin; begin(x); +// +// it allows using the detected idiom to retrieve the return type +// of such an expression +#define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name) \ + namespace detail { \ + using std::std_name; \ + \ + template \ + using result_of_##std_name = decltype(std_name(std::declval()...)); \ + } \ + \ + namespace detail2 { \ + struct std_name##_tag \ + { \ + }; \ + \ + template \ + std_name##_tag std_name(T&&...); \ + \ + template \ + using result_of_##std_name = decltype(std_name(std::declval()...)); \ + \ + template \ + struct would_call_std_##std_name \ + { \ + static constexpr auto const value = ::nlohmann::detail:: \ + is_detected_exact::value; \ + }; \ + } /* namespace detail2 */ \ + \ + template \ + struct would_call_std_##std_name : detail2::would_call_std_##std_name \ + { \ + } + +#ifndef JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_USE_IMPLICIT_CONVERSIONS 1 +#endif + +#if JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_EXPLICIT +#else + #define JSON_EXPLICIT explicit +#endif + +#ifndef JSON_DIAGNOSTICS + #define JSON_DIAGNOSTICS 0 +#endif + + +namespace nlohmann +{ +namespace detail +{ + +/*! +@brief replace all occurrences of a substring by another string + +@param[in,out] s the string to manipulate; changed so that all + occurrences of @a f are replaced with @a t +@param[in] f the substring to replace with @a t +@param[in] t the string to replace @a f + +@pre The search string @a f must not be empty. **This precondition is +enforced with an assertion.** + +@since version 2.0.0 +*/ +inline void replace_substring(std::string& s, const std::string& f, + const std::string& t) +{ + JSON_ASSERT(!f.empty()); + for (auto pos = s.find(f); // find first occurrence of f + pos != std::string::npos; // make sure f was found + s.replace(pos, f.size(), t), // replace with t, and + pos = s.find(f, pos + t.size())) // find next occurrence of f + {} +} + +/*! + * @brief string escaping as described in RFC 6901 (Sect. 4) + * @param[in] s string to escape + * @return escaped string + * + * Note the order of escaping "~" to "~0" and "/" to "~1" is important. + */ +inline std::string escape(std::string s) +{ + replace_substring(s, "~", "~0"); + replace_substring(s, "/", "~1"); + return s; +} + +/*! + * @brief string unescaping as described in RFC 6901 (Sect. 4) + * @param[in] s string to unescape + * @return unescaped string + * + * Note the order of escaping "~1" to "/" and "~0" to "~" is important. + */ +static void unescape(std::string& s) +{ + replace_substring(s, "~1", "/"); + replace_substring(s, "~0", "~"); +} + +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // size_t + +namespace nlohmann +{ +namespace detail +{ +/// struct to capture the start position of the current token +struct position_t +{ + /// the total number of characters read + std::size_t chars_read_total = 0; + /// the number of characters read in the current line + std::size_t chars_read_current_line = 0; + /// the number of lines read + std::size_t lines_read = 0; + + /// conversion to size_t to preserve SAX interface + constexpr operator size_t() const + { + return chars_read_total; + } +}; + +} // namespace detail +} // namespace nlohmann + +// #include + + +namespace nlohmann +{ +namespace detail +{ +//////////////// +// exceptions // +//////////////// + +/*! +@brief general exception of the @ref basic_json class + +This class is an extension of `std::exception` objects with a member @a id for +exception ids. It is used as the base class for all exceptions thrown by the +@ref basic_json class. This class can hence be used as "wildcard" to catch +exceptions. + +Subclasses: +- @ref parse_error for exceptions indicating a parse error +- @ref invalid_iterator for exceptions indicating errors with iterators +- @ref type_error for exceptions indicating executing a member function with + a wrong type +- @ref out_of_range for exceptions indicating access out of the defined range +- @ref other_error for exceptions indicating other library errors + +@internal +@note To have nothrow-copy-constructible exceptions, we internally use + `std::runtime_error` which can cope with arbitrary-length error messages. + Intermediate strings are built with static functions and then passed to + the actual constructor. +@endinternal + +@liveexample{The following code shows how arbitrary library exceptions can be +caught.,exception} + +@since version 3.0.0 +*/ +class exception : public std::exception +{ + public: + /// returns the explanatory string + const char* what() const noexcept override + { + return m.what(); + } + + /// the id of the exception + const int id; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes) + + protected: + JSON_HEDLEY_NON_NULL(3) + exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} // NOLINT(bugprone-throw-keyword-missing) + + static std::string name(const std::string& ename, int id_) + { + return "[json.exception." + ename + "." + std::to_string(id_) + "] "; + } + + template + static std::string diagnostics(const BasicJsonType& leaf_element) + { +#if JSON_DIAGNOSTICS + std::vector tokens; + for (const auto* current = &leaf_element; current->m_parent != nullptr; current = current->m_parent) + { + switch (current->m_parent->type()) + { + case value_t::array: + { + for (std::size_t i = 0; i < current->m_parent->m_value.array->size(); ++i) + { + if (¤t->m_parent->m_value.array->operator[](i) == current) + { + tokens.emplace_back(std::to_string(i)); + break; + } + } + break; + } + + case value_t::object: + { + for (const auto& element : *current->m_parent->m_value.object) + { + if (&element.second == current) + { + tokens.emplace_back(element.first.c_str()); + break; + } + } + break; + } + + case value_t::null: // LCOV_EXCL_LINE + case value_t::string: // LCOV_EXCL_LINE + case value_t::boolean: // LCOV_EXCL_LINE + case value_t::number_integer: // LCOV_EXCL_LINE + case value_t::number_unsigned: // LCOV_EXCL_LINE + case value_t::number_float: // LCOV_EXCL_LINE + case value_t::binary: // LCOV_EXCL_LINE + case value_t::discarded: // LCOV_EXCL_LINE + default: // LCOV_EXCL_LINE + break; // LCOV_EXCL_LINE + } + } + + if (tokens.empty()) + { + return ""; + } + + return "(" + std::accumulate(tokens.rbegin(), tokens.rend(), std::string{}, + [](const std::string & a, const std::string & b) + { + return a + "/" + detail::escape(b); + }) + ") "; +#else + static_cast(leaf_element); + return ""; +#endif + } + + private: + /// an exception object as storage for error messages + std::runtime_error m; +}; + +/*! +@brief exception indicating a parse error + +This exception is thrown by the library when a parse error occurs. Parse errors +can occur during the deserialization of JSON text, CBOR, MessagePack, as well +as when using JSON Patch. + +Member @a byte holds the byte index of the last read character in the input +file. + +Exceptions have ids 1xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.parse_error.101 | parse error at 2: unexpected end of input; expected string literal | This error indicates a syntax error while deserializing a JSON text. The error message describes that an unexpected token (character) was encountered, and the member @a byte indicates the error position. +json.exception.parse_error.102 | parse error at 14: missing or wrong low surrogate | JSON uses the `\uxxxx` format to describe Unicode characters. Code points above above 0xFFFF are split into two `\uxxxx` entries ("surrogate pairs"). This error indicates that the surrogate pair is incomplete or contains an invalid code point. +json.exception.parse_error.103 | parse error: code points above 0x10FFFF are invalid | Unicode supports code points up to 0x10FFFF. Code points above 0x10FFFF are invalid. +json.exception.parse_error.104 | parse error: JSON patch must be an array of objects | [RFC 6902](https://tools.ietf.org/html/rfc6902) requires a JSON Patch document to be a JSON document that represents an array of objects. +json.exception.parse_error.105 | parse error: operation must have string member 'op' | An operation of a JSON Patch document must contain exactly one "op" member, whose value indicates the operation to perform. Its value must be one of "add", "remove", "replace", "move", "copy", or "test"; other values are errors. +json.exception.parse_error.106 | parse error: array index '01' must not begin with '0' | An array index in a JSON Pointer ([RFC 6901](https://tools.ietf.org/html/rfc6901)) may be `0` or any number without a leading `0`. +json.exception.parse_error.107 | parse error: JSON pointer must be empty or begin with '/' - was: 'foo' | A JSON Pointer must be a Unicode string containing a sequence of zero or more reference tokens, each prefixed by a `/` character. +json.exception.parse_error.108 | parse error: escape character '~' must be followed with '0' or '1' | In a JSON Pointer, only `~0` and `~1` are valid escape sequences. +json.exception.parse_error.109 | parse error: array index 'one' is not a number | A JSON Pointer array index must be a number. +json.exception.parse_error.110 | parse error at 1: cannot read 2 bytes from vector | When parsing CBOR or MessagePack, the byte vector ends before the complete value has been read. +json.exception.parse_error.112 | parse error at 1: error reading CBOR; last byte: 0xF8 | Not all types of CBOR or MessagePack are supported. This exception occurs if an unsupported byte was read. +json.exception.parse_error.113 | parse error at 2: expected a CBOR string; last byte: 0x98 | While parsing a map key, a value that is not a string has been read. +json.exception.parse_error.114 | parse error: Unsupported BSON record type 0x0F | The parsing of the corresponding BSON record type is not implemented (yet). +json.exception.parse_error.115 | parse error at byte 5: syntax error while parsing UBJSON high-precision number: invalid number text: 1A | A UBJSON high-precision number could not be parsed. + +@note For an input with n bytes, 1 is the index of the first character and n+1 + is the index of the terminating null byte or the end of file. This also + holds true when reading a byte vector (CBOR or MessagePack). + +@liveexample{The following code shows how a `parse_error` exception can be +caught.,parse_error} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class parse_error : public exception +{ + public: + /*! + @brief create a parse error exception + @param[in] id_ the id of the exception + @param[in] pos the position where the error occurred (or with + chars_read_total=0 if the position cannot be + determined) + @param[in] what_arg the explanatory string + @return parse_error object + */ + template + static parse_error create(int id_, const position_t& pos, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + position_string(pos) + ": " + exception::diagnostics(context) + what_arg; + return {id_, pos.chars_read_total, w.c_str()}; + } + + template + static parse_error create(int id_, std::size_t byte_, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + (byte_ != 0 ? (" at byte " + std::to_string(byte_)) : "") + + ": " + exception::diagnostics(context) + what_arg; + return {id_, byte_, w.c_str()}; + } + + /*! + @brief byte index of the parse error + + The byte index of the last read character in the input file. + + @note For an input with n bytes, 1 is the index of the first character and + n+1 is the index of the terminating null byte or the end of file. + This also holds true when reading a byte vector (CBOR or MessagePack). + */ + const std::size_t byte; + + private: + parse_error(int id_, std::size_t byte_, const char* what_arg) + : exception(id_, what_arg), byte(byte_) {} + + static std::string position_string(const position_t& pos) + { + return " at line " + std::to_string(pos.lines_read + 1) + + ", column " + std::to_string(pos.chars_read_current_line); + } +}; + +/*! +@brief exception indicating errors with iterators + +This exception is thrown if iterators passed to a library function do not match +the expected semantics. + +Exceptions have ids 2xx. + +name / id | example message | description +----------------------------------- | --------------- | ------------------------- +json.exception.invalid_iterator.201 | iterators are not compatible | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.202 | iterator does not fit current value | In an erase or insert function, the passed iterator @a pos does not belong to the JSON value for which the function was called. It hence does not define a valid position for the deletion/insertion. +json.exception.invalid_iterator.203 | iterators do not fit current value | Either iterator passed to function @ref erase(IteratorType first, IteratorType last) does not belong to the JSON value from which values shall be erased. It hence does not define a valid range to delete values from. +json.exception.invalid_iterator.204 | iterators out of range | When an iterator range for a primitive type (number, boolean, or string) is passed to a constructor or an erase function, this range has to be exactly (@ref begin(), @ref end()), because this is the only way the single stored value is expressed. All other ranges are invalid. +json.exception.invalid_iterator.205 | iterator out of range | When an iterator for a primitive type (number, boolean, or string) is passed to an erase function, the iterator has to be the @ref begin() iterator, because it is the only way to address the stored value. All other iterators are invalid. +json.exception.invalid_iterator.206 | cannot construct with iterators from null | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) belong to a JSON null value and hence to not define a valid range. +json.exception.invalid_iterator.207 | cannot use key() for non-object iterators | The key() member function can only be used on iterators belonging to a JSON object, because other types do not have a concept of a key. +json.exception.invalid_iterator.208 | cannot use operator[] for object iterators | The operator[] to specify a concrete offset cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.209 | cannot use offsets with object iterators | The offset operators (+, -, +=, -=) cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.210 | iterators do not fit | The iterator range passed to the insert function are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.211 | passed iterators may not belong to container | The iterator range passed to the insert function must not be a subrange of the container to insert to. +json.exception.invalid_iterator.212 | cannot compare iterators of different containers | When two iterators are compared, they must belong to the same container. +json.exception.invalid_iterator.213 | cannot compare order of object iterators | The order of object iterators cannot be compared, because JSON objects are unordered. +json.exception.invalid_iterator.214 | cannot get value | Cannot get value for iterator: Either the iterator belongs to a null value or it is an iterator to a primitive type (number, boolean, or string), but the iterator is different to @ref begin(). + +@liveexample{The following code shows how an `invalid_iterator` exception can be +caught.,invalid_iterator} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class invalid_iterator : public exception +{ + public: + template + static invalid_iterator create(int id_, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("invalid_iterator", id_) + exception::diagnostics(context) + what_arg; + return {id_, w.c_str()}; + } + + private: + JSON_HEDLEY_NON_NULL(3) + invalid_iterator(int id_, const char* what_arg) + : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating executing a member function with a wrong type + +This exception is thrown in case of a type error; that is, a library function is +executed on a JSON value whose type does not match the expected semantics. + +Exceptions have ids 3xx. + +name / id | example message | description +----------------------------- | --------------- | ------------------------- +json.exception.type_error.301 | cannot create object from initializer list | To create an object from an initializer list, the initializer list must consist only of a list of pairs whose first element is a string. When this constraint is violated, an array is created instead. +json.exception.type_error.302 | type must be object, but is array | During implicit or explicit value conversion, the JSON type must be compatible to the target type. For instance, a JSON string can only be converted into string types, but not into numbers or boolean types. +json.exception.type_error.303 | incompatible ReferenceType for get_ref, actual type is object | To retrieve a reference to a value stored in a @ref basic_json object with @ref get_ref, the type of the reference must match the value type. For instance, for a JSON array, the @a ReferenceType must be @ref array_t &. +json.exception.type_error.304 | cannot use at() with string | The @ref at() member functions can only be executed for certain JSON types. +json.exception.type_error.305 | cannot use operator[] with string | The @ref operator[] member functions can only be executed for certain JSON types. +json.exception.type_error.306 | cannot use value() with string | The @ref value() member functions can only be executed for certain JSON types. +json.exception.type_error.307 | cannot use erase() with string | The @ref erase() member functions can only be executed for certain JSON types. +json.exception.type_error.308 | cannot use push_back() with string | The @ref push_back() and @ref operator+= member functions can only be executed for certain JSON types. +json.exception.type_error.309 | cannot use insert() with | The @ref insert() member functions can only be executed for certain JSON types. +json.exception.type_error.310 | cannot use swap() with number | The @ref swap() member functions can only be executed for certain JSON types. +json.exception.type_error.311 | cannot use emplace_back() with string | The @ref emplace_back() member function can only be executed for certain JSON types. +json.exception.type_error.312 | cannot use update() with string | The @ref update() member functions can only be executed for certain JSON types. +json.exception.type_error.313 | invalid value to unflatten | The @ref unflatten function converts an object whose keys are JSON Pointers back into an arbitrary nested JSON value. The JSON Pointers must not overlap, because then the resulting value would not be well defined. +json.exception.type_error.314 | only objects can be unflattened | The @ref unflatten function only works for an object whose keys are JSON Pointers. +json.exception.type_error.315 | values in object must be primitive | The @ref unflatten function only works for an object whose keys are JSON Pointers and whose values are primitive. +json.exception.type_error.316 | invalid UTF-8 byte at index 10: 0x7E | The @ref dump function only works with UTF-8 encoded strings; that is, if you assign a `std::string` to a JSON value, make sure it is UTF-8 encoded. | +json.exception.type_error.317 | JSON value cannot be serialized to requested format | The dynamic type of the object cannot be represented in the requested serialization format (e.g. a raw `true` or `null` JSON object cannot be serialized to BSON) | + +@liveexample{The following code shows how a `type_error` exception can be +caught.,type_error} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class type_error : public exception +{ + public: + template + static type_error create(int id_, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("type_error", id_) + exception::diagnostics(context) + what_arg; + return {id_, w.c_str()}; + } + + private: + JSON_HEDLEY_NON_NULL(3) + type_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating access out of the defined range + +This exception is thrown in case a library function is called on an input +parameter that exceeds the expected range, for instance in case of array +indices or nonexisting object keys. + +Exceptions have ids 4xx. + +name / id | example message | description +------------------------------- | --------------- | ------------------------- +json.exception.out_of_range.401 | array index 3 is out of range | The provided array index @a i is larger than @a size-1. +json.exception.out_of_range.402 | array index '-' (3) is out of range | The special array index `-` in a JSON Pointer never describes a valid element of the array, but the index past the end. That is, it can only be used to add elements at this position, but not to read it. +json.exception.out_of_range.403 | key 'foo' not found | The provided key was not found in the JSON object. +json.exception.out_of_range.404 | unresolved reference token 'foo' | A reference token in a JSON Pointer could not be resolved. +json.exception.out_of_range.405 | JSON pointer has no parent | The JSON Patch operations 'remove' and 'add' can not be applied to the root element of the JSON value. +json.exception.out_of_range.406 | number overflow parsing '10E1000' | A parsed number could not be stored as without changing it to NaN or INF. +json.exception.out_of_range.407 | number overflow serializing '9223372036854775808' | UBJSON and BSON only support integer numbers up to 9223372036854775807. (until version 3.8.0) | +json.exception.out_of_range.408 | excessive array size: 8658170730974374167 | The size (following `#`) of an UBJSON array or object exceeds the maximal capacity. | +json.exception.out_of_range.409 | BSON key cannot contain code point U+0000 (at byte 2) | Key identifiers to be serialized to BSON cannot contain code point U+0000, since the key is stored as zero-terminated c-string | + +@liveexample{The following code shows how an `out_of_range` exception can be +caught.,out_of_range} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class out_of_range : public exception +{ + public: + template + static out_of_range create(int id_, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("out_of_range", id_) + exception::diagnostics(context) + what_arg; + return {id_, w.c_str()}; + } + + private: + JSON_HEDLEY_NON_NULL(3) + out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating other library errors + +This exception is thrown in case of errors that cannot be classified with the +other exception types. + +Exceptions have ids 5xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.other_error.501 | unsuccessful: {"op":"test","path":"/baz", "value":"bar"} | A JSON Patch operation 'test' failed. The unsuccessful operation is also printed. + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range + +@liveexample{The following code shows how an `other_error` exception can be +caught.,other_error} + +@since version 3.0.0 +*/ +class other_error : public exception +{ + public: + template + static other_error create(int id_, const std::string& what_arg, const BasicJsonType& context) + { + std::string w = exception::name("other_error", id_) + exception::diagnostics(context) + what_arg; + return {id_, w.c_str()}; + } + + private: + JSON_HEDLEY_NON_NULL(3) + other_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // size_t +#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type +#include // index_sequence, make_index_sequence, index_sequence_for + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +template +using uncvref_t = typename std::remove_cv::type>::type; + +#ifdef JSON_HAS_CPP_14 + +// the following utilities are natively available in C++14 +using std::enable_if_t; +using std::index_sequence; +using std::make_index_sequence; +using std::index_sequence_for; + +#else + +// alias templates to reduce boilerplate +template +using enable_if_t = typename std::enable_if::type; + +// The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h +// which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0. + +//// START OF CODE FROM GOOGLE ABSEIL + +// integer_sequence +// +// Class template representing a compile-time integer sequence. An instantiation +// of `integer_sequence` has a sequence of integers encoded in its +// type through its template arguments (which is a common need when +// working with C++11 variadic templates). `absl::integer_sequence` is designed +// to be a drop-in replacement for C++14's `std::integer_sequence`. +// +// Example: +// +// template< class T, T... Ints > +// void user_function(integer_sequence); +// +// int main() +// { +// // user_function's `T` will be deduced to `int` and `Ints...` +// // will be deduced to `0, 1, 2, 3, 4`. +// user_function(make_integer_sequence()); +// } +template +struct integer_sequence +{ + using value_type = T; + static constexpr std::size_t size() noexcept + { + return sizeof...(Ints); + } +}; + +// index_sequence +// +// A helper template for an `integer_sequence` of `size_t`, +// `absl::index_sequence` is designed to be a drop-in replacement for C++14's +// `std::index_sequence`. +template +using index_sequence = integer_sequence; + +namespace utility_internal +{ + +template +struct Extend; + +// Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency. +template +struct Extend, SeqSize, 0> +{ + using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >; +}; + +template +struct Extend, SeqSize, 1> +{ + using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >; +}; + +// Recursion helper for 'make_integer_sequence'. +// 'Gen::type' is an alias for 'integer_sequence'. +template +struct Gen +{ + using type = + typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type; +}; + +template +struct Gen +{ + using type = integer_sequence; +}; + +} // namespace utility_internal + +// Compile-time sequences of integers + +// make_integer_sequence +// +// This template alias is equivalent to +// `integer_sequence`, and is designed to be a drop-in +// replacement for C++14's `std::make_integer_sequence`. +template +using make_integer_sequence = typename utility_internal::Gen::type; + +// make_index_sequence +// +// This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`, +// and is designed to be a drop-in replacement for C++14's +// `std::make_index_sequence`. +template +using make_index_sequence = make_integer_sequence; + +// index_sequence_for +// +// Converts a typename pack into an index sequence of the same length, and +// is designed to be a drop-in replacement for C++14's +// `std::index_sequence_for()` +template +using index_sequence_for = make_index_sequence; + +//// END OF CODE FROM GOOGLE ABSEIL + +#endif + +// dispatch utility (taken from ranges-v3) +template struct priority_tag : priority_tag < N - 1 > {}; +template<> struct priority_tag<0> {}; + +// taken from ranges-v3 +template +struct static_const +{ + static constexpr T value{}; +}; + +template +constexpr T static_const::value; + +} // namespace detail +} // namespace nlohmann + +// #include + + +namespace nlohmann +{ +namespace detail +{ +// dispatching helper struct +template struct identity_tag {}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // numeric_limits +#include // false_type, is_constructible, is_integral, is_same, true_type +#include // declval +#include // tuple + +// #include + + +// #include + + +#include // random_access_iterator_tag + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +struct iterator_types {}; + +template +struct iterator_types < + It, + void_t> +{ + using difference_type = typename It::difference_type; + using value_type = typename It::value_type; + using pointer = typename It::pointer; + using reference = typename It::reference; + using iterator_category = typename It::iterator_category; +}; + +// This is required as some compilers implement std::iterator_traits in a way that +// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. +template +struct iterator_traits +{ +}; + +template +struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> + : iterator_types +{ +}; + +template +struct iterator_traits::value>> +{ + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T*; + using reference = T&; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +// #include + + +namespace nlohmann +{ +NLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin); +} // namespace nlohmann + +// #include + + +// #include + + +namespace nlohmann +{ +NLOHMANN_CAN_CALL_STD_FUNC_IMPL(end); +} // namespace nlohmann + +// #include + +// #include + +// #include +#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ +#define INCLUDE_NLOHMANN_JSON_FWD_HPP_ + +#include // int64_t, uint64_t +#include // map +#include // allocator +#include // string +#include // vector + +/*! +@brief namespace for Niels Lohmann +@see https://github.com/nlohmann +@since version 1.0.0 +*/ +namespace nlohmann +{ +/*! +@brief default JSONSerializer template argument + +This serializer ignores the template arguments and uses ADL +([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) +for serialization. +*/ +template +struct adl_serializer; + +template class ObjectType = + std::map, + template class ArrayType = std::vector, + class StringType = std::string, class BooleanType = bool, + class NumberIntegerType = std::int64_t, + class NumberUnsignedType = std::uint64_t, + class NumberFloatType = double, + template class AllocatorType = std::allocator, + template class JSONSerializer = + adl_serializer, + class BinaryType = std::vector> +class basic_json; + +/*! +@brief JSON Pointer + +A JSON pointer defines a string syntax for identifying a specific value +within a JSON document. It can be used with functions `at` and +`operator[]`. Furthermore, JSON pointers are the base for JSON patches. + +@sa [RFC 6901](https://tools.ietf.org/html/rfc6901) + +@since version 2.0.0 +*/ +template +class json_pointer; + +/*! +@brief default JSON class + +This type is the default specialization of the @ref basic_json class which +uses the standard template types. + +@since version 1.0.0 +*/ +using json = basic_json<>; + +template +struct ordered_map; + +/*! +@brief ordered JSON class + +This type preserves the insertion order of object keys. + +@since version 3.9.0 +*/ +using ordered_json = basic_json; + +} // namespace nlohmann + +#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ + + +namespace nlohmann +{ +/*! +@brief detail namespace with internal helper functions + +This namespace collects functions that should not be exposed, +implementations of some @ref basic_json methods, and meta-programming helpers. + +@since version 2.1.0 +*/ +namespace detail +{ +///////////// +// helpers // +///////////// + +// Note to maintainers: +// +// Every trait in this file expects a non CV-qualified type. +// The only exceptions are in the 'aliases for detected' section +// (i.e. those of the form: decltype(T::member_function(std::declval()))) +// +// In this case, T has to be properly CV-qualified to constraint the function arguments +// (e.g. to_json(BasicJsonType&, const T&)) + +template struct is_basic_json : std::false_type {}; + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +struct is_basic_json : std::true_type {}; + +////////////////////// +// json_ref helpers // +////////////////////// + +template +class json_ref; + +template +struct is_json_ref : std::false_type {}; + +template +struct is_json_ref> : std::true_type {}; + +////////////////////////// +// aliases for detected // +////////////////////////// + +template +using mapped_type_t = typename T::mapped_type; + +template +using key_type_t = typename T::key_type; + +template +using value_type_t = typename T::value_type; + +template +using difference_type_t = typename T::difference_type; + +template +using pointer_t = typename T::pointer; + +template +using reference_t = typename T::reference; + +template +using iterator_category_t = typename T::iterator_category; + +template +using to_json_function = decltype(T::to_json(std::declval()...)); + +template +using from_json_function = decltype(T::from_json(std::declval()...)); + +template +using get_template_function = decltype(std::declval().template get()); + +// trait checking if JSONSerializer::from_json(json const&, udt&) exists +template +struct has_from_json : std::false_type {}; + +// trait checking if j.get is valid +// use this trait instead of std::is_constructible or std::is_convertible, +// both rely on, or make use of implicit conversions, and thus fail when T +// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) +template +struct is_getable +{ + static constexpr bool value = is_detected::value; +}; + +template +struct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if JSONSerializer::from_json(json const&) exists +// this overload is used for non-default-constructible user-defined-types +template +struct has_non_default_from_json : std::false_type {}; + +template +struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if BasicJsonType::json_serializer::to_json exists +// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. +template +struct has_to_json : std::false_type {}; + +template +struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + + +/////////////////// +// is_ functions // +/////////////////// + +// https://en.cppreference.com/w/cpp/types/conjunction +template struct conjunction : std::true_type { }; +template struct conjunction : B1 { }; +template +struct conjunction +: std::conditional, B1>::type {}; + +// https://en.cppreference.com/w/cpp/types/negation +template struct negation : std::integral_constant < bool, !B::value > { }; + +// Reimplementation of is_constructible and is_default_constructible, due to them being broken for +// std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367). +// This causes compile errors in e.g. clang 3.5 or gcc 4.9. +template +struct is_default_constructible : std::is_default_constructible {}; + +template +struct is_default_constructible> + : conjunction, is_default_constructible> {}; + +template +struct is_default_constructible> + : conjunction, is_default_constructible> {}; + +template +struct is_default_constructible> + : conjunction...> {}; + +template +struct is_default_constructible> + : conjunction...> {}; + + +template +struct is_constructible : std::is_constructible {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + + +template +struct is_iterator_traits : std::false_type {}; + +template +struct is_iterator_traits> +{ + private: + using traits = iterator_traits; + + public: + static constexpr auto value = + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value; +}; + +template +struct is_range +{ + private: + using t_ref = typename std::add_lvalue_reference::type; + + using iterator = detected_t; + using sentinel = detected_t; + + // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator + // and https://en.cppreference.com/w/cpp/iterator/sentinel_for + // but reimplementing these would be too much work, as a lot of other concepts are used underneath + static constexpr auto is_iterator_begin = + is_iterator_traits>::value; + + public: + static constexpr bool value = !std::is_same::value && !std::is_same::value && is_iterator_begin; +}; + +template +using iterator_t = enable_if_t::value, result_of_begin())>>; + +template +using range_value_t = value_type_t>>; + +// The following implementation of is_complete_type is taken from +// https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/ +// and is written by Xiang Fan who agreed to using it in this library. + +template +struct is_complete_type : std::false_type {}; + +template +struct is_complete_type : std::true_type {}; + +template +struct is_compatible_object_type_impl : std::false_type {}; + +template +struct is_compatible_object_type_impl < + BasicJsonType, CompatibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + using object_t = typename BasicJsonType::object_t; + + // macOS's is_constructible does not play well with nonesuch... + static constexpr bool value = + is_constructible::value && + is_constructible::value; +}; + +template +struct is_compatible_object_type + : is_compatible_object_type_impl {}; + +template +struct is_constructible_object_type_impl : std::false_type {}; + +template +struct is_constructible_object_type_impl < + BasicJsonType, ConstructibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + using object_t = typename BasicJsonType::object_t; + + static constexpr bool value = + (is_default_constructible::value && + (std::is_move_assignable::value || + std::is_copy_assignable::value) && + (is_constructible::value && + std::is_same < + typename object_t::mapped_type, + typename ConstructibleObjectType::mapped_type >::value)) || + (has_from_json::value || + has_non_default_from_json < + BasicJsonType, + typename ConstructibleObjectType::mapped_type >::value); +}; + +template +struct is_constructible_object_type + : is_constructible_object_type_impl {}; + +template +struct is_compatible_string_type +{ + static constexpr auto value = + is_constructible::value; +}; + +template +struct is_constructible_string_type +{ + static constexpr auto value = + is_constructible::value; +}; + +template +struct is_compatible_array_type_impl : std::false_type {}; + +template +struct is_compatible_array_type_impl < + BasicJsonType, CompatibleArrayType, + enable_if_t < + is_detected::value&& + is_iterator_traits>>::value&& +// special case for types like std::filesystem::path whose iterator's value_type are themselves +// c.f. https://github.com/nlohmann/json/pull/3073 + !std::is_same>::value >> +{ + static constexpr bool value = + is_constructible>::value; +}; + +template +struct is_compatible_array_type + : is_compatible_array_type_impl {}; + +template +struct is_constructible_array_type_impl : std::false_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t::value >> + : std::true_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t < !std::is_same::value&& + !is_compatible_string_type::value&& + is_default_constructible::value&& +(std::is_move_assignable::value || + std::is_copy_assignable::value)&& +is_detected::value&& +is_iterator_traits>>::value&& +is_detected::value&& +// special case for types like std::filesystem::path whose iterator's value_type are themselves +// c.f. https://github.com/nlohmann/json/pull/3073 +!std::is_same>::value&& + is_complete_type < + detected_t>::value >> +{ + using value_type = range_value_t; + + static constexpr bool value = + std::is_same::value || + has_from_json::value || + has_non_default_from_json < + BasicJsonType, + value_type >::value; +}; + +template +struct is_constructible_array_type + : is_constructible_array_type_impl {}; + +template +struct is_compatible_integer_type_impl : std::false_type {}; + +template +struct is_compatible_integer_type_impl < + RealIntegerType, CompatibleNumberIntegerType, + enable_if_t < std::is_integral::value&& + std::is_integral::value&& + !std::is_same::value >> +{ + // is there an assert somewhere on overflows? + using RealLimits = std::numeric_limits; + using CompatibleLimits = std::numeric_limits; + + static constexpr auto value = + is_constructible::value && + CompatibleLimits::is_integer && + RealLimits::is_signed == CompatibleLimits::is_signed; +}; + +template +struct is_compatible_integer_type + : is_compatible_integer_type_impl {}; + +template +struct is_compatible_type_impl: std::false_type {}; + +template +struct is_compatible_type_impl < + BasicJsonType, CompatibleType, + enable_if_t::value >> +{ + static constexpr bool value = + has_to_json::value; +}; + +template +struct is_compatible_type + : is_compatible_type_impl {}; + +template +struct is_constructible_tuple : std::false_type {}; + +template +struct is_constructible_tuple> : conjunction...> {}; + +// a naive helper to check if a type is an ordered_map (exploits the fact that +// ordered_map inherits capacity() from std::vector) +template +struct is_ordered_map +{ + using one = char; + + struct two + { + char x[2]; // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) + }; + + template static one test( decltype(&C::capacity) ) ; + template static two test(...); + + enum { value = sizeof(test(nullptr)) == sizeof(char) }; // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg) +}; + +// to avoid useless casts (see https://github.com/nlohmann/json/issues/2893#issuecomment-889152324) +template < typename T, typename U, enable_if_t < !std::is_same::value, int > = 0 > +T conditional_static_cast(U value) +{ + return static_cast(value); +} + +template::value, int> = 0> +T conditional_static_cast(U value) +{ + return value; +} + +} // namespace detail +} // namespace nlohmann + +// #include + + +#ifdef JSON_HAS_CPP_17 + #include +#endif + +namespace nlohmann +{ +namespace detail +{ +template +void from_json(const BasicJsonType& j, typename std::nullptr_t& n) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_null())) + { + JSON_THROW(type_error::create(302, "type must be null, but is " + std::string(j.type_name()), j)); + } + n = nullptr; +} + +// overloads for basic_json template parameters +template < typename BasicJsonType, typename ArithmeticType, + enable_if_t < std::is_arithmetic::value&& + !std::is_same::value, + int > = 0 > +void get_arithmetic_value(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + + case value_t::null: + case value_t::object: + case value_t::array: + case value_t::string: + case value_t::boolean: + case value_t::binary: + case value_t::discarded: + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()), j)); + } +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::boolean_t& b) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_boolean())) + { + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(j.type_name()), j)); + } + b = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::string_t& s) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()), j)); + } + s = *j.template get_ptr(); +} + +template < + typename BasicJsonType, typename ConstructibleStringType, + enable_if_t < + is_constructible_string_type::value&& + !std::is_same::value, + int > = 0 > +void from_json(const BasicJsonType& j, ConstructibleStringType& s) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()), j)); + } + + s = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_float_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_unsigned_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_integer_t& val) +{ + get_arithmetic_value(j, val); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, EnumType& e) +{ + typename std::underlying_type::type val; + get_arithmetic_value(j, val); + e = static_cast(val); +} + +// forward_list doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::forward_list& l) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + l.clear(); + std::transform(j.rbegin(), j.rend(), + std::front_inserter(l), [](const BasicJsonType & i) + { + return i.template get(); + }); +} + +// valarray doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::valarray& l) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + l.resize(j.size()); + std::transform(j.begin(), j.end(), std::begin(l), + [](const BasicJsonType & elem) + { + return elem.template get(); + }); +} + +template +auto from_json(const BasicJsonType& j, T (&arr)[N]) // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) +-> decltype(j.template get(), void()) +{ + for (std::size_t i = 0; i < N; ++i) + { + arr[i] = j.at(i).template get(); + } +} + +template +void from_json_array_impl(const BasicJsonType& j, typename BasicJsonType::array_t& arr, priority_tag<3> /*unused*/) +{ + arr = *j.template get_ptr(); +} + +template +auto from_json_array_impl(const BasicJsonType& j, std::array& arr, + priority_tag<2> /*unused*/) +-> decltype(j.template get(), void()) +{ + for (std::size_t i = 0; i < N; ++i) + { + arr[i] = j.at(i).template get(); + } +} + +template::value, + int> = 0> +auto from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, priority_tag<1> /*unused*/) +-> decltype( + arr.reserve(std::declval()), + j.template get(), + void()) +{ + using std::end; + + ConstructibleArrayType ret; + ret.reserve(j.size()); + std::transform(j.begin(), j.end(), + std::inserter(ret, end(ret)), [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); + arr = std::move(ret); +} + +template::value, + int> = 0> +void from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, + priority_tag<0> /*unused*/) +{ + using std::end; + + ConstructibleArrayType ret; + std::transform( + j.begin(), j.end(), std::inserter(ret, end(ret)), + [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); + arr = std::move(ret); +} + +template < typename BasicJsonType, typename ConstructibleArrayType, + enable_if_t < + is_constructible_array_type::value&& + !is_constructible_object_type::value&& + !is_constructible_string_type::value&& + !std::is_same::value&& + !is_basic_json::value, + int > = 0 > +auto from_json(const BasicJsonType& j, ConstructibleArrayType& arr) +-> decltype(from_json_array_impl(j, arr, priority_tag<3> {}), +j.template get(), +void()) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + + from_json_array_impl(j, arr, priority_tag<3> {}); +} + +template < typename BasicJsonType, typename T, std::size_t... Idx > +std::array from_json_inplace_array_impl(BasicJsonType&& j, + identity_tag> /*unused*/, index_sequence /*unused*/) +{ + return { { std::forward(j).at(Idx).template get()... } }; +} + +template < typename BasicJsonType, typename T, std::size_t N > +auto from_json(BasicJsonType&& j, identity_tag> tag) +-> decltype(from_json_inplace_array_impl(std::forward(j), tag, make_index_sequence {})) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + + return from_json_inplace_array_impl(std::forward(j), tag, make_index_sequence {}); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::binary_t& bin) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_binary())) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(j.type_name()), j)); + } + + bin = *j.template get_ptr(); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, ConstructibleObjectType& obj) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_object())) + { + JSON_THROW(type_error::create(302, "type must be object, but is " + std::string(j.type_name()), j)); + } + + ConstructibleObjectType ret; + const auto* inner_object = j.template get_ptr(); + using value_type = typename ConstructibleObjectType::value_type; + std::transform( + inner_object->begin(), inner_object->end(), + std::inserter(ret, ret.begin()), + [](typename BasicJsonType::object_t::value_type const & p) + { + return value_type(p.first, p.second.template get()); + }); + obj = std::move(ret); +} + +// overload for arithmetic types, not chosen for basic_json template arguments +// (BooleanType, etc..); note: Is it really necessary to provide explicit +// overloads for boolean_t etc. in case of a custom BooleanType which is not +// an arithmetic type? +template < typename BasicJsonType, typename ArithmeticType, + enable_if_t < + std::is_arithmetic::value&& + !std::is_same::value&& + !std::is_same::value&& + !std::is_same::value&& + !std::is_same::value, + int > = 0 > +void from_json(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::boolean: + { + val = static_cast(*j.template get_ptr()); + break; + } + + case value_t::null: + case value_t::object: + case value_t::array: + case value_t::string: + case value_t::binary: + case value_t::discarded: + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()), j)); + } +} + +template +std::tuple from_json_tuple_impl_base(BasicJsonType&& j, index_sequence /*unused*/) +{ + return std::make_tuple(std::forward(j).at(Idx).template get()...); +} + +template < typename BasicJsonType, class A1, class A2 > +std::pair from_json_tuple_impl(BasicJsonType&& j, identity_tag> /*unused*/, priority_tag<0> /*unused*/) +{ + return {std::forward(j).at(0).template get(), + std::forward(j).at(1).template get()}; +} + +template +void from_json_tuple_impl(BasicJsonType&& j, std::pair& p, priority_tag<1> /*unused*/) +{ + p = from_json_tuple_impl(std::forward(j), identity_tag> {}, priority_tag<0> {}); +} + +template +std::tuple from_json_tuple_impl(BasicJsonType&& j, identity_tag> /*unused*/, priority_tag<2> /*unused*/) +{ + return from_json_tuple_impl_base(std::forward(j), index_sequence_for {}); +} + +template +void from_json_tuple_impl(BasicJsonType&& j, std::tuple& t, priority_tag<3> /*unused*/) +{ + t = from_json_tuple_impl_base(std::forward(j), index_sequence_for {}); +} + +template +auto from_json(BasicJsonType&& j, TupleRelated&& t) +-> decltype(from_json_tuple_impl(std::forward(j), std::forward(t), priority_tag<3> {})) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + + return from_json_tuple_impl(std::forward(j), std::forward(t), priority_tag<3> {}); +} + +template < typename BasicJsonType, typename Key, typename Value, typename Compare, typename Allocator, + typename = enable_if_t < !std::is_constructible < + typename BasicJsonType::string_t, Key >::value >> +void from_json(const BasicJsonType& j, std::map& m) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + m.clear(); + for (const auto& p : j) + { + if (JSON_HEDLEY_UNLIKELY(!p.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()), j)); + } + m.emplace(p.at(0).template get(), p.at(1).template get()); + } +} + +template < typename BasicJsonType, typename Key, typename Value, typename Hash, typename KeyEqual, typename Allocator, + typename = enable_if_t < !std::is_constructible < + typename BasicJsonType::string_t, Key >::value >> +void from_json(const BasicJsonType& j, std::unordered_map& m) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()), j)); + } + m.clear(); + for (const auto& p : j) + { + if (JSON_HEDLEY_UNLIKELY(!p.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()), j)); + } + m.emplace(p.at(0).template get(), p.at(1).template get()); + } +} + +#ifdef JSON_HAS_CPP_17 +template +void from_json(const BasicJsonType& j, std::filesystem::path& p) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()), j)); + } + p = *j.template get_ptr(); +} +#endif + +struct from_json_fn +{ + template + auto operator()(const BasicJsonType& j, T&& val) const + noexcept(noexcept(from_json(j, std::forward(val)))) + -> decltype(from_json(j, std::forward(val))) + { + return from_json(j, std::forward(val)); + } +}; +} // namespace detail + +/// namespace to hold default `from_json` function +/// to see why this is required: +/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html +namespace // NOLINT(cert-dcl59-cpp,fuchsia-header-anon-namespaces,google-build-namespaces) +{ +constexpr const auto& from_json = detail::static_const::value; // NOLINT(misc-definitions-in-headers) +} // namespace +} // namespace nlohmann + +// #include + + +#include // copy +#include // begin, end +#include // string +#include // tuple, get +#include // is_same, is_constructible, is_floating_point, is_enum, underlying_type +#include // move, forward, declval, pair +#include // valarray +#include // vector + +// #include + +// #include + + +#include // size_t +#include // input_iterator_tag +#include // string, to_string +#include // tuple_size, get, tuple_element +#include // move + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +void int_to_string( string_type& target, std::size_t value ) +{ + // For ADL + using std::to_string; + target = to_string(value); +} +template class iteration_proxy_value +{ + public: + using difference_type = std::ptrdiff_t; + using value_type = iteration_proxy_value; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::input_iterator_tag; + using string_type = typename std::remove_cv< typename std::remove_reference().key() ) >::type >::type; + + private: + /// the iterator + IteratorType anchor; + /// an index for arrays (used to create key names) + std::size_t array_index = 0; + /// last stringified array index + mutable std::size_t array_index_last = 0; + /// a string representation of the array index + mutable string_type array_index_str = "0"; + /// an empty string (to return a reference for primitive values) + const string_type empty_str{}; + + public: + explicit iteration_proxy_value(IteratorType it) noexcept + : anchor(std::move(it)) + {} + + /// dereference operator (needed for range-based for) + iteration_proxy_value& operator*() + { + return *this; + } + + /// increment operator (needed for range-based for) + iteration_proxy_value& operator++() + { + ++anchor; + ++array_index; + + return *this; + } + + /// equality operator (needed for InputIterator) + bool operator==(const iteration_proxy_value& o) const + { + return anchor == o.anchor; + } + + /// inequality operator (needed for range-based for) + bool operator!=(const iteration_proxy_value& o) const + { + return anchor != o.anchor; + } + + /// return key of the iterator + const string_type& key() const + { + JSON_ASSERT(anchor.m_object != nullptr); + + switch (anchor.m_object->type()) + { + // use integer array index as key + case value_t::array: + { + if (array_index != array_index_last) + { + int_to_string( array_index_str, array_index ); + array_index_last = array_index; + } + return array_index_str; + } + + // use key from the object + case value_t::object: + return anchor.key(); + + // use an empty key for all primitive types + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + return empty_str; + } + } + + /// return value of the iterator + typename IteratorType::reference value() const + { + return anchor.value(); + } +}; + +/// proxy class for the items() function +template class iteration_proxy +{ + private: + /// the container to iterate + typename IteratorType::reference container; + + public: + /// construct iteration proxy from a container + explicit iteration_proxy(typename IteratorType::reference cont) noexcept + : container(cont) {} + + /// return iterator begin (needed for range-based for) + iteration_proxy_value begin() noexcept + { + return iteration_proxy_value(container.begin()); + } + + /// return iterator end (needed for range-based for) + iteration_proxy_value end() noexcept + { + return iteration_proxy_value(container.end()); + } +}; +// Structured Bindings Support +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +template = 0> +auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.key()) +{ + return i.key(); +} +// Structured Bindings Support +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +template = 0> +auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.value()) +{ + return i.value(); +} +} // namespace detail +} // namespace nlohmann + +// The Addition to the STD Namespace is required to add +// Structured Bindings Support to the iteration_proxy_value class +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +namespace std +{ +#if defined(__clang__) + // Fix: https://github.com/nlohmann/json/issues/1401 + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wmismatched-tags" +#endif +template +class tuple_size<::nlohmann::detail::iteration_proxy_value> + : public std::integral_constant {}; + +template +class tuple_element> +{ + public: + using type = decltype( + get(std::declval < + ::nlohmann::detail::iteration_proxy_value> ())); +}; +#if defined(__clang__) + #pragma clang diagnostic pop +#endif +} // namespace std + +// #include + +// #include + +// #include + + +#ifdef JSON_HAS_CPP_17 + #include +#endif + +namespace nlohmann +{ +namespace detail +{ +////////////////// +// constructors // +////////////////// + +/* + * Note all external_constructor<>::construct functions need to call + * j.m_value.destroy(j.m_type) to avoid a memory leak in case j contains an + * allocated value (e.g., a string). See bug issue + * https://github.com/nlohmann/json/issues/2865 for more information. + */ + +template struct external_constructor; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::boolean_t b) noexcept + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::boolean; + j.m_value = b; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::string_t& s) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::string; + j.m_value = s; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::string_t&& s) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::string; + j.m_value = std::move(s); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleStringType, + enable_if_t < !std::is_same::value, + int > = 0 > + static void construct(BasicJsonType& j, const CompatibleStringType& str) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::string; + j.m_value.string = j.template create(str); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::binary_t& b) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::binary; + j.m_value = typename BasicJsonType::binary_t(b); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::binary_t&& b) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::binary; + j.m_value = typename BasicJsonType::binary_t(std::move(b)); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_float_t val) noexcept + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::number_float; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_unsigned_t val) noexcept + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::number_unsigned; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_integer_t val) noexcept + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::number_integer; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::array_t& arr) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::array; + j.m_value = arr; + j.set_parents(); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::array_t&& arr) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::array; + j.m_value = std::move(arr); + j.set_parents(); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleArrayType, + enable_if_t < !std::is_same::value, + int > = 0 > + static void construct(BasicJsonType& j, const CompatibleArrayType& arr) + { + using std::begin; + using std::end; + + j.m_value.destroy(j.m_type); + j.m_type = value_t::array; + j.m_value.array = j.template create(begin(arr), end(arr)); + j.set_parents(); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, const std::vector& arr) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array->reserve(arr.size()); + for (const bool x : arr) + { + j.m_value.array->push_back(x); + j.set_parent(j.m_value.array->back()); + } + j.assert_invariant(); + } + + template::value, int> = 0> + static void construct(BasicJsonType& j, const std::valarray& arr) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array->resize(arr.size()); + if (arr.size() > 0) + { + std::copy(std::begin(arr), std::end(arr), j.m_value.array->begin()); + } + j.set_parents(); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::object_t& obj) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::object; + j.m_value = obj; + j.set_parents(); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::object_t&& obj) + { + j.m_value.destroy(j.m_type); + j.m_type = value_t::object; + j.m_value = std::move(obj); + j.set_parents(); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleObjectType, + enable_if_t < !std::is_same::value, int > = 0 > + static void construct(BasicJsonType& j, const CompatibleObjectType& obj) + { + using std::begin; + using std::end; + + j.m_value.destroy(j.m_type); + j.m_type = value_t::object; + j.m_value.object = j.template create(begin(obj), end(obj)); + j.set_parents(); + j.assert_invariant(); + } +}; + +///////////// +// to_json // +///////////// + +template::value, int> = 0> +void to_json(BasicJsonType& j, T b) noexcept +{ + external_constructor::construct(j, b); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, const CompatibleString& s) +{ + external_constructor::construct(j, s); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::string_t&& s) +{ + external_constructor::construct(j, std::move(s)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, FloatType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, CompatibleNumberUnsignedType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, CompatibleNumberIntegerType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, EnumType e) noexcept +{ + using underlying_type = typename std::underlying_type::type; + external_constructor::construct(j, static_cast(e)); +} + +template +void to_json(BasicJsonType& j, const std::vector& e) +{ + external_constructor::construct(j, e); +} + +template < typename BasicJsonType, typename CompatibleArrayType, + enable_if_t < is_compatible_array_type::value&& + !is_compatible_object_type::value&& + !is_compatible_string_type::value&& + !std::is_same::value&& + !is_basic_json::value, + int > = 0 > +void to_json(BasicJsonType& j, const CompatibleArrayType& arr) +{ + external_constructor::construct(j, arr); +} + +template +void to_json(BasicJsonType& j, const typename BasicJsonType::binary_t& bin) +{ + external_constructor::construct(j, bin); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, const std::valarray& arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::array_t&& arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template < typename BasicJsonType, typename CompatibleObjectType, + enable_if_t < is_compatible_object_type::value&& !is_basic_json::value, int > = 0 > +void to_json(BasicJsonType& j, const CompatibleObjectType& obj) +{ + external_constructor::construct(j, obj); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::object_t&& obj) +{ + external_constructor::construct(j, std::move(obj)); +} + +template < + typename BasicJsonType, typename T, std::size_t N, + enable_if_t < !std::is_constructible::value, // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) + int > = 0 > +void to_json(BasicJsonType& j, const T(&arr)[N]) // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) +{ + external_constructor::construct(j, arr); +} + +template < typename BasicJsonType, typename T1, typename T2, enable_if_t < std::is_constructible::value&& std::is_constructible::value, int > = 0 > +void to_json(BasicJsonType& j, const std::pair& p) +{ + j = { p.first, p.second }; +} + +// for https://github.com/nlohmann/json/pull/1134 +template>::value, int> = 0> +void to_json(BasicJsonType& j, const T& b) +{ + j = { {b.key(), b.value()} }; +} + +template +void to_json_tuple_impl(BasicJsonType& j, const Tuple& t, index_sequence /*unused*/) +{ + j = { std::get(t)... }; +} + +template::value, int > = 0> +void to_json(BasicJsonType& j, const T& t) +{ + to_json_tuple_impl(j, t, make_index_sequence::value> {}); +} + +#ifdef JSON_HAS_CPP_17 +template +void to_json(BasicJsonType& j, const std::filesystem::path& p) +{ + j = p.string(); +} +#endif + +struct to_json_fn +{ + template + auto operator()(BasicJsonType& j, T&& val) const noexcept(noexcept(to_json(j, std::forward(val)))) + -> decltype(to_json(j, std::forward(val)), void()) + { + return to_json(j, std::forward(val)); + } +}; +} // namespace detail + +/// namespace to hold default `to_json` function +/// to see why this is required: +/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html +namespace // NOLINT(cert-dcl59-cpp,fuchsia-header-anon-namespaces,google-build-namespaces) +{ +constexpr const auto& to_json = detail::static_const::value; // NOLINT(misc-definitions-in-headers) +} // namespace +} // namespace nlohmann + +// #include + +// #include + + +namespace nlohmann +{ + +template +struct adl_serializer +{ + /*! + @brief convert a JSON value to any value type + + This function is usually called by the `get()` function of the + @ref basic_json class (either explicit or via conversion operators). + + @note This function is chosen for default-constructible value types. + + @param[in] j JSON value to read from + @param[in,out] val value to write to + */ + template + static auto from_json(BasicJsonType && j, TargetType& val) noexcept( + noexcept(::nlohmann::from_json(std::forward(j), val))) + -> decltype(::nlohmann::from_json(std::forward(j), val), void()) + { + ::nlohmann::from_json(std::forward(j), val); + } + + /*! + @brief convert a JSON value to any value type + + This function is usually called by the `get()` function of the + @ref basic_json class (either explicit or via conversion operators). + + @note This function is chosen for value types which are not default-constructible. + + @param[in] j JSON value to read from + + @return copy of the JSON value, converted to @a ValueType + */ + template + static auto from_json(BasicJsonType && j) noexcept( + noexcept(::nlohmann::from_json(std::forward(j), detail::identity_tag {}))) + -> decltype(::nlohmann::from_json(std::forward(j), detail::identity_tag {})) + { + return ::nlohmann::from_json(std::forward(j), detail::identity_tag {}); + } + + /*! + @brief convert any value type to a JSON value + + This function is usually called by the constructors of the @ref basic_json + class. + + @param[in,out] j JSON value to write to + @param[in] val value to read from + */ + template + static auto to_json(BasicJsonType& j, TargetType && val) noexcept( + noexcept(::nlohmann::to_json(j, std::forward(val)))) + -> decltype(::nlohmann::to_json(j, std::forward(val)), void()) + { + ::nlohmann::to_json(j, std::forward(val)); + } +}; +} // namespace nlohmann + +// #include + + +#include // uint8_t, uint64_t +#include // tie +#include // move + +namespace nlohmann +{ + +/*! +@brief an internal type for a backed binary type + +This type extends the template parameter @a BinaryType provided to `basic_json` +with a subtype used by BSON and MessagePack. This type exists so that the user +does not have to specify a type themselves with a specific naming scheme in +order to override the binary type. + +@tparam BinaryType container to store bytes (`std::vector` by + default) + +@since version 3.8.0; changed type of subtypes to std::uint64_t in 3.10.0. +*/ +template +class byte_container_with_subtype : public BinaryType +{ + public: + /// the type of the underlying container + using container_type = BinaryType; + /// the type of the subtype + using subtype_type = std::uint64_t; + + byte_container_with_subtype() noexcept(noexcept(container_type())) + : container_type() + {} + + byte_container_with_subtype(const container_type& b) noexcept(noexcept(container_type(b))) + : container_type(b) + {} + + byte_container_with_subtype(container_type&& b) noexcept(noexcept(container_type(std::move(b)))) + : container_type(std::move(b)) + {} + + byte_container_with_subtype(const container_type& b, subtype_type subtype_) noexcept(noexcept(container_type(b))) + : container_type(b) + , m_subtype(subtype_) + , m_has_subtype(true) + {} + + byte_container_with_subtype(container_type&& b, subtype_type subtype_) noexcept(noexcept(container_type(std::move(b)))) + : container_type(std::move(b)) + , m_subtype(subtype_) + , m_has_subtype(true) + {} + + bool operator==(const byte_container_with_subtype& rhs) const + { + return std::tie(static_cast(*this), m_subtype, m_has_subtype) == + std::tie(static_cast(rhs), rhs.m_subtype, rhs.m_has_subtype); + } + + bool operator!=(const byte_container_with_subtype& rhs) const + { + return !(rhs == *this); + } + + /*! + @brief sets the binary subtype + + Sets the binary subtype of the value, also flags a binary JSON value as + having a subtype, which has implications for serialization. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa see @ref subtype() -- return the binary subtype + @sa see @ref clear_subtype() -- clears the binary subtype + @sa see @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0 + */ + void set_subtype(subtype_type subtype_) noexcept + { + m_subtype = subtype_; + m_has_subtype = true; + } + + /*! + @brief return the binary subtype + + Returns the numerical subtype of the value if it has a subtype. If it does + not have a subtype, this function will return subtype_type(-1) as a sentinel + value. + + @return the numerical subtype of the binary value + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa see @ref set_subtype() -- sets the binary subtype + @sa see @ref clear_subtype() -- clears the binary subtype + @sa see @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0; fixed return value to properly return + subtype_type(-1) as documented in version 3.10.0 + */ + constexpr subtype_type subtype() const noexcept + { + return m_has_subtype ? m_subtype : subtype_type(-1); + } + + /*! + @brief return whether the value has a subtype + + @return whether the value has a subtype + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa see @ref subtype() -- return the binary subtype + @sa see @ref set_subtype() -- sets the binary subtype + @sa see @ref clear_subtype() -- clears the binary subtype + + @since version 3.8.0 + */ + constexpr bool has_subtype() const noexcept + { + return m_has_subtype; + } + + /*! + @brief clears the binary subtype + + Clears the binary subtype and flags the value as not having a subtype, which + has implications for serialization; for instance MessagePack will prefer the + bin family over the ext family. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa see @ref subtype() -- return the binary subtype + @sa see @ref set_subtype() -- sets the binary subtype + @sa see @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0 + */ + void clear_subtype() noexcept + { + m_subtype = 0; + m_has_subtype = false; + } + + private: + subtype_type m_subtype = 0; + bool m_has_subtype = false; +}; + +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + + +#include // uint8_t +#include // size_t +#include // hash + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +// boost::hash_combine +inline std::size_t combine(std::size_t seed, std::size_t h) noexcept +{ + seed ^= h + 0x9e3779b9 + (seed << 6U) + (seed >> 2U); + return seed; +} + +/*! +@brief hash a JSON value + +The hash function tries to rely on std::hash where possible. Furthermore, the +type of the JSON value is taken into account to have different hash values for +null, 0, 0U, and false, etc. + +@tparam BasicJsonType basic_json specialization +@param j JSON value to hash +@return hash value of j +*/ +template +std::size_t hash(const BasicJsonType& j) +{ + using string_t = typename BasicJsonType::string_t; + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + + const auto type = static_cast(j.type()); + switch (j.type()) + { + case BasicJsonType::value_t::null: + case BasicJsonType::value_t::discarded: + { + return combine(type, 0); + } + + case BasicJsonType::value_t::object: + { + auto seed = combine(type, j.size()); + for (const auto& element : j.items()) + { + const auto h = std::hash {}(element.key()); + seed = combine(seed, h); + seed = combine(seed, hash(element.value())); + } + return seed; + } + + case BasicJsonType::value_t::array: + { + auto seed = combine(type, j.size()); + for (const auto& element : j) + { + seed = combine(seed, hash(element)); + } + return seed; + } + + case BasicJsonType::value_t::string: + { + const auto h = std::hash {}(j.template get_ref()); + return combine(type, h); + } + + case BasicJsonType::value_t::boolean: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case BasicJsonType::value_t::number_integer: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case BasicJsonType::value_t::number_unsigned: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case BasicJsonType::value_t::number_float: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case BasicJsonType::value_t::binary: + { + auto seed = combine(type, j.get_binary().size()); + const auto h = std::hash {}(j.get_binary().has_subtype()); + seed = combine(seed, h); + seed = combine(seed, static_cast(j.get_binary().subtype())); + for (const auto byte : j.get_binary()) + { + seed = combine(seed, std::hash {}(byte)); + } + return seed; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + return 0; // LCOV_EXCL_LINE + } +} + +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // generate_n +#include // array +#include // ldexp +#include // size_t +#include // uint8_t, uint16_t, uint32_t, uint64_t +#include // snprintf +#include // memcpy +#include // back_inserter +#include // numeric_limits +#include // char_traits, string +#include // make_pair, move +#include // vector + +// #include + +// #include + + +#include // array +#include // size_t +#include // strlen +#include // begin, end, iterator_traits, random_access_iterator_tag, distance, next +#include // shared_ptr, make_shared, addressof +#include // accumulate +#include // string, char_traits +#include // enable_if, is_base_of, is_pointer, is_integral, remove_pointer +#include // pair, declval + +#ifndef JSON_NO_IO + #include // FILE * + #include // istream +#endif // JSON_NO_IO + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/// the supported input formats +enum class input_format_t { json, cbor, msgpack, ubjson, bson }; + +//////////////////// +// input adapters // +//////////////////// + +#ifndef JSON_NO_IO +/*! +Input adapter for stdio file access. This adapter read only 1 byte and do not use any + buffer. This adapter is a very low level adapter. +*/ +class file_input_adapter +{ + public: + using char_type = char; + + JSON_HEDLEY_NON_NULL(2) + explicit file_input_adapter(std::FILE* f) noexcept + : m_file(f) + {} + + // make class move-only + file_input_adapter(const file_input_adapter&) = delete; + file_input_adapter(file_input_adapter&&) noexcept = default; + file_input_adapter& operator=(const file_input_adapter&) = delete; + file_input_adapter& operator=(file_input_adapter&&) = delete; + ~file_input_adapter() = default; + + std::char_traits::int_type get_character() noexcept + { + return std::fgetc(m_file); + } + + private: + /// the file pointer to read from + std::FILE* m_file; +}; + + +/*! +Input adapter for a (caching) istream. Ignores a UFT Byte Order Mark at +beginning of input. Does not support changing the underlying std::streambuf +in mid-input. Maintains underlying std::istream and std::streambuf to support +subsequent use of standard std::istream operations to process any input +characters following those used in parsing the JSON input. Clears the +std::istream flags; any input errors (e.g., EOF) will be detected by the first +subsequent call for input from the std::istream. +*/ +class input_stream_adapter +{ + public: + using char_type = char; + + ~input_stream_adapter() + { + // clear stream flags; we use underlying streambuf I/O, do not + // maintain ifstream flags, except eof + if (is != nullptr) + { + is->clear(is->rdstate() & std::ios::eofbit); + } + } + + explicit input_stream_adapter(std::istream& i) + : is(&i), sb(i.rdbuf()) + {} + + // delete because of pointer members + input_stream_adapter(const input_stream_adapter&) = delete; + input_stream_adapter& operator=(input_stream_adapter&) = delete; + input_stream_adapter& operator=(input_stream_adapter&&) = delete; + + input_stream_adapter(input_stream_adapter&& rhs) noexcept + : is(rhs.is), sb(rhs.sb) + { + rhs.is = nullptr; + rhs.sb = nullptr; + } + + // std::istream/std::streambuf use std::char_traits::to_int_type, to + // ensure that std::char_traits::eof() and the character 0xFF do not + // end up as the same value, eg. 0xFFFFFFFF. + std::char_traits::int_type get_character() + { + auto res = sb->sbumpc(); + // set eof manually, as we don't use the istream interface. + if (JSON_HEDLEY_UNLIKELY(res == std::char_traits::eof())) + { + is->clear(is->rdstate() | std::ios::eofbit); + } + return res; + } + + private: + /// the associated input stream + std::istream* is = nullptr; + std::streambuf* sb = nullptr; +}; +#endif // JSON_NO_IO + +// General-purpose iterator-based adapter. It might not be as fast as +// theoretically possible for some containers, but it is extremely versatile. +template +class iterator_input_adapter +{ + public: + using char_type = typename std::iterator_traits::value_type; + + iterator_input_adapter(IteratorType first, IteratorType last) + : current(std::move(first)), end(std::move(last)) + {} + + typename std::char_traits::int_type get_character() + { + if (JSON_HEDLEY_LIKELY(current != end)) + { + auto result = std::char_traits::to_int_type(*current); + std::advance(current, 1); + return result; + } + + return std::char_traits::eof(); + } + + private: + IteratorType current; + IteratorType end; + + template + friend struct wide_string_input_helper; + + bool empty() const + { + return current == end; + } +}; + + +template +struct wide_string_input_helper; + +template +struct wide_string_input_helper +{ + // UTF-32 + static void fill_buffer(BaseInputAdapter& input, + std::array::int_type, 4>& utf8_bytes, + size_t& utf8_bytes_index, + size_t& utf8_bytes_filled) + { + utf8_bytes_index = 0; + + if (JSON_HEDLEY_UNLIKELY(input.empty())) + { + utf8_bytes[0] = std::char_traits::eof(); + utf8_bytes_filled = 1; + } + else + { + // get the current character + const auto wc = input.get_character(); + + // UTF-32 to UTF-8 encoding + if (wc < 0x80) + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + else if (wc <= 0x7FF) + { + utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u) & 0x1Fu)); + utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 2; + } + else if (wc <= 0xFFFF) + { + utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u) & 0x0Fu)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 3; + } + else if (wc <= 0x10FFFF) + { + utf8_bytes[0] = static_cast::int_type>(0xF0u | ((static_cast(wc) >> 18u) & 0x07u)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 12u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[3] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 4; + } + else + { + // unknown character + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + } + } +}; + +template +struct wide_string_input_helper +{ + // UTF-16 + static void fill_buffer(BaseInputAdapter& input, + std::array::int_type, 4>& utf8_bytes, + size_t& utf8_bytes_index, + size_t& utf8_bytes_filled) + { + utf8_bytes_index = 0; + + if (JSON_HEDLEY_UNLIKELY(input.empty())) + { + utf8_bytes[0] = std::char_traits::eof(); + utf8_bytes_filled = 1; + } + else + { + // get the current character + const auto wc = input.get_character(); + + // UTF-16 to UTF-8 encoding + if (wc < 0x80) + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + else if (wc <= 0x7FF) + { + utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u))); + utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 2; + } + else if (0xD800 > wc || wc >= 0xE000) + { + utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u))); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 3; + } + else + { + if (JSON_HEDLEY_UNLIKELY(!input.empty())) + { + const auto wc2 = static_cast(input.get_character()); + const auto charcode = 0x10000u + (((static_cast(wc) & 0x3FFu) << 10u) | (wc2 & 0x3FFu)); + utf8_bytes[0] = static_cast::int_type>(0xF0u | (charcode >> 18u)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((charcode >> 12u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | ((charcode >> 6u) & 0x3Fu)); + utf8_bytes[3] = static_cast::int_type>(0x80u | (charcode & 0x3Fu)); + utf8_bytes_filled = 4; + } + else + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + } + } + } +}; + +// Wraps another input apdater to convert wide character types into individual bytes. +template +class wide_string_input_adapter +{ + public: + using char_type = char; + + wide_string_input_adapter(BaseInputAdapter base) + : base_adapter(base) {} + + typename std::char_traits::int_type get_character() noexcept + { + // check if buffer needs to be filled + if (utf8_bytes_index == utf8_bytes_filled) + { + fill_buffer(); + + JSON_ASSERT(utf8_bytes_filled > 0); + JSON_ASSERT(utf8_bytes_index == 0); + } + + // use buffer + JSON_ASSERT(utf8_bytes_filled > 0); + JSON_ASSERT(utf8_bytes_index < utf8_bytes_filled); + return utf8_bytes[utf8_bytes_index++]; + } + + private: + BaseInputAdapter base_adapter; + + template + void fill_buffer() + { + wide_string_input_helper::fill_buffer(base_adapter, utf8_bytes, utf8_bytes_index, utf8_bytes_filled); + } + + /// a buffer for UTF-8 bytes + std::array::int_type, 4> utf8_bytes = {{0, 0, 0, 0}}; + + /// index to the utf8_codes array for the next valid byte + std::size_t utf8_bytes_index = 0; + /// number of valid bytes in the utf8_codes array + std::size_t utf8_bytes_filled = 0; +}; + + +template +struct iterator_input_adapter_factory +{ + using iterator_type = IteratorType; + using char_type = typename std::iterator_traits::value_type; + using adapter_type = iterator_input_adapter; + + static adapter_type create(IteratorType first, IteratorType last) + { + return adapter_type(std::move(first), std::move(last)); + } +}; + +template +struct is_iterator_of_multibyte +{ + using value_type = typename std::iterator_traits::value_type; + enum + { + value = sizeof(value_type) > 1 + }; +}; + +template +struct iterator_input_adapter_factory::value>> +{ + using iterator_type = IteratorType; + using char_type = typename std::iterator_traits::value_type; + using base_adapter_type = iterator_input_adapter; + using adapter_type = wide_string_input_adapter; + + static adapter_type create(IteratorType first, IteratorType last) + { + return adapter_type(base_adapter_type(std::move(first), std::move(last))); + } +}; + +// General purpose iterator-based input +template +typename iterator_input_adapter_factory::adapter_type input_adapter(IteratorType first, IteratorType last) +{ + using factory_type = iterator_input_adapter_factory; + return factory_type::create(first, last); +} + +// Convenience shorthand from container to iterator +// Enables ADL on begin(container) and end(container) +// Encloses the using declarations in namespace for not to leak them to outside scope + +namespace container_input_adapter_factory_impl +{ + +using std::begin; +using std::end; + +template +struct container_input_adapter_factory {}; + +template +struct container_input_adapter_factory< ContainerType, + void_t()), end(std::declval()))>> + { + using adapter_type = decltype(input_adapter(begin(std::declval()), end(std::declval()))); + + static adapter_type create(const ContainerType& container) +{ + return input_adapter(begin(container), end(container)); +} + }; + +} // namespace container_input_adapter_factory_impl + +template +typename container_input_adapter_factory_impl::container_input_adapter_factory::adapter_type input_adapter(const ContainerType& container) +{ + return container_input_adapter_factory_impl::container_input_adapter_factory::create(container); +} + +#ifndef JSON_NO_IO +// Special cases with fast paths +inline file_input_adapter input_adapter(std::FILE* file) +{ + return file_input_adapter(file); +} + +inline input_stream_adapter input_adapter(std::istream& stream) +{ + return input_stream_adapter(stream); +} + +inline input_stream_adapter input_adapter(std::istream&& stream) +{ + return input_stream_adapter(stream); +} +#endif // JSON_NO_IO + +using contiguous_bytes_input_adapter = decltype(input_adapter(std::declval(), std::declval())); + +// Null-delimited strings, and the like. +template < typename CharT, + typename std::enable_if < + std::is_pointer::value&& + !std::is_array::value&& + std::is_integral::type>::value&& + sizeof(typename std::remove_pointer::type) == 1, + int >::type = 0 > +contiguous_bytes_input_adapter input_adapter(CharT b) +{ + auto length = std::strlen(reinterpret_cast(b)); + const auto* ptr = reinterpret_cast(b); + return input_adapter(ptr, ptr + length); +} + +template +auto input_adapter(T (&array)[N]) -> decltype(input_adapter(array, array + N)) // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) +{ + return input_adapter(array, array + N); +} + +// This class only handles inputs of input_buffer_adapter type. +// It's required so that expressions like {ptr, len} can be implicitely casted +// to the correct adapter. +class span_input_adapter +{ + public: + template < typename CharT, + typename std::enable_if < + std::is_pointer::value&& + std::is_integral::type>::value&& + sizeof(typename std::remove_pointer::type) == 1, + int >::type = 0 > + span_input_adapter(CharT b, std::size_t l) + : ia(reinterpret_cast(b), reinterpret_cast(b) + l) {} + + template::iterator_category, std::random_access_iterator_tag>::value, + int>::type = 0> + span_input_adapter(IteratorType first, IteratorType last) + : ia(input_adapter(first, last)) {} + + contiguous_bytes_input_adapter&& get() + { + return std::move(ia); // NOLINT(hicpp-move-const-arg,performance-move-const-arg) + } + + private: + contiguous_bytes_input_adapter ia; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include +#include // string +#include // move +#include // vector + +// #include + +// #include + + +namespace nlohmann +{ + +/*! +@brief SAX interface + +This class describes the SAX interface used by @ref nlohmann::json::sax_parse. +Each function is called in different situations while the input is parsed. The +boolean return value informs the parser whether to continue processing the +input. +*/ +template +struct json_sax +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + /*! + @brief a null value was read + @return whether parsing should proceed + */ + virtual bool null() = 0; + + /*! + @brief a boolean value was read + @param[in] val boolean value + @return whether parsing should proceed + */ + virtual bool boolean(bool val) = 0; + + /*! + @brief an integer number was read + @param[in] val integer value + @return whether parsing should proceed + */ + virtual bool number_integer(number_integer_t val) = 0; + + /*! + @brief an unsigned integer number was read + @param[in] val unsigned integer value + @return whether parsing should proceed + */ + virtual bool number_unsigned(number_unsigned_t val) = 0; + + /*! + @brief an floating-point number was read + @param[in] val floating-point value + @param[in] s raw token value + @return whether parsing should proceed + */ + virtual bool number_float(number_float_t val, const string_t& s) = 0; + + /*! + @brief a string was read + @param[in] val string value + @return whether parsing should proceed + @note It is safe to move the passed string. + */ + virtual bool string(string_t& val) = 0; + + /*! + @brief a binary string was read + @param[in] val binary value + @return whether parsing should proceed + @note It is safe to move the passed binary. + */ + virtual bool binary(binary_t& val) = 0; + + /*! + @brief the beginning of an object was read + @param[in] elements number of object elements or -1 if unknown + @return whether parsing should proceed + @note binary formats may report the number of elements + */ + virtual bool start_object(std::size_t elements) = 0; + + /*! + @brief an object key was read + @param[in] val object key + @return whether parsing should proceed + @note It is safe to move the passed string. + */ + virtual bool key(string_t& val) = 0; + + /*! + @brief the end of an object was read + @return whether parsing should proceed + */ + virtual bool end_object() = 0; + + /*! + @brief the beginning of an array was read + @param[in] elements number of array elements or -1 if unknown + @return whether parsing should proceed + @note binary formats may report the number of elements + */ + virtual bool start_array(std::size_t elements) = 0; + + /*! + @brief the end of an array was read + @return whether parsing should proceed + */ + virtual bool end_array() = 0; + + /*! + @brief a parse error occurred + @param[in] position the position in the input where the error occurs + @param[in] last_token the last read token + @param[in] ex an exception object describing the error + @return whether parsing should proceed (must return false) + */ + virtual bool parse_error(std::size_t position, + const std::string& last_token, + const detail::exception& ex) = 0; + + json_sax() = default; + json_sax(const json_sax&) = default; + json_sax(json_sax&&) noexcept = default; + json_sax& operator=(const json_sax&) = default; + json_sax& operator=(json_sax&&) noexcept = default; + virtual ~json_sax() = default; +}; + + +namespace detail +{ +/*! +@brief SAX implementation to create a JSON value from SAX events + +This class implements the @ref json_sax interface and processes the SAX events +to create a JSON value which makes it basically a DOM parser. The structure or +hierarchy of the JSON value is managed by the stack `ref_stack` which contains +a pointer to the respective array or object for each recursion depth. + +After successful parsing, the value that is passed by reference to the +constructor contains the parsed value. + +@tparam BasicJsonType the JSON type +*/ +template +class json_sax_dom_parser +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + /*! + @param[in,out] r reference to a JSON value that is manipulated while + parsing + @param[in] allow_exceptions_ whether parse errors yield exceptions + */ + explicit json_sax_dom_parser(BasicJsonType& r, const bool allow_exceptions_ = true) + : root(r), allow_exceptions(allow_exceptions_) + {} + + // make class move-only + json_sax_dom_parser(const json_sax_dom_parser&) = delete; + json_sax_dom_parser(json_sax_dom_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + json_sax_dom_parser& operator=(const json_sax_dom_parser&) = delete; + json_sax_dom_parser& operator=(json_sax_dom_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + ~json_sax_dom_parser() = default; + + bool null() + { + handle_value(nullptr); + return true; + } + + bool boolean(bool val) + { + handle_value(val); + return true; + } + + bool number_integer(number_integer_t val) + { + handle_value(val); + return true; + } + + bool number_unsigned(number_unsigned_t val) + { + handle_value(val); + return true; + } + + bool number_float(number_float_t val, const string_t& /*unused*/) + { + handle_value(val); + return true; + } + + bool string(string_t& val) + { + handle_value(val); + return true; + } + + bool binary(binary_t& val) + { + handle_value(std::move(val)); + return true; + } + + bool start_object(std::size_t len) + { + ref_stack.push_back(handle_value(BasicJsonType::value_t::object)); + + if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive object size: " + std::to_string(len), *ref_stack.back())); + } + + return true; + } + + bool key(string_t& val) + { + // add null at given key and store the reference for later + object_element = &(ref_stack.back()->m_value.object->operator[](val)); + return true; + } + + bool end_object() + { + ref_stack.back()->set_parents(); + ref_stack.pop_back(); + return true; + } + + bool start_array(std::size_t len) + { + ref_stack.push_back(handle_value(BasicJsonType::value_t::array)); + + if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive array size: " + std::to_string(len), *ref_stack.back())); + } + + return true; + } + + bool end_array() + { + ref_stack.back()->set_parents(); + ref_stack.pop_back(); + return true; + } + + template + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, + const Exception& ex) + { + errored = true; + static_cast(ex); + if (allow_exceptions) + { + JSON_THROW(ex); + } + return false; + } + + constexpr bool is_errored() const + { + return errored; + } + + private: + /*! + @invariant If the ref stack is empty, then the passed value will be the new + root. + @invariant If the ref stack contains a value, then it is an array or an + object to which we can add elements + */ + template + JSON_HEDLEY_RETURNS_NON_NULL + BasicJsonType* handle_value(Value&& v) + { + if (ref_stack.empty()) + { + root = BasicJsonType(std::forward(v)); + return &root; + } + + JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); + + if (ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->emplace_back(std::forward(v)); + return &(ref_stack.back()->m_value.array->back()); + } + + JSON_ASSERT(ref_stack.back()->is_object()); + JSON_ASSERT(object_element); + *object_element = BasicJsonType(std::forward(v)); + return object_element; + } + + /// the parsed JSON value + BasicJsonType& root; + /// stack to model hierarchy of values + std::vector ref_stack {}; + /// helper to hold the reference for the next object element + BasicJsonType* object_element = nullptr; + /// whether a syntax error occurred + bool errored = false; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; +}; + +template +class json_sax_dom_callback_parser +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using parser_callback_t = typename BasicJsonType::parser_callback_t; + using parse_event_t = typename BasicJsonType::parse_event_t; + + json_sax_dom_callback_parser(BasicJsonType& r, + const parser_callback_t cb, + const bool allow_exceptions_ = true) + : root(r), callback(cb), allow_exceptions(allow_exceptions_) + { + keep_stack.push_back(true); + } + + // make class move-only + json_sax_dom_callback_parser(const json_sax_dom_callback_parser&) = delete; + json_sax_dom_callback_parser(json_sax_dom_callback_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + json_sax_dom_callback_parser& operator=(const json_sax_dom_callback_parser&) = delete; + json_sax_dom_callback_parser& operator=(json_sax_dom_callback_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + ~json_sax_dom_callback_parser() = default; + + bool null() + { + handle_value(nullptr); + return true; + } + + bool boolean(bool val) + { + handle_value(val); + return true; + } + + bool number_integer(number_integer_t val) + { + handle_value(val); + return true; + } + + bool number_unsigned(number_unsigned_t val) + { + handle_value(val); + return true; + } + + bool number_float(number_float_t val, const string_t& /*unused*/) + { + handle_value(val); + return true; + } + + bool string(string_t& val) + { + handle_value(val); + return true; + } + + bool binary(binary_t& val) + { + handle_value(std::move(val)); + return true; + } + + bool start_object(std::size_t len) + { + // check callback for object start + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::object_start, discarded); + keep_stack.push_back(keep); + + auto val = handle_value(BasicJsonType::value_t::object, true); + ref_stack.push_back(val.second); + + // check object limit + if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive object size: " + std::to_string(len), *ref_stack.back())); + } + + return true; + } + + bool key(string_t& val) + { + BasicJsonType k = BasicJsonType(val); + + // check callback for key + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::key, k); + key_keep_stack.push_back(keep); + + // add discarded value at given key and store the reference for later + if (keep && ref_stack.back()) + { + object_element = &(ref_stack.back()->m_value.object->operator[](val) = discarded); + } + + return true; + } + + bool end_object() + { + if (ref_stack.back()) + { + if (!callback(static_cast(ref_stack.size()) - 1, parse_event_t::object_end, *ref_stack.back())) + { + // discard object + *ref_stack.back() = discarded; + } + else + { + ref_stack.back()->set_parents(); + } + } + + JSON_ASSERT(!ref_stack.empty()); + JSON_ASSERT(!keep_stack.empty()); + ref_stack.pop_back(); + keep_stack.pop_back(); + + if (!ref_stack.empty() && ref_stack.back() && ref_stack.back()->is_structured()) + { + // remove discarded value + for (auto it = ref_stack.back()->begin(); it != ref_stack.back()->end(); ++it) + { + if (it->is_discarded()) + { + ref_stack.back()->erase(it); + break; + } + } + } + + return true; + } + + bool start_array(std::size_t len) + { + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::array_start, discarded); + keep_stack.push_back(keep); + + auto val = handle_value(BasicJsonType::value_t::array, true); + ref_stack.push_back(val.second); + + // check array limit + if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive array size: " + std::to_string(len), *ref_stack.back())); + } + + return true; + } + + bool end_array() + { + bool keep = true; + + if (ref_stack.back()) + { + keep = callback(static_cast(ref_stack.size()) - 1, parse_event_t::array_end, *ref_stack.back()); + if (keep) + { + ref_stack.back()->set_parents(); + } + else + { + // discard array + *ref_stack.back() = discarded; + } + } + + JSON_ASSERT(!ref_stack.empty()); + JSON_ASSERT(!keep_stack.empty()); + ref_stack.pop_back(); + keep_stack.pop_back(); + + // remove discarded value + if (!keep && !ref_stack.empty() && ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->pop_back(); + } + + return true; + } + + template + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, + const Exception& ex) + { + errored = true; + static_cast(ex); + if (allow_exceptions) + { + JSON_THROW(ex); + } + return false; + } + + constexpr bool is_errored() const + { + return errored; + } + + private: + /*! + @param[in] v value to add to the JSON value we build during parsing + @param[in] skip_callback whether we should skip calling the callback + function; this is required after start_array() and + start_object() SAX events, because otherwise we would call the + callback function with an empty array or object, respectively. + + @invariant If the ref stack is empty, then the passed value will be the new + root. + @invariant If the ref stack contains a value, then it is an array or an + object to which we can add elements + + @return pair of boolean (whether value should be kept) and pointer (to the + passed value in the ref_stack hierarchy; nullptr if not kept) + */ + template + std::pair handle_value(Value&& v, const bool skip_callback = false) + { + JSON_ASSERT(!keep_stack.empty()); + + // do not handle this value if we know it would be added to a discarded + // container + if (!keep_stack.back()) + { + return {false, nullptr}; + } + + // create value + auto value = BasicJsonType(std::forward(v)); + + // check callback + const bool keep = skip_callback || callback(static_cast(ref_stack.size()), parse_event_t::value, value); + + // do not handle this value if we just learnt it shall be discarded + if (!keep) + { + return {false, nullptr}; + } + + if (ref_stack.empty()) + { + root = std::move(value); + return {true, &root}; + } + + // skip this value if we already decided to skip the parent + // (https://github.com/nlohmann/json/issues/971#issuecomment-413678360) + if (!ref_stack.back()) + { + return {false, nullptr}; + } + + // we now only expect arrays and objects + JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); + + // array + if (ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->emplace_back(std::move(value)); + return {true, &(ref_stack.back()->m_value.array->back())}; + } + + // object + JSON_ASSERT(ref_stack.back()->is_object()); + // check if we should store an element for the current key + JSON_ASSERT(!key_keep_stack.empty()); + const bool store_element = key_keep_stack.back(); + key_keep_stack.pop_back(); + + if (!store_element) + { + return {false, nullptr}; + } + + JSON_ASSERT(object_element); + *object_element = std::move(value); + return {true, object_element}; + } + + /// the parsed JSON value + BasicJsonType& root; + /// stack to model hierarchy of values + std::vector ref_stack {}; + /// stack to manage which values to keep + std::vector keep_stack {}; + /// stack to manage which object keys to keep + std::vector key_keep_stack {}; + /// helper to hold the reference for the next object element + BasicJsonType* object_element = nullptr; + /// whether a syntax error occurred + bool errored = false; + /// callback function + const parser_callback_t callback = nullptr; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; + /// a discarded value for the callback + BasicJsonType discarded = BasicJsonType::value_t::discarded; +}; + +template +class json_sax_acceptor +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + bool null() + { + return true; + } + + bool boolean(bool /*unused*/) + { + return true; + } + + bool number_integer(number_integer_t /*unused*/) + { + return true; + } + + bool number_unsigned(number_unsigned_t /*unused*/) + { + return true; + } + + bool number_float(number_float_t /*unused*/, const string_t& /*unused*/) + { + return true; + } + + bool string(string_t& /*unused*/) + { + return true; + } + + bool binary(binary_t& /*unused*/) + { + return true; + } + + bool start_object(std::size_t /*unused*/ = std::size_t(-1)) + { + return true; + } + + bool key(string_t& /*unused*/) + { + return true; + } + + bool end_object() + { + return true; + } + + bool start_array(std::size_t /*unused*/ = std::size_t(-1)) + { + return true; + } + + bool end_array() + { + return true; + } + + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, const detail::exception& /*unused*/) + { + return false; + } +}; +} // namespace detail + +} // namespace nlohmann + +// #include + + +#include // array +#include // localeconv +#include // size_t +#include // snprintf +#include // strtof, strtod, strtold, strtoll, strtoull +#include // initializer_list +#include // char_traits, string +#include // move +#include // vector + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/////////// +// lexer // +/////////// + +template +class lexer_base +{ + public: + /// token types for the parser + enum class token_type + { + uninitialized, ///< indicating the scanner is uninitialized + literal_true, ///< the `true` literal + literal_false, ///< the `false` literal + literal_null, ///< the `null` literal + value_string, ///< a string -- use get_string() for actual value + value_unsigned, ///< an unsigned integer -- use get_number_unsigned() for actual value + value_integer, ///< a signed integer -- use get_number_integer() for actual value + value_float, ///< an floating point number -- use get_number_float() for actual value + begin_array, ///< the character for array begin `[` + begin_object, ///< the character for object begin `{` + end_array, ///< the character for array end `]` + end_object, ///< the character for object end `}` + name_separator, ///< the name separator `:` + value_separator, ///< the value separator `,` + parse_error, ///< indicating a parse error + end_of_input, ///< indicating the end of the input buffer + literal_or_value ///< a literal or the begin of a value (only for diagnostics) + }; + + /// return name of values of type token_type (only used for errors) + JSON_HEDLEY_RETURNS_NON_NULL + JSON_HEDLEY_CONST + static const char* token_type_name(const token_type t) noexcept + { + switch (t) + { + case token_type::uninitialized: + return ""; + case token_type::literal_true: + return "true literal"; + case token_type::literal_false: + return "false literal"; + case token_type::literal_null: + return "null literal"; + case token_type::value_string: + return "string literal"; + case token_type::value_unsigned: + case token_type::value_integer: + case token_type::value_float: + return "number literal"; + case token_type::begin_array: + return "'['"; + case token_type::begin_object: + return "'{'"; + case token_type::end_array: + return "']'"; + case token_type::end_object: + return "'}'"; + case token_type::name_separator: + return "':'"; + case token_type::value_separator: + return "','"; + case token_type::parse_error: + return ""; + case token_type::end_of_input: + return "end of input"; + case token_type::literal_or_value: + return "'[', '{', or a literal"; + // LCOV_EXCL_START + default: // catch non-enum values + return "unknown token"; + // LCOV_EXCL_STOP + } + } +}; +/*! +@brief lexical analysis + +This class organizes the lexical analysis during JSON deserialization. +*/ +template +class lexer : public lexer_base +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using char_type = typename InputAdapterType::char_type; + using char_int_type = typename std::char_traits::int_type; + + public: + using token_type = typename lexer_base::token_type; + + explicit lexer(InputAdapterType&& adapter, bool ignore_comments_ = false) noexcept + : ia(std::move(adapter)) + , ignore_comments(ignore_comments_) + , decimal_point_char(static_cast(get_decimal_point())) + {} + + // delete because of pointer members + lexer(const lexer&) = delete; + lexer(lexer&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + lexer& operator=(lexer&) = delete; + lexer& operator=(lexer&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + ~lexer() = default; + + private: + ///////////////////// + // locales + ///////////////////// + + /// return the locale-dependent decimal point + JSON_HEDLEY_PURE + static char get_decimal_point() noexcept + { + const auto* loc = localeconv(); + JSON_ASSERT(loc != nullptr); + return (loc->decimal_point == nullptr) ? '.' : *(loc->decimal_point); + } + + ///////////////////// + // scan functions + ///////////////////// + + /*! + @brief get codepoint from 4 hex characters following `\u` + + For input "\u c1 c2 c3 c4" the codepoint is: + (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4 + = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0) + + Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f' + must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The + conversion is done by subtracting the offset (0x30, 0x37, and 0x57) + between the ASCII value of the character and the desired integer value. + + @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or + non-hex character) + */ + int get_codepoint() + { + // this function only makes sense after reading `\u` + JSON_ASSERT(current == 'u'); + int codepoint = 0; + + const auto factors = { 12u, 8u, 4u, 0u }; + for (const auto factor : factors) + { + get(); + + if (current >= '0' && current <= '9') + { + codepoint += static_cast((static_cast(current) - 0x30u) << factor); + } + else if (current >= 'A' && current <= 'F') + { + codepoint += static_cast((static_cast(current) - 0x37u) << factor); + } + else if (current >= 'a' && current <= 'f') + { + codepoint += static_cast((static_cast(current) - 0x57u) << factor); + } + else + { + return -1; + } + } + + JSON_ASSERT(0x0000 <= codepoint && codepoint <= 0xFFFF); + return codepoint; + } + + /*! + @brief check if the next byte(s) are inside a given range + + Adds the current byte and, for each passed range, reads a new byte and + checks if it is inside the range. If a violation was detected, set up an + error message and return false. Otherwise, return true. + + @param[in] ranges list of integers; interpreted as list of pairs of + inclusive lower and upper bound, respectively + + @pre The passed list @a ranges must have 2, 4, or 6 elements; that is, + 1, 2, or 3 pairs. This precondition is enforced by an assertion. + + @return true if and only if no range violation was detected + */ + bool next_byte_in_range(std::initializer_list ranges) + { + JSON_ASSERT(ranges.size() == 2 || ranges.size() == 4 || ranges.size() == 6); + add(current); + + for (auto range = ranges.begin(); range != ranges.end(); ++range) + { + get(); + if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) + { + add(current); + } + else + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return false; + } + } + + return true; + } + + /*! + @brief scan a string literal + + This function scans a string according to Sect. 7 of RFC 8259. While + scanning, bytes are escaped and copied into buffer token_buffer. Then the + function returns successfully, token_buffer is *not* null-terminated (as it + may contain \0 bytes), and token_buffer.size() is the number of bytes in the + string. + + @return token_type::value_string if string could be successfully scanned, + token_type::parse_error otherwise + + @note In case of errors, variable error_message contains a textual + description. + */ + token_type scan_string() + { + // reset token_buffer (ignore opening quote) + reset(); + + // we entered the function by reading an open quote + JSON_ASSERT(current == '\"'); + + while (true) + { + // get next character + switch (get()) + { + // end of file while parsing string + case std::char_traits::eof(): + { + error_message = "invalid string: missing closing quote"; + return token_type::parse_error; + } + + // closing quote + case '\"': + { + return token_type::value_string; + } + + // escapes + case '\\': + { + switch (get()) + { + // quotation mark + case '\"': + add('\"'); + break; + // reverse solidus + case '\\': + add('\\'); + break; + // solidus + case '/': + add('/'); + break; + // backspace + case 'b': + add('\b'); + break; + // form feed + case 'f': + add('\f'); + break; + // line feed + case 'n': + add('\n'); + break; + // carriage return + case 'r': + add('\r'); + break; + // tab + case 't': + add('\t'); + break; + + // unicode escapes + case 'u': + { + const int codepoint1 = get_codepoint(); + int codepoint = codepoint1; // start with codepoint1 + + if (JSON_HEDLEY_UNLIKELY(codepoint1 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if code point is a high surrogate + if (0xD800 <= codepoint1 && codepoint1 <= 0xDBFF) + { + // expect next \uxxxx entry + if (JSON_HEDLEY_LIKELY(get() == '\\' && get() == 'u')) + { + const int codepoint2 = get_codepoint(); + + if (JSON_HEDLEY_UNLIKELY(codepoint2 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if codepoint2 is a low surrogate + if (JSON_HEDLEY_LIKELY(0xDC00 <= codepoint2 && codepoint2 <= 0xDFFF)) + { + // overwrite codepoint + codepoint = static_cast( + // high surrogate occupies the most significant 22 bits + (static_cast(codepoint1) << 10u) + // low surrogate occupies the least significant 15 bits + + static_cast(codepoint2) + // there is still the 0xD800, 0xDC00 and 0x10000 noise + // in the result so we have to subtract with: + // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00 + - 0x35FDC00u); + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(0xDC00 <= codepoint1 && codepoint1 <= 0xDFFF)) + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF"; + return token_type::parse_error; + } + } + + // result of the above calculation yields a proper codepoint + JSON_ASSERT(0x00 <= codepoint && codepoint <= 0x10FFFF); + + // translate codepoint into bytes + if (codepoint < 0x80) + { + // 1-byte characters: 0xxxxxxx (ASCII) + add(static_cast(codepoint)); + } + else if (codepoint <= 0x7FF) + { + // 2-byte characters: 110xxxxx 10xxxxxx + add(static_cast(0xC0u | (static_cast(codepoint) >> 6u))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else if (codepoint <= 0xFFFF) + { + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + add(static_cast(0xE0u | (static_cast(codepoint) >> 12u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else + { + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + add(static_cast(0xF0u | (static_cast(codepoint) >> 18u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 12u) & 0x3Fu))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + + break; + } + + // other characters after escape + default: + error_message = "invalid string: forbidden character after backslash"; + return token_type::parse_error; + } + + break; + } + + // invalid control characters + case 0x00: + { + error_message = "invalid string: control character U+0000 (NUL) must be escaped to \\u0000"; + return token_type::parse_error; + } + + case 0x01: + { + error_message = "invalid string: control character U+0001 (SOH) must be escaped to \\u0001"; + return token_type::parse_error; + } + + case 0x02: + { + error_message = "invalid string: control character U+0002 (STX) must be escaped to \\u0002"; + return token_type::parse_error; + } + + case 0x03: + { + error_message = "invalid string: control character U+0003 (ETX) must be escaped to \\u0003"; + return token_type::parse_error; + } + + case 0x04: + { + error_message = "invalid string: control character U+0004 (EOT) must be escaped to \\u0004"; + return token_type::parse_error; + } + + case 0x05: + { + error_message = "invalid string: control character U+0005 (ENQ) must be escaped to \\u0005"; + return token_type::parse_error; + } + + case 0x06: + { + error_message = "invalid string: control character U+0006 (ACK) must be escaped to \\u0006"; + return token_type::parse_error; + } + + case 0x07: + { + error_message = "invalid string: control character U+0007 (BEL) must be escaped to \\u0007"; + return token_type::parse_error; + } + + case 0x08: + { + error_message = "invalid string: control character U+0008 (BS) must be escaped to \\u0008 or \\b"; + return token_type::parse_error; + } + + case 0x09: + { + error_message = "invalid string: control character U+0009 (HT) must be escaped to \\u0009 or \\t"; + return token_type::parse_error; + } + + case 0x0A: + { + error_message = "invalid string: control character U+000A (LF) must be escaped to \\u000A or \\n"; + return token_type::parse_error; + } + + case 0x0B: + { + error_message = "invalid string: control character U+000B (VT) must be escaped to \\u000B"; + return token_type::parse_error; + } + + case 0x0C: + { + error_message = "invalid string: control character U+000C (FF) must be escaped to \\u000C or \\f"; + return token_type::parse_error; + } + + case 0x0D: + { + error_message = "invalid string: control character U+000D (CR) must be escaped to \\u000D or \\r"; + return token_type::parse_error; + } + + case 0x0E: + { + error_message = "invalid string: control character U+000E (SO) must be escaped to \\u000E"; + return token_type::parse_error; + } + + case 0x0F: + { + error_message = "invalid string: control character U+000F (SI) must be escaped to \\u000F"; + return token_type::parse_error; + } + + case 0x10: + { + error_message = "invalid string: control character U+0010 (DLE) must be escaped to \\u0010"; + return token_type::parse_error; + } + + case 0x11: + { + error_message = "invalid string: control character U+0011 (DC1) must be escaped to \\u0011"; + return token_type::parse_error; + } + + case 0x12: + { + error_message = "invalid string: control character U+0012 (DC2) must be escaped to \\u0012"; + return token_type::parse_error; + } + + case 0x13: + { + error_message = "invalid string: control character U+0013 (DC3) must be escaped to \\u0013"; + return token_type::parse_error; + } + + case 0x14: + { + error_message = "invalid string: control character U+0014 (DC4) must be escaped to \\u0014"; + return token_type::parse_error; + } + + case 0x15: + { + error_message = "invalid string: control character U+0015 (NAK) must be escaped to \\u0015"; + return token_type::parse_error; + } + + case 0x16: + { + error_message = "invalid string: control character U+0016 (SYN) must be escaped to \\u0016"; + return token_type::parse_error; + } + + case 0x17: + { + error_message = "invalid string: control character U+0017 (ETB) must be escaped to \\u0017"; + return token_type::parse_error; + } + + case 0x18: + { + error_message = "invalid string: control character U+0018 (CAN) must be escaped to \\u0018"; + return token_type::parse_error; + } + + case 0x19: + { + error_message = "invalid string: control character U+0019 (EM) must be escaped to \\u0019"; + return token_type::parse_error; + } + + case 0x1A: + { + error_message = "invalid string: control character U+001A (SUB) must be escaped to \\u001A"; + return token_type::parse_error; + } + + case 0x1B: + { + error_message = "invalid string: control character U+001B (ESC) must be escaped to \\u001B"; + return token_type::parse_error; + } + + case 0x1C: + { + error_message = "invalid string: control character U+001C (FS) must be escaped to \\u001C"; + return token_type::parse_error; + } + + case 0x1D: + { + error_message = "invalid string: control character U+001D (GS) must be escaped to \\u001D"; + return token_type::parse_error; + } + + case 0x1E: + { + error_message = "invalid string: control character U+001E (RS) must be escaped to \\u001E"; + return token_type::parse_error; + } + + case 0x1F: + { + error_message = "invalid string: control character U+001F (US) must be escaped to \\u001F"; + return token_type::parse_error; + } + + // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace)) + case 0x20: + case 0x21: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3A: + case 0x3B: + case 0x3C: + case 0x3D: + case 0x3E: + case 0x3F: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5A: + case 0x5B: + case 0x5D: + case 0x5E: + case 0x5F: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7A: + case 0x7B: + case 0x7C: + case 0x7D: + case 0x7E: + case 0x7F: + { + add(current); + break; + } + + // U+0080..U+07FF: bytes C2..DF 80..BF + case 0xC2: + case 0xC3: + case 0xC4: + case 0xC5: + case 0xC6: + case 0xC7: + case 0xC8: + case 0xC9: + case 0xCA: + case 0xCB: + case 0xCC: + case 0xCD: + case 0xCE: + case 0xCF: + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + case 0xD8: + case 0xD9: + case 0xDA: + case 0xDB: + case 0xDC: + case 0xDD: + case 0xDE: + case 0xDF: + { + if (JSON_HEDLEY_UNLIKELY(!next_byte_in_range({0x80, 0xBF}))) + { + return token_type::parse_error; + } + break; + } + + // U+0800..U+0FFF: bytes E0 A0..BF 80..BF + case 0xE0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF + // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xEE: + case 0xEF: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+D000..U+D7FF: bytes ED 80..9F 80..BF + case 0xED: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x9F, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF + case 0xF0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF + case 0xF1: + case 0xF2: + case 0xF3: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF + case 0xF4: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // remaining bytes (80..C1 and F5..FF) are ill-formed + default: + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return token_type::parse_error; + } + } + } + } + + /*! + * @brief scan a comment + * @return whether comment could be scanned successfully + */ + bool scan_comment() + { + switch (get()) + { + // single-line comments skip input until a newline or EOF is read + case '/': + { + while (true) + { + switch (get()) + { + case '\n': + case '\r': + case std::char_traits::eof(): + case '\0': + return true; + + default: + break; + } + } + } + + // multi-line comments skip input until */ is read + case '*': + { + while (true) + { + switch (get()) + { + case std::char_traits::eof(): + case '\0': + { + error_message = "invalid comment; missing closing '*/'"; + return false; + } + + case '*': + { + switch (get()) + { + case '/': + return true; + + default: + { + unget(); + continue; + } + } + } + + default: + continue; + } + } + } + + // unexpected character after reading '/' + default: + { + error_message = "invalid comment; expecting '/' or '*' after '/'"; + return false; + } + } + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(float& f, const char* str, char** endptr) noexcept + { + f = std::strtof(str, endptr); + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(double& f, const char* str, char** endptr) noexcept + { + f = std::strtod(str, endptr); + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(long double& f, const char* str, char** endptr) noexcept + { + f = std::strtold(str, endptr); + } + + /*! + @brief scan a number literal + + This function scans a string according to Sect. 6 of RFC 8259. + + The function is realized with a deterministic finite state machine derived + from the grammar described in RFC 8259. Starting in state "init", the + input is read and used to determined the next state. Only state "done" + accepts the number. State "error" is a trap state to model errors. In the + table below, "anything" means any character but the ones listed before. + + state | 0 | 1-9 | e E | + | - | . | anything + ---------|----------|----------|----------|---------|---------|----------|----------- + init | zero | any1 | [error] | [error] | minus | [error] | [error] + minus | zero | any1 | [error] | [error] | [error] | [error] | [error] + zero | done | done | exponent | done | done | decimal1 | done + any1 | any1 | any1 | exponent | done | done | decimal1 | done + decimal1 | decimal2 | decimal2 | [error] | [error] | [error] | [error] | [error] + decimal2 | decimal2 | decimal2 | exponent | done | done | done | done + exponent | any2 | any2 | [error] | sign | sign | [error] | [error] + sign | any2 | any2 | [error] | [error] | [error] | [error] | [error] + any2 | any2 | any2 | done | done | done | done | done + + The state machine is realized with one label per state (prefixed with + "scan_number_") and `goto` statements between them. The state machine + contains cycles, but any cycle can be left when EOF is read. Therefore, + the function is guaranteed to terminate. + + During scanning, the read bytes are stored in token_buffer. This string is + then converted to a signed integer, an unsigned integer, or a + floating-point number. + + @return token_type::value_unsigned, token_type::value_integer, or + token_type::value_float if number could be successfully scanned, + token_type::parse_error otherwise + + @note The scanner is independent of the current locale. Internally, the + locale's decimal point is used instead of `.` to work with the + locale-dependent converters. + */ + token_type scan_number() // lgtm [cpp/use-of-goto] + { + // reset token_buffer to store the number's bytes + reset(); + + // the type of the parsed number; initially set to unsigned; will be + // changed if minus sign, decimal point or exponent is read + token_type number_type = token_type::value_unsigned; + + // state (init): we just found out we need to scan a number + switch (current) + { + case '-': + { + add(current); + goto scan_number_minus; + } + + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + // all other characters are rejected outside scan_number() + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + +scan_number_minus: + // state: we just parsed a leading minus sign + number_type = token_type::value_integer; + switch (get()) + { + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + default: + { + error_message = "invalid number; expected digit after '-'"; + return token_type::parse_error; + } + } + +scan_number_zero: + // state: we just parse a zero (maybe with a leading minus sign) + switch (get()) + { + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_any1: + // state: we just parsed a number 0-9 (maybe with a leading minus sign) + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_decimal1: + // state: we just parsed a decimal point + number_type = token_type::value_float; + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + default: + { + error_message = "invalid number; expected digit after '.'"; + return token_type::parse_error; + } + } + +scan_number_decimal2: + // we just parsed at least one number after a decimal point + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_exponent: + // we just parsed an exponent + number_type = token_type::value_float; + switch (get()) + { + case '+': + case '-': + { + add(current); + goto scan_number_sign; + } + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = + "invalid number; expected '+', '-', or digit after exponent"; + return token_type::parse_error; + } + } + +scan_number_sign: + // we just parsed an exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = "invalid number; expected digit after exponent sign"; + return token_type::parse_error; + } + } + +scan_number_any2: + // we just parsed a number after the exponent or exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + goto scan_number_done; + } + +scan_number_done: + // unget the character after the number (we only read it to know that + // we are done scanning a number) + unget(); + + char* endptr = nullptr; // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + errno = 0; + + // try to parse integers first and fall back to floats + if (number_type == token_type::value_unsigned) + { + const auto x = std::strtoull(token_buffer.data(), &endptr, 10); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + if (errno == 0) + { + value_unsigned = static_cast(x); + if (value_unsigned == x) + { + return token_type::value_unsigned; + } + } + } + else if (number_type == token_type::value_integer) + { + const auto x = std::strtoll(token_buffer.data(), &endptr, 10); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + if (errno == 0) + { + value_integer = static_cast(x); + if (value_integer == x) + { + return token_type::value_integer; + } + } + } + + // this code is reached if we parse a floating-point number or if an + // integer conversion above failed + strtof(value_float, token_buffer.data(), &endptr); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + return token_type::value_float; + } + + /*! + @param[in] literal_text the literal text to expect + @param[in] length the length of the passed literal text + @param[in] return_type the token type to return on success + */ + JSON_HEDLEY_NON_NULL(2) + token_type scan_literal(const char_type* literal_text, const std::size_t length, + token_type return_type) + { + JSON_ASSERT(std::char_traits::to_char_type(current) == literal_text[0]); + for (std::size_t i = 1; i < length; ++i) + { + if (JSON_HEDLEY_UNLIKELY(std::char_traits::to_char_type(get()) != literal_text[i])) + { + error_message = "invalid literal"; + return token_type::parse_error; + } + } + return return_type; + } + + ///////////////////// + // input management + ///////////////////// + + /// reset token_buffer; current character is beginning of token + void reset() noexcept + { + token_buffer.clear(); + token_string.clear(); + token_string.push_back(std::char_traits::to_char_type(current)); + } + + /* + @brief get next character from the input + + This function provides the interface to the used input adapter. It does + not throw in case the input reached EOF, but returns a + `std::char_traits::eof()` in that case. Stores the scanned characters + for use in error messages. + + @return character read from the input + */ + char_int_type get() + { + ++position.chars_read_total; + ++position.chars_read_current_line; + + if (next_unget) + { + // just reset the next_unget variable and work with current + next_unget = false; + } + else + { + current = ia.get_character(); + } + + if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + { + token_string.push_back(std::char_traits::to_char_type(current)); + } + + if (current == '\n') + { + ++position.lines_read; + position.chars_read_current_line = 0; + } + + return current; + } + + /*! + @brief unget current character (read it again on next get) + + We implement unget by setting variable next_unget to true. The input is not + changed - we just simulate ungetting by modifying chars_read_total, + chars_read_current_line, and token_string. The next call to get() will + behave as if the unget character is read again. + */ + void unget() + { + next_unget = true; + + --position.chars_read_total; + + // in case we "unget" a newline, we have to also decrement the lines_read + if (position.chars_read_current_line == 0) + { + if (position.lines_read > 0) + { + --position.lines_read; + } + } + else + { + --position.chars_read_current_line; + } + + if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + { + JSON_ASSERT(!token_string.empty()); + token_string.pop_back(); + } + } + + /// add a character to token_buffer + void add(char_int_type c) + { + token_buffer.push_back(static_cast(c)); + } + + public: + ///////////////////// + // value getters + ///////////////////// + + /// return integer value + constexpr number_integer_t get_number_integer() const noexcept + { + return value_integer; + } + + /// return unsigned integer value + constexpr number_unsigned_t get_number_unsigned() const noexcept + { + return value_unsigned; + } + + /// return floating-point value + constexpr number_float_t get_number_float() const noexcept + { + return value_float; + } + + /// return current string value (implicitly resets the token; useful only once) + string_t& get_string() + { + return token_buffer; + } + + ///////////////////// + // diagnostics + ///////////////////// + + /// return position of last read token + constexpr position_t get_position() const noexcept + { + return position; + } + + /// return the last read token (for errors only). Will never contain EOF + /// (an arbitrary value that is not a valid char value, often -1), because + /// 255 may legitimately occur. May contain NUL, which should be escaped. + std::string get_token_string() const + { + // escape control characters + std::string result; + for (const auto c : token_string) + { + if (static_cast(c) <= '\x1F') + { + // escape control characters + std::array cs{{}}; + (std::snprintf)(cs.data(), cs.size(), "", static_cast(c)); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + result += cs.data(); + } + else + { + // add character as is + result.push_back(static_cast(c)); + } + } + + return result; + } + + /// return syntax error message + JSON_HEDLEY_RETURNS_NON_NULL + constexpr const char* get_error_message() const noexcept + { + return error_message; + } + + ///////////////////// + // actual scanner + ///////////////////// + + /*! + @brief skip the UTF-8 byte order mark + @return true iff there is no BOM or the correct BOM has been skipped + */ + bool skip_bom() + { + if (get() == 0xEF) + { + // check if we completely parse the BOM + return get() == 0xBB && get() == 0xBF; + } + + // the first character is not the beginning of the BOM; unget it to + // process is later + unget(); + return true; + } + + void skip_whitespace() + { + do + { + get(); + } + while (current == ' ' || current == '\t' || current == '\n' || current == '\r'); + } + + token_type scan() + { + // initially, skip the BOM + if (position.chars_read_total == 0 && !skip_bom()) + { + error_message = "invalid BOM; must be 0xEF 0xBB 0xBF if given"; + return token_type::parse_error; + } + + // read next character and ignore whitespace + skip_whitespace(); + + // ignore comments + while (ignore_comments && current == '/') + { + if (!scan_comment()) + { + return token_type::parse_error; + } + + // skip following whitespace + skip_whitespace(); + } + + switch (current) + { + // structural characters + case '[': + return token_type::begin_array; + case ']': + return token_type::end_array; + case '{': + return token_type::begin_object; + case '}': + return token_type::end_object; + case ':': + return token_type::name_separator; + case ',': + return token_type::value_separator; + + // literals + case 't': + { + std::array true_literal = {{char_type('t'), char_type('r'), char_type('u'), char_type('e')}}; + return scan_literal(true_literal.data(), true_literal.size(), token_type::literal_true); + } + case 'f': + { + std::array false_literal = {{char_type('f'), char_type('a'), char_type('l'), char_type('s'), char_type('e')}}; + return scan_literal(false_literal.data(), false_literal.size(), token_type::literal_false); + } + case 'n': + { + std::array null_literal = {{char_type('n'), char_type('u'), char_type('l'), char_type('l')}}; + return scan_literal(null_literal.data(), null_literal.size(), token_type::literal_null); + } + + // string + case '\"': + return scan_string(); + + // number + case '-': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return scan_number(); + + // end of input (the null byte is needed when parsing from + // string literals) + case '\0': + case std::char_traits::eof(): + return token_type::end_of_input; + + // error + default: + error_message = "invalid literal"; + return token_type::parse_error; + } + } + + private: + /// input adapter + InputAdapterType ia; + + /// whether comments should be ignored (true) or signaled as errors (false) + const bool ignore_comments = false; + + /// the current character + char_int_type current = std::char_traits::eof(); + + /// whether the next get() call should just return current + bool next_unget = false; + + /// the start position of the current token + position_t position {}; + + /// raw input token string (for error messages) + std::vector token_string {}; + + /// buffer for variable-length tokens (numbers, strings) + string_t token_buffer {}; + + /// a description of occurred lexer errors + const char* error_message = ""; + + // number values + number_integer_t value_integer = 0; + number_unsigned_t value_unsigned = 0; + number_float_t value_float = 0; + + /// the decimal point + const char_int_type decimal_point_char = '.'; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // size_t +#include // declval +#include // string + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +using null_function_t = decltype(std::declval().null()); + +template +using boolean_function_t = + decltype(std::declval().boolean(std::declval())); + +template +using number_integer_function_t = + decltype(std::declval().number_integer(std::declval())); + +template +using number_unsigned_function_t = + decltype(std::declval().number_unsigned(std::declval())); + +template +using number_float_function_t = decltype(std::declval().number_float( + std::declval(), std::declval())); + +template +using string_function_t = + decltype(std::declval().string(std::declval())); + +template +using binary_function_t = + decltype(std::declval().binary(std::declval())); + +template +using start_object_function_t = + decltype(std::declval().start_object(std::declval())); + +template +using key_function_t = + decltype(std::declval().key(std::declval())); + +template +using end_object_function_t = decltype(std::declval().end_object()); + +template +using start_array_function_t = + decltype(std::declval().start_array(std::declval())); + +template +using end_array_function_t = decltype(std::declval().end_array()); + +template +using parse_error_function_t = decltype(std::declval().parse_error( + std::declval(), std::declval(), + std::declval())); + +template +struct is_sax +{ + private: + static_assert(is_basic_json::value, + "BasicJsonType must be of type basic_json<...>"); + + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using exception_t = typename BasicJsonType::exception; + + public: + static constexpr bool value = + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value; +}; + +template +struct is_sax_static_asserts +{ + private: + static_assert(is_basic_json::value, + "BasicJsonType must be of type basic_json<...>"); + + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using exception_t = typename BasicJsonType::exception; + + public: + static_assert(is_detected_exact::value, + "Missing/invalid function: bool null()"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool boolean(bool)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool boolean(bool)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool number_integer(number_integer_t)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool number_unsigned(number_unsigned_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool number_float(number_float_t, const string_t&)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool string(string_t&)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool binary(binary_t&)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool start_object(std::size_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool key(string_t&)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool end_object()"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool start_array(std::size_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool end_array()"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool parse_error(std::size_t, const " + "std::string&, const exception&)"); +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +/// how to treat CBOR tags +enum class cbor_tag_handler_t +{ + error, ///< throw a parse_error exception in case of a tag + ignore, ///< ignore tags + store ///< store tags as binary type +}; + +/*! +@brief determine system byte order + +@return true if and only if system's byte order is little endian + +@note from https://stackoverflow.com/a/1001328/266378 +*/ +static inline bool little_endianess(int num = 1) noexcept +{ + return *reinterpret_cast(&num) == 1; +} + + +/////////////////// +// binary reader // +/////////////////// + +/*! +@brief deserialization of CBOR, MessagePack, and UBJSON values +*/ +template> +class binary_reader +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using json_sax_t = SAX; + using char_type = typename InputAdapterType::char_type; + using char_int_type = typename std::char_traits::int_type; + + public: + /*! + @brief create a binary reader + + @param[in] adapter input adapter to read from + */ + explicit binary_reader(InputAdapterType&& adapter) noexcept : ia(std::move(adapter)) + { + (void)detail::is_sax_static_asserts {}; + } + + // make class move-only + binary_reader(const binary_reader&) = delete; + binary_reader(binary_reader&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + binary_reader& operator=(const binary_reader&) = delete; + binary_reader& operator=(binary_reader&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + ~binary_reader() = default; + + /*! + @param[in] format the binary format to parse + @param[in] sax_ a SAX event processor + @param[in] strict whether to expect the input to be consumed completed + @param[in] tag_handler how to treat CBOR tags + + @return whether parsing was successful + */ + JSON_HEDLEY_NON_NULL(3) + bool sax_parse(const input_format_t format, + json_sax_t* sax_, + const bool strict = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + sax = sax_; + bool result = false; + + switch (format) + { + case input_format_t::bson: + result = parse_bson_internal(); + break; + + case input_format_t::cbor: + result = parse_cbor_internal(true, tag_handler); + break; + + case input_format_t::msgpack: + result = parse_msgpack_internal(); + break; + + case input_format_t::ubjson: + result = parse_ubjson_internal(); + break; + + case input_format_t::json: // LCOV_EXCL_LINE + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + + // strict mode: next byte must be EOF + if (result && strict) + { + if (format == input_format_t::ubjson) + { + get_ignore_noop(); + } + else + { + get(); + } + + if (JSON_HEDLEY_UNLIKELY(current != std::char_traits::eof())) + { + return sax->parse_error(chars_read, get_token_string(), + parse_error::create(110, chars_read, exception_message(format, "expected end of input; last byte: 0x" + get_token_string(), "value"), BasicJsonType())); + } + } + + return result; + } + + private: + ////////// + // BSON // + ////////// + + /*! + @brief Reads in a BSON-object and passes it to the SAX-parser. + @return whether a valid BSON-value was passed to the SAX parser + */ + bool parse_bson_internal() + { + std::int32_t document_size{}; + get_number(input_format_t::bson, document_size); + + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/false))) + { + return false; + } + + return sax->end_object(); + } + + /*! + @brief Parses a C-style string from the BSON input. + @param[in,out] result A reference to the string variable where the read + string is to be stored. + @return `true` if the \x00-byte indicating the end of the string was + encountered before the EOF; false` indicates an unexpected EOF. + */ + bool get_bson_cstr(string_t& result) + { + auto out = std::back_inserter(result); + while (true) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "cstring"))) + { + return false; + } + if (current == 0x00) + { + return true; + } + *out++ = static_cast(current); + } + } + + /*! + @brief Parses a zero-terminated string of length @a len from the BSON + input. + @param[in] len The length (including the zero-byte at the end) of the + string to be read. + @param[in,out] result A reference to the string variable where the read + string is to be stored. + @tparam NumberType The type of the length @a len + @pre len >= 1 + @return `true` if the string was successfully parsed + */ + template + bool get_bson_string(const NumberType len, string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(len < 1)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "string length must be at least 1, is " + std::to_string(len), "string"), BasicJsonType())); + } + + return get_string(input_format_t::bson, len - static_cast(1), result) && get() != std::char_traits::eof(); + } + + /*! + @brief Parses a byte array input of length @a len from the BSON input. + @param[in] len The length of the byte array to be read. + @param[in,out] result A reference to the binary variable where the read + array is to be stored. + @tparam NumberType The type of the length @a len + @pre len >= 0 + @return `true` if the byte array was successfully parsed + */ + template + bool get_bson_binary(const NumberType len, binary_t& result) + { + if (JSON_HEDLEY_UNLIKELY(len < 0)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "byte array length cannot be negative, is " + std::to_string(len), "binary"), BasicJsonType())); + } + + // All BSON binary values have a subtype + std::uint8_t subtype{}; + get_number(input_format_t::bson, subtype); + result.set_subtype(subtype); + + return get_binary(input_format_t::bson, len, result); + } + + /*! + @brief Read a BSON document element of the given @a element_type. + @param[in] element_type The BSON element type, c.f. http://bsonspec.org/spec.html + @param[in] element_type_parse_position The position in the input stream, + where the `element_type` was read. + @warning Not all BSON element types are supported yet. An unsupported + @a element_type will give rise to a parse_error.114: + Unsupported BSON record type 0x... + @return whether a valid BSON-object/array was passed to the SAX parser + */ + bool parse_bson_element_internal(const char_int_type element_type, + const std::size_t element_type_parse_position) + { + switch (element_type) + { + case 0x01: // double + { + double number{}; + return get_number(input_format_t::bson, number) && sax->number_float(static_cast(number), ""); + } + + case 0x02: // string + { + std::int32_t len{}; + string_t value; + return get_number(input_format_t::bson, len) && get_bson_string(len, value) && sax->string(value); + } + + case 0x03: // object + { + return parse_bson_internal(); + } + + case 0x04: // array + { + return parse_bson_array(); + } + + case 0x05: // binary + { + std::int32_t len{}; + binary_t value; + return get_number(input_format_t::bson, len) && get_bson_binary(len, value) && sax->binary(value); + } + + case 0x08: // boolean + { + return sax->boolean(get() != 0); + } + + case 0x0A: // null + { + return sax->null(); + } + + case 0x10: // int32 + { + std::int32_t value{}; + return get_number(input_format_t::bson, value) && sax->number_integer(value); + } + + case 0x12: // int64 + { + std::int64_t value{}; + return get_number(input_format_t::bson, value) && sax->number_integer(value); + } + + default: // anything else not supported (yet) + { + std::array cr{{}}; + (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(element_type)); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + return sax->parse_error(element_type_parse_position, std::string(cr.data()), parse_error::create(114, element_type_parse_position, "Unsupported BSON record type 0x" + std::string(cr.data()), BasicJsonType())); + } + } + } + + /*! + @brief Read a BSON element list (as specified in the BSON-spec) + + The same binary layout is used for objects and arrays, hence it must be + indicated with the argument @a is_array which one is expected + (true --> array, false --> object). + + @param[in] is_array Determines if the element list being read is to be + treated as an object (@a is_array == false), or as an + array (@a is_array == true). + @return whether a valid BSON-object/array was passed to the SAX parser + */ + bool parse_bson_element_list(const bool is_array) + { + string_t key; + + while (auto element_type = get()) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "element list"))) + { + return false; + } + + const std::size_t element_type_parse_position = chars_read; + if (JSON_HEDLEY_UNLIKELY(!get_bson_cstr(key))) + { + return false; + } + + if (!is_array && !sax->key(key)) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_internal(element_type, element_type_parse_position))) + { + return false; + } + + // get_bson_cstr only appends + key.clear(); + } + + return true; + } + + /*! + @brief Reads an array from the BSON input and passes it to the SAX-parser. + @return whether a valid BSON-array was passed to the SAX parser + */ + bool parse_bson_array() + { + std::int32_t document_size{}; + get_number(input_format_t::bson, document_size); + + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/true))) + { + return false; + } + + return sax->end_array(); + } + + ////////// + // CBOR // + ////////// + + /*! + @param[in] get_char whether a new character should be retrieved from the + input (true) or whether the last read character should + be considered instead (false) + @param[in] tag_handler how CBOR tags should be treated + + @return whether a valid CBOR value was passed to the SAX parser + */ + bool parse_cbor_internal(const bool get_char, + const cbor_tag_handler_t tag_handler) + { + switch (get_char ? get() : current) + { + // EOF + case std::char_traits::eof(): + return unexpect_eof(input_format_t::cbor, "value"); + + // Integer 0x00..0x17 (0..23) + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0A: + case 0x0B: + case 0x0C: + case 0x0D: + case 0x0E: + case 0x0F: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + return sax->number_unsigned(static_cast(current)); + + case 0x18: // Unsigned integer (one-byte uint8_t follows) + { + std::uint8_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x19: // Unsigned integer (two-byte uint16_t follows) + { + std::uint16_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x1A: // Unsigned integer (four-byte uint32_t follows) + { + std::uint32_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x1B: // Unsigned integer (eight-byte uint64_t follows) + { + std::uint64_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + // Negative integer -1-0x00..-1-0x17 (-1..-24) + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + return sax->number_integer(static_cast(0x20 - 1 - current)); + + case 0x38: // Negative integer (one-byte uint8_t follows) + { + std::uint8_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x39: // Negative integer -1-n (two-byte uint16_t follows) + { + std::uint16_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x3A: // Negative integer -1-n (four-byte uint32_t follows) + { + std::uint32_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x3B: // Negative integer -1-n (eight-byte uint64_t follows) + { + std::uint64_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) + - static_cast(number)); + } + + // Binary data (0x00..0x17 bytes follow) + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: // Binary data (one-byte uint8_t for n follows) + case 0x59: // Binary data (two-byte uint16_t for n follow) + case 0x5A: // Binary data (four-byte uint32_t for n follow) + case 0x5B: // Binary data (eight-byte uint64_t for n follow) + case 0x5F: // Binary data (indefinite length) + { + binary_t b; + return get_cbor_binary(b) && sax->binary(b); + } + + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) + case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) + case 0x7F: // UTF-8 string (indefinite length) + { + string_t s; + return get_cbor_string(s) && sax->string(s); + } + + // array (0x00..0x17 data items follow) + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8A: + case 0x8B: + case 0x8C: + case 0x8D: + case 0x8E: + case 0x8F: + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + return get_cbor_array(static_cast(static_cast(current) & 0x1Fu), tag_handler); + + case 0x98: // array (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x99: // array (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x9A: // array (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x9B: // array (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(detail::conditional_static_cast(len), tag_handler); + } + + case 0x9F: // array (indefinite length) + return get_cbor_array(std::size_t(-1), tag_handler); + + // map (0x00..0x17 pairs of data items follow) + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + return get_cbor_object(static_cast(static_cast(current) & 0x1Fu), tag_handler); + + case 0xB8: // map (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xB9: // map (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xBA: // map (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xBB: // map (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(detail::conditional_static_cast(len), tag_handler); + } + + case 0xBF: // map (indefinite length) + return get_cbor_object(std::size_t(-1), tag_handler); + + case 0xC6: // tagged item + case 0xC7: + case 0xC8: + case 0xC9: + case 0xCA: + case 0xCB: + case 0xCC: + case 0xCD: + case 0xCE: + case 0xCF: + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD8: // tagged item (1 bytes follow) + case 0xD9: // tagged item (2 bytes follow) + case 0xDA: // tagged item (4 bytes follow) + case 0xDB: // tagged item (8 bytes follow) + { + switch (tag_handler) + { + case cbor_tag_handler_t::error: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"), BasicJsonType())); + } + + case cbor_tag_handler_t::ignore: + { + // ignore binary subtype + switch (current) + { + case 0xD8: + { + std::uint8_t subtype_to_ignore{}; + get_number(input_format_t::cbor, subtype_to_ignore); + break; + } + case 0xD9: + { + std::uint16_t subtype_to_ignore{}; + get_number(input_format_t::cbor, subtype_to_ignore); + break; + } + case 0xDA: + { + std::uint32_t subtype_to_ignore{}; + get_number(input_format_t::cbor, subtype_to_ignore); + break; + } + case 0xDB: + { + std::uint64_t subtype_to_ignore{}; + get_number(input_format_t::cbor, subtype_to_ignore); + break; + } + default: + break; + } + return parse_cbor_internal(true, tag_handler); + } + + case cbor_tag_handler_t::store: + { + binary_t b; + // use binary subtype and store in binary container + switch (current) + { + case 0xD8: + { + std::uint8_t subtype{}; + get_number(input_format_t::cbor, subtype); + b.set_subtype(detail::conditional_static_cast(subtype)); + break; + } + case 0xD9: + { + std::uint16_t subtype{}; + get_number(input_format_t::cbor, subtype); + b.set_subtype(detail::conditional_static_cast(subtype)); + break; + } + case 0xDA: + { + std::uint32_t subtype{}; + get_number(input_format_t::cbor, subtype); + b.set_subtype(detail::conditional_static_cast(subtype)); + break; + } + case 0xDB: + { + std::uint64_t subtype{}; + get_number(input_format_t::cbor, subtype); + b.set_subtype(detail::conditional_static_cast(subtype)); + break; + } + default: + return parse_cbor_internal(true, tag_handler); + } + get(); + return get_cbor_binary(b) && sax->binary(b); + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + return false; // LCOV_EXCL_LINE + } + } + + case 0xF4: // false + return sax->boolean(false); + + case 0xF5: // true + return sax->boolean(true); + + case 0xF6: // null + return sax->null(); + + case 0xF9: // Half-Precision Float (two-byte IEEE 754) + { + const auto byte1_raw = get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) + { + return false; + } + const auto byte2_raw = get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) + { + return false; + } + + const auto byte1 = static_cast(byte1_raw); + const auto byte2 = static_cast(byte2_raw); + + // code from RFC 7049, Appendix D, Figure 3: + // As half-precision floating-point numbers were only added + // to IEEE 754 in 2008, today's programming platforms often + // still only have limited support for them. It is very + // easy to include at least decoding support for them even + // without such support. An example of a small decoder for + // half-precision floating-point numbers in the C language + // is shown in Fig. 3. + const auto half = static_cast((byte1 << 8u) + byte2); + const double val = [&half] + { + const int exp = (half >> 10u) & 0x1Fu; + const unsigned int mant = half & 0x3FFu; + JSON_ASSERT(0 <= exp&& exp <= 32); + JSON_ASSERT(mant <= 1024); + switch (exp) + { + case 0: + return std::ldexp(mant, -24); + case 31: + return (mant == 0) + ? std::numeric_limits::infinity() + : std::numeric_limits::quiet_NaN(); + default: + return std::ldexp(mant + 1024, exp - 25); + } + }(); + return sax->number_float((half & 0x8000u) != 0 + ? static_cast(-val) + : static_cast(val), ""); + } + + case 0xFA: // Single-Precision Float (four-byte IEEE 754) + { + float number{}; + return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); + } + + case 0xFB: // Double-Precision Float (eight-byte IEEE 754) + { + double number{}; + return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); + } + + default: // anything else (0xFF is handled inside the other types) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"), BasicJsonType())); + } + } + } + + /*! + @brief reads a CBOR string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + Additionally, CBOR's strings with indefinite lengths are supported. + + @param[out] result created string + + @return whether string creation completed + */ + bool get_cbor_string(string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "string"))) + { + return false; + } + + switch (current) + { + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + { + return get_string(input_format_t::cbor, static_cast(current) & 0x1Fu, result); + } + + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7F: // UTF-8 string (indefinite length) + { + while (get() != 0xFF) + { + string_t chunk; + if (!get_cbor_string(chunk)) + { + return false; + } + result.append(chunk); + } + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x60-0x7B) or indefinite string type (0x7F); last byte: 0x" + last_token, "string"), BasicJsonType())); + } + } + } + + /*! + @brief reads a CBOR byte array + + This function first reads starting bytes to determine the expected + byte array length and then copies this number of bytes into the byte array. + Additionally, CBOR's byte arrays with indefinite lengths are supported. + + @param[out] result created byte array + + @return whether byte array creation completed + */ + bool get_cbor_binary(binary_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "binary"))) + { + return false; + } + + switch (current) + { + // Binary data (0x00..0x17 bytes follow) + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + { + return get_binary(input_format_t::cbor, static_cast(current) & 0x1Fu, result); + } + + case 0x58: // Binary data (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x59: // Binary data (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5A: // Binary data (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5B: // Binary data (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5F: // Binary data (indefinite length) + { + while (get() != 0xFF) + { + binary_t chunk; + if (!get_cbor_binary(chunk)) + { + return false; + } + result.insert(result.end(), chunk.begin(), chunk.end()); + } + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x40-0x5B) or indefinite binary array type (0x5F); last byte: 0x" + last_token, "binary"), BasicJsonType())); + } + } + } + + /*! + @param[in] len the length of the array or std::size_t(-1) for an + array of indefinite size + @param[in] tag_handler how CBOR tags should be treated + @return whether array creation completed + */ + bool get_cbor_array(const std::size_t len, + const cbor_tag_handler_t tag_handler) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) + { + return false; + } + + if (len != std::size_t(-1)) + { + for (std::size_t i = 0; i < len; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + } + } + else + { + while (get() != 0xFF) + { + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(false, tag_handler))) + { + return false; + } + } + } + + return sax->end_array(); + } + + /*! + @param[in] len the length of the object or std::size_t(-1) for an + object of indefinite size + @param[in] tag_handler how CBOR tags should be treated + @return whether object creation completed + */ + bool get_cbor_object(const std::size_t len, + const cbor_tag_handler_t tag_handler) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) + { + return false; + } + + if (len != 0) + { + string_t key; + if (len != std::size_t(-1)) + { + for (std::size_t i = 0; i < len; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + key.clear(); + } + } + else + { + while (get() != 0xFF) + { + if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + key.clear(); + } + } + } + + return sax->end_object(); + } + + ///////////// + // MsgPack // + ///////////// + + /*! + @return whether a valid MessagePack value was passed to the SAX parser + */ + bool parse_msgpack_internal() + { + switch (get()) + { + // EOF + case std::char_traits::eof(): + return unexpect_eof(input_format_t::msgpack, "value"); + + // positive fixint + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0A: + case 0x0B: + case 0x0C: + case 0x0D: + case 0x0E: + case 0x0F: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + case 0x18: + case 0x19: + case 0x1A: + case 0x1B: + case 0x1C: + case 0x1D: + case 0x1E: + case 0x1F: + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3A: + case 0x3B: + case 0x3C: + case 0x3D: + case 0x3E: + case 0x3F: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5A: + case 0x5B: + case 0x5C: + case 0x5D: + case 0x5E: + case 0x5F: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7A: + case 0x7B: + case 0x7C: + case 0x7D: + case 0x7E: + case 0x7F: + return sax->number_unsigned(static_cast(current)); + + // fixmap + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8A: + case 0x8B: + case 0x8C: + case 0x8D: + case 0x8E: + case 0x8F: + return get_msgpack_object(static_cast(static_cast(current) & 0x0Fu)); + + // fixarray + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + case 0x98: + case 0x99: + case 0x9A: + case 0x9B: + case 0x9C: + case 0x9D: + case 0x9E: + case 0x9F: + return get_msgpack_array(static_cast(static_cast(current) & 0x0Fu)); + + // fixstr + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + case 0xB8: + case 0xB9: + case 0xBA: + case 0xBB: + case 0xBC: + case 0xBD: + case 0xBE: + case 0xBF: + case 0xD9: // str 8 + case 0xDA: // str 16 + case 0xDB: // str 32 + { + string_t s; + return get_msgpack_string(s) && sax->string(s); + } + + case 0xC0: // nil + return sax->null(); + + case 0xC2: // false + return sax->boolean(false); + + case 0xC3: // true + return sax->boolean(true); + + case 0xC4: // bin 8 + case 0xC5: // bin 16 + case 0xC6: // bin 32 + case 0xC7: // ext 8 + case 0xC8: // ext 16 + case 0xC9: // ext 32 + case 0xD4: // fixext 1 + case 0xD5: // fixext 2 + case 0xD6: // fixext 4 + case 0xD7: // fixext 8 + case 0xD8: // fixext 16 + { + binary_t b; + return get_msgpack_binary(b) && sax->binary(b); + } + + case 0xCA: // float 32 + { + float number{}; + return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); + } + + case 0xCB: // float 64 + { + double number{}; + return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); + } + + case 0xCC: // uint 8 + { + std::uint8_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCD: // uint 16 + { + std::uint16_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCE: // uint 32 + { + std::uint32_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCF: // uint 64 + { + std::uint64_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xD0: // int 8 + { + std::int8_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD1: // int 16 + { + std::int16_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD2: // int 32 + { + std::int32_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD3: // int 64 + { + std::int64_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xDC: // array 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); + } + + case 0xDD: // array 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); + } + + case 0xDE: // map 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); + } + + case 0xDF: // map 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); + } + + // negative fixint + case 0xE0: + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xED: + case 0xEE: + case 0xEF: + case 0xF0: + case 0xF1: + case 0xF2: + case 0xF3: + case 0xF4: + case 0xF5: + case 0xF6: + case 0xF7: + case 0xF8: + case 0xF9: + case 0xFA: + case 0xFB: + case 0xFC: + case 0xFD: + case 0xFE: + case 0xFF: + return sax->number_integer(static_cast(current)); + + default: // anything else + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::msgpack, "invalid byte: 0x" + last_token, "value"), BasicJsonType())); + } + } + } + + /*! + @brief reads a MessagePack string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + + @param[out] result created string + + @return whether string creation completed + */ + bool get_msgpack_string(string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::msgpack, "string"))) + { + return false; + } + + switch (current) + { + // fixstr + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + case 0xB8: + case 0xB9: + case 0xBA: + case 0xBB: + case 0xBC: + case 0xBD: + case 0xBE: + case 0xBF: + { + return get_string(input_format_t::msgpack, static_cast(current) & 0x1Fu, result); + } + + case 0xD9: // str 8 + { + std::uint8_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + case 0xDA: // str 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + case 0xDB: // str 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::msgpack, "expected length specification (0xA0-0xBF, 0xD9-0xDB); last byte: 0x" + last_token, "string"), BasicJsonType())); + } + } + } + + /*! + @brief reads a MessagePack byte array + + This function first reads starting bytes to determine the expected + byte array length and then copies this number of bytes into a byte array. + + @param[out] result created byte array + + @return whether byte array creation completed + */ + bool get_msgpack_binary(binary_t& result) + { + // helper function to set the subtype + auto assign_and_return_true = [&result](std::int8_t subtype) + { + result.set_subtype(static_cast(subtype)); + return true; + }; + + switch (current) + { + case 0xC4: // bin 8 + { + std::uint8_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC5: // bin 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC6: // bin 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC7: // ext 8 + { + std::uint8_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xC8: // ext 16 + { + std::uint16_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xC9: // ext 32 + { + std::uint32_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xD4: // fixext 1 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 1, result) && + assign_and_return_true(subtype); + } + + case 0xD5: // fixext 2 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 2, result) && + assign_and_return_true(subtype); + } + + case 0xD6: // fixext 4 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 4, result) && + assign_and_return_true(subtype); + } + + case 0xD7: // fixext 8 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 8, result) && + assign_and_return_true(subtype); + } + + case 0xD8: // fixext 16 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 16, result) && + assign_and_return_true(subtype); + } + + default: // LCOV_EXCL_LINE + return false; // LCOV_EXCL_LINE + } + } + + /*! + @param[in] len the length of the array + @return whether array creation completed + */ + bool get_msgpack_array(const std::size_t len) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) + { + return false; + } + + for (std::size_t i = 0; i < len; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) + { + return false; + } + } + + return sax->end_array(); + } + + /*! + @param[in] len the length of the object + @return whether object creation completed + */ + bool get_msgpack_object(const std::size_t len) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) + { + return false; + } + + string_t key; + for (std::size_t i = 0; i < len; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!get_msgpack_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) + { + return false; + } + key.clear(); + } + + return sax->end_object(); + } + + //////////// + // UBJSON // + //////////// + + /*! + @param[in] get_char whether a new character should be retrieved from the + input (true, default) or whether the last read + character should be considered instead + + @return whether a valid UBJSON value was passed to the SAX parser + */ + bool parse_ubjson_internal(const bool get_char = true) + { + return get_ubjson_value(get_char ? get_ignore_noop() : current); + } + + /*! + @brief reads a UBJSON string + + This function is either called after reading the 'S' byte explicitly + indicating a string, or in case of an object key where the 'S' byte can be + left out. + + @param[out] result created string + @param[in] get_char whether a new character should be retrieved from the + input (true, default) or whether the last read + character should be considered instead + + @return whether string creation completed + */ + bool get_ubjson_string(string_t& result, const bool get_char = true) + { + if (get_char) + { + get(); // TODO(niels): may we ignore N here? + } + + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) + { + return false; + } + + switch (current) + { + case 'U': + { + std::uint8_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'i': + { + std::int8_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'I': + { + std::int16_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'l': + { + std::int32_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'L': + { + std::int64_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + default: + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L); last byte: 0x" + last_token, "string"), BasicJsonType())); + } + } + + /*! + @param[out] result determined size + @return whether size determination completed + */ + bool get_ubjson_size_value(std::size_t& result) + { + switch (get_ignore_noop()) + { + case 'U': + { + std::uint8_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'i': + { + std::int8_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); // NOLINT(bugprone-signed-char-misuse,cert-str34-c): number is not a char + return true; + } + + case 'I': + { + std::int16_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'l': + { + std::int32_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'L': + { + std::int64_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L) after '#'; last byte: 0x" + last_token, "size"), BasicJsonType())); + } + } + } + + /*! + @brief determine the type and size for a container + + In the optimized UBJSON format, a type and a size can be provided to allow + for a more compact representation. + + @param[out] result pair of the size and the type + + @return whether pair creation completed + */ + bool get_ubjson_size_type(std::pair& result) + { + result.first = string_t::npos; // size + result.second = 0; // type + + get_ignore_noop(); + + if (current == '$') + { + result.second = get(); // must not ignore 'N', because 'N' maybe the type + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "type"))) + { + return false; + } + + get_ignore_noop(); + if (JSON_HEDLEY_UNLIKELY(current != '#')) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) + { + return false; + } + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "expected '#' after type information; last byte: 0x" + last_token, "size"), BasicJsonType())); + } + + return get_ubjson_size_value(result.first); + } + + if (current == '#') + { + return get_ubjson_size_value(result.first); + } + + return true; + } + + /*! + @param prefix the previously read or set type prefix + @return whether value creation completed + */ + bool get_ubjson_value(const char_int_type prefix) + { + switch (prefix) + { + case std::char_traits::eof(): // EOF + return unexpect_eof(input_format_t::ubjson, "value"); + + case 'T': // true + return sax->boolean(true); + case 'F': // false + return sax->boolean(false); + + case 'Z': // null + return sax->null(); + + case 'U': + { + std::uint8_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_unsigned(number); + } + + case 'i': + { + std::int8_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'I': + { + std::int16_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'l': + { + std::int32_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'L': + { + std::int64_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'd': + { + float number{}; + return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); + } + + case 'D': + { + double number{}; + return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); + } + + case 'H': + { + return get_ubjson_high_precision_number(); + } + + case 'C': // char + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "char"))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(current > 127)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "byte after 'C' must be in range 0x00..0x7F; last byte: 0x" + last_token, "char"), BasicJsonType())); + } + string_t s(1, static_cast(current)); + return sax->string(s); + } + + case 'S': // string + { + string_t s; + return get_ubjson_string(s) && sax->string(s); + } + + case '[': // array + return get_ubjson_array(); + + case '{': // object + return get_ubjson_object(); + + default: // anything else + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "invalid byte: 0x" + last_token, "value"), BasicJsonType())); + } + } + } + + /*! + @return whether array creation completed + */ + bool get_ubjson_array() + { + std::pair size_and_type; + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) + { + return false; + } + + if (size_and_type.first != string_t::npos) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(size_and_type.first))) + { + return false; + } + + if (size_and_type.second != 0) + { + if (size_and_type.second != 'N') + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) + { + return false; + } + } + } + } + else + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + } + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + while (current != ']') + { + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal(false))) + { + return false; + } + get_ignore_noop(); + } + } + + return sax->end_array(); + } + + /*! + @return whether object creation completed + */ + bool get_ubjson_object() + { + std::pair size_and_type; + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) + { + return false; + } + + string_t key; + if (size_and_type.first != string_t::npos) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(size_and_type.first))) + { + return false; + } + + if (size_and_type.second != 0) + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) + { + return false; + } + key.clear(); + } + } + else + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + key.clear(); + } + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + while (current != '}') + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key, false) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + get_ignore_noop(); + key.clear(); + } + } + + return sax->end_object(); + } + + // Note, no reader for UBJSON binary types is implemented because they do + // not exist + + bool get_ubjson_high_precision_number() + { + // get size of following number string + std::size_t size{}; + auto res = get_ubjson_size_value(size); + if (JSON_HEDLEY_UNLIKELY(!res)) + { + return res; + } + + // get number string + std::vector number_vector; + for (std::size_t i = 0; i < size; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "number"))) + { + return false; + } + number_vector.push_back(static_cast(current)); + } + + // parse number string + using ia_type = decltype(detail::input_adapter(number_vector)); + auto number_lexer = detail::lexer(detail::input_adapter(number_vector), false); + const auto result_number = number_lexer.scan(); + const auto number_string = number_lexer.get_token_string(); + const auto result_remainder = number_lexer.scan(); + + using token_type = typename detail::lexer_base::token_type; + + if (JSON_HEDLEY_UNLIKELY(result_remainder != token_type::end_of_input)) + { + return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"), BasicJsonType())); + } + + switch (result_number) + { + case token_type::value_integer: + return sax->number_integer(number_lexer.get_number_integer()); + case token_type::value_unsigned: + return sax->number_unsigned(number_lexer.get_number_unsigned()); + case token_type::value_float: + return sax->number_float(number_lexer.get_number_float(), std::move(number_string)); + case token_type::uninitialized: + case token_type::literal_true: + case token_type::literal_false: + case token_type::literal_null: + case token_type::value_string: + case token_type::begin_array: + case token_type::begin_object: + case token_type::end_array: + case token_type::end_object: + case token_type::name_separator: + case token_type::value_separator: + case token_type::parse_error: + case token_type::end_of_input: + case token_type::literal_or_value: + default: + return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"), BasicJsonType())); + } + } + + /////////////////////// + // Utility functions // + /////////////////////// + + /*! + @brief get next character from the input + + This function provides the interface to the used input adapter. It does + not throw in case the input reached EOF, but returns a -'ve valued + `std::char_traits::eof()` in that case. + + @return character read from the input + */ + char_int_type get() + { + ++chars_read; + return current = ia.get_character(); + } + + /*! + @return character read from the input after ignoring all 'N' entries + */ + char_int_type get_ignore_noop() + { + do + { + get(); + } + while (current == 'N'); + + return current; + } + + /* + @brief read a number from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[out] result number of type @a NumberType + + @return whether conversion completed + + @note This function needs to respect the system's endianess, because + bytes in CBOR, MessagePack, and UBJSON are stored in network order + (big endian) and therefore need reordering on little endian systems. + */ + template + bool get_number(const input_format_t format, NumberType& result) + { + // step 1: read input into array with system's byte order + std::array vec{}; + for (std::size_t i = 0; i < sizeof(NumberType); ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "number"))) + { + return false; + } + + // reverse byte order prior to conversion if necessary + if (is_little_endian != InputIsLittleEndian) + { + vec[sizeof(NumberType) - i - 1] = static_cast(current); + } + else + { + vec[i] = static_cast(current); // LCOV_EXCL_LINE + } + } + + // step 2: convert array into number of type T and return + std::memcpy(&result, vec.data(), sizeof(NumberType)); + return true; + } + + /*! + @brief create a string by reading characters from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[in] len number of characters to read + @param[out] result string created by reading @a len bytes + + @return whether string creation completed + + @note We can not reserve @a len bytes for the result, because @a len + may be too large. Usually, @ref unexpect_eof() detects the end of + the input before we run out of string memory. + */ + template + bool get_string(const input_format_t format, + const NumberType len, + string_t& result) + { + bool success = true; + for (NumberType i = 0; i < len; i++) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "string"))) + { + success = false; + break; + } + result.push_back(static_cast(current)); + } + return success; + } + + /*! + @brief create a byte array by reading bytes from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[in] len number of bytes to read + @param[out] result byte array created by reading @a len bytes + + @return whether byte array creation completed + + @note We can not reserve @a len bytes for the result, because @a len + may be too large. Usually, @ref unexpect_eof() detects the end of + the input before we run out of memory. + */ + template + bool get_binary(const input_format_t format, + const NumberType len, + binary_t& result) + { + bool success = true; + for (NumberType i = 0; i < len; i++) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "binary"))) + { + success = false; + break; + } + result.push_back(static_cast(current)); + } + return success; + } + + /*! + @param[in] format the current format (for diagnostics) + @param[in] context further context information (for diagnostics) + @return whether the last read character is not EOF + */ + JSON_HEDLEY_NON_NULL(3) + bool unexpect_eof(const input_format_t format, const char* context) const + { + if (JSON_HEDLEY_UNLIKELY(current == std::char_traits::eof())) + { + return sax->parse_error(chars_read, "", + parse_error::create(110, chars_read, exception_message(format, "unexpected end of input", context), BasicJsonType())); + } + return true; + } + + /*! + @return a string representation of the last read byte + */ + std::string get_token_string() const + { + std::array cr{{}}; + (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(current)); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + return std::string{cr.data()}; + } + + /*! + @param[in] format the current format + @param[in] detail a detailed error message + @param[in] context further context information + @return a message string to use in the parse_error exceptions + */ + std::string exception_message(const input_format_t format, + const std::string& detail, + const std::string& context) const + { + std::string error_msg = "syntax error while parsing "; + + switch (format) + { + case input_format_t::cbor: + error_msg += "CBOR"; + break; + + case input_format_t::msgpack: + error_msg += "MessagePack"; + break; + + case input_format_t::ubjson: + error_msg += "UBJSON"; + break; + + case input_format_t::bson: + error_msg += "BSON"; + break; + + case input_format_t::json: // LCOV_EXCL_LINE + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + + return error_msg + " " + context + ": " + detail; + } + + private: + /// input adapter + InputAdapterType ia; + + /// the current character + char_int_type current = std::char_traits::eof(); + + /// the number of characters read + std::size_t chars_read = 0; + + /// whether we can assume little endianess + const bool is_little_endian = little_endianess(); + + /// the SAX parser + json_sax_t* sax = nullptr; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include // isfinite +#include // uint8_t +#include // function +#include // string +#include // move +#include // vector + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +//////////// +// parser // +//////////// + +enum class parse_event_t : std::uint8_t +{ + /// the parser read `{` and started to process a JSON object + object_start, + /// the parser read `}` and finished processing a JSON object + object_end, + /// the parser read `[` and started to process a JSON array + array_start, + /// the parser read `]` and finished processing a JSON array + array_end, + /// the parser read a key of a value in an object + key, + /// the parser finished reading a JSON value + value +}; + +template +using parser_callback_t = + std::function; + +/*! +@brief syntax analysis + +This class implements a recursive descent parser. +*/ +template +class parser +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using lexer_t = lexer; + using token_type = typename lexer_t::token_type; + + public: + /// a parser reading from an input adapter + explicit parser(InputAdapterType&& adapter, + const parser_callback_t cb = nullptr, + const bool allow_exceptions_ = true, + const bool skip_comments = false) + : callback(cb) + , m_lexer(std::move(adapter), skip_comments) + , allow_exceptions(allow_exceptions_) + { + // read first token + get_token(); + } + + /*! + @brief public parser interface + + @param[in] strict whether to expect the last token to be EOF + @param[in,out] result parsed JSON value + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + void parse(const bool strict, BasicJsonType& result) + { + if (callback) + { + json_sax_dom_callback_parser sdp(result, callback, allow_exceptions); + sax_parse_internal(&sdp); + + // in strict mode, input must be completely read + if (strict && (get_token() != token_type::end_of_input)) + { + sdp.parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_of_input, "value"), BasicJsonType())); + } + + // in case of an error, return discarded value + if (sdp.is_errored()) + { + result = value_t::discarded; + return; + } + + // set top-level value to null if it was discarded by the callback + // function + if (result.is_discarded()) + { + result = nullptr; + } + } + else + { + json_sax_dom_parser sdp(result, allow_exceptions); + sax_parse_internal(&sdp); + + // in strict mode, input must be completely read + if (strict && (get_token() != token_type::end_of_input)) + { + sdp.parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_of_input, "value"), BasicJsonType())); + } + + // in case of an error, return discarded value + if (sdp.is_errored()) + { + result = value_t::discarded; + return; + } + } + + result.assert_invariant(); + } + + /*! + @brief public accept interface + + @param[in] strict whether to expect the last token to be EOF + @return whether the input is a proper JSON text + */ + bool accept(const bool strict = true) + { + json_sax_acceptor sax_acceptor; + return sax_parse(&sax_acceptor, strict); + } + + template + JSON_HEDLEY_NON_NULL(2) + bool sax_parse(SAX* sax, const bool strict = true) + { + (void)detail::is_sax_static_asserts {}; + const bool result = sax_parse_internal(sax); + + // strict mode: next byte must be EOF + if (result && strict && (get_token() != token_type::end_of_input)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_of_input, "value"), BasicJsonType())); + } + + return result; + } + + private: + template + JSON_HEDLEY_NON_NULL(2) + bool sax_parse_internal(SAX* sax) + { + // stack to remember the hierarchy of structured values we are parsing + // true = array; false = object + std::vector states; + // value to avoid a goto (see comment where set to true) + bool skip_to_state_evaluation = false; + + while (true) + { + if (!skip_to_state_evaluation) + { + // invariant: get_token() was called before each iteration + switch (last_token) + { + case token_type::begin_object: + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + // closing } -> we are done + if (get_token() == token_type::end_object) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) + { + return false; + } + break; + } + + // parse key + if (JSON_HEDLEY_UNLIKELY(last_token != token_type::value_string)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::value_string, "object key"), BasicJsonType())); + } + if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) + { + return false; + } + + // parse separator (:) + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::name_separator, "object separator"), BasicJsonType())); + } + + // remember we are now inside an object + states.push_back(false); + + // parse values + get_token(); + continue; + } + + case token_type::begin_array: + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + // closing ] -> we are done + if (get_token() == token_type::end_array) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) + { + return false; + } + break; + } + + // remember we are now inside an array + states.push_back(true); + + // parse values (no need to call get_token) + continue; + } + + case token_type::value_float: + { + const auto res = m_lexer.get_number_float(); + + if (JSON_HEDLEY_UNLIKELY(!std::isfinite(res))) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + out_of_range::create(406, "number overflow parsing '" + m_lexer.get_token_string() + "'", BasicJsonType())); + } + + if (JSON_HEDLEY_UNLIKELY(!sax->number_float(res, m_lexer.get_string()))) + { + return false; + } + + break; + } + + case token_type::literal_false: + { + if (JSON_HEDLEY_UNLIKELY(!sax->boolean(false))) + { + return false; + } + break; + } + + case token_type::literal_null: + { + if (JSON_HEDLEY_UNLIKELY(!sax->null())) + { + return false; + } + break; + } + + case token_type::literal_true: + { + if (JSON_HEDLEY_UNLIKELY(!sax->boolean(true))) + { + return false; + } + break; + } + + case token_type::value_integer: + { + if (JSON_HEDLEY_UNLIKELY(!sax->number_integer(m_lexer.get_number_integer()))) + { + return false; + } + break; + } + + case token_type::value_string: + { + if (JSON_HEDLEY_UNLIKELY(!sax->string(m_lexer.get_string()))) + { + return false; + } + break; + } + + case token_type::value_unsigned: + { + if (JSON_HEDLEY_UNLIKELY(!sax->number_unsigned(m_lexer.get_number_unsigned()))) + { + return false; + } + break; + } + + case token_type::parse_error: + { + // using "uninitialized" to avoid "expected" message + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::uninitialized, "value"), BasicJsonType())); + } + + case token_type::uninitialized: + case token_type::end_array: + case token_type::end_object: + case token_type::name_separator: + case token_type::value_separator: + case token_type::end_of_input: + case token_type::literal_or_value: + default: // the last token was unexpected + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::literal_or_value, "value"), BasicJsonType())); + } + } + } + else + { + skip_to_state_evaluation = false; + } + + // we reached this line after we successfully parsed a value + if (states.empty()) + { + // empty stack: we reached the end of the hierarchy: done + return true; + } + + if (states.back()) // array + { + // comma -> next value + if (get_token() == token_type::value_separator) + { + // parse a new value + get_token(); + continue; + } + + // closing ] + if (JSON_HEDLEY_LIKELY(last_token == token_type::end_array)) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) + { + return false; + } + + // We are done with this array. Before we can parse a + // new value, we need to evaluate the new state first. + // By setting skip_to_state_evaluation to false, we + // are effectively jumping to the beginning of this if. + JSON_ASSERT(!states.empty()); + states.pop_back(); + skip_to_state_evaluation = true; + continue; + } + + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_array, "array"), BasicJsonType())); + } + + // states.back() is false -> object + + // comma -> next value + if (get_token() == token_type::value_separator) + { + // parse key + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::value_string)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::value_string, "object key"), BasicJsonType())); + } + + if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) + { + return false; + } + + // parse separator (:) + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::name_separator, "object separator"), BasicJsonType())); + } + + // parse values + get_token(); + continue; + } + + // closing } + if (JSON_HEDLEY_LIKELY(last_token == token_type::end_object)) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) + { + return false; + } + + // We are done with this object. Before we can parse a + // new value, we need to evaluate the new state first. + // By setting skip_to_state_evaluation to false, we + // are effectively jumping to the beginning of this if. + JSON_ASSERT(!states.empty()); + states.pop_back(); + skip_to_state_evaluation = true; + continue; + } + + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_object, "object"), BasicJsonType())); + } + } + + /// get next token from lexer + token_type get_token() + { + return last_token = m_lexer.scan(); + } + + std::string exception_message(const token_type expected, const std::string& context) + { + std::string error_msg = "syntax error "; + + if (!context.empty()) + { + error_msg += "while parsing " + context + " "; + } + + error_msg += "- "; + + if (last_token == token_type::parse_error) + { + error_msg += std::string(m_lexer.get_error_message()) + "; last read: '" + + m_lexer.get_token_string() + "'"; + } + else + { + error_msg += "unexpected " + std::string(lexer_t::token_type_name(last_token)); + } + + if (expected != token_type::uninitialized) + { + error_msg += "; expected " + std::string(lexer_t::token_type_name(expected)); + } + + return error_msg; + } + + private: + /// callback function + const parser_callback_t callback = nullptr; + /// the type of the last read token + token_type last_token = token_type::uninitialized; + /// the lexer + lexer_t m_lexer; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; +}; + +} // namespace detail +} // namespace nlohmann + +// #include + + +// #include + + +#include // ptrdiff_t +#include // numeric_limits + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/* +@brief an iterator for primitive JSON types + +This class models an iterator for primitive JSON types (boolean, number, +string). It's only purpose is to allow the iterator/const_iterator classes +to "iterate" over primitive values. Internally, the iterator is modeled by +a `difference_type` variable. Value begin_value (`0`) models the begin, +end_value (`1`) models past the end. +*/ +class primitive_iterator_t +{ + private: + using difference_type = std::ptrdiff_t; + static constexpr difference_type begin_value = 0; + static constexpr difference_type end_value = begin_value + 1; + + JSON_PRIVATE_UNLESS_TESTED: + /// iterator as signed integer type + difference_type m_it = (std::numeric_limits::min)(); + + public: + constexpr difference_type get_value() const noexcept + { + return m_it; + } + + /// set iterator to a defined beginning + void set_begin() noexcept + { + m_it = begin_value; + } + + /// set iterator to a defined past the end + void set_end() noexcept + { + m_it = end_value; + } + + /// return whether the iterator can be dereferenced + constexpr bool is_begin() const noexcept + { + return m_it == begin_value; + } + + /// return whether the iterator is at end + constexpr bool is_end() const noexcept + { + return m_it == end_value; + } + + friend constexpr bool operator==(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it == rhs.m_it; + } + + friend constexpr bool operator<(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it < rhs.m_it; + } + + primitive_iterator_t operator+(difference_type n) noexcept + { + auto result = *this; + result += n; + return result; + } + + friend constexpr difference_type operator-(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it - rhs.m_it; + } + + primitive_iterator_t& operator++() noexcept + { + ++m_it; + return *this; + } + + primitive_iterator_t const operator++(int) noexcept // NOLINT(readability-const-return-type) + { + auto result = *this; + ++m_it; + return result; + } + + primitive_iterator_t& operator--() noexcept + { + --m_it; + return *this; + } + + primitive_iterator_t const operator--(int) noexcept // NOLINT(readability-const-return-type) + { + auto result = *this; + --m_it; + return result; + } + + primitive_iterator_t& operator+=(difference_type n) noexcept + { + m_it += n; + return *this; + } + + primitive_iterator_t& operator-=(difference_type n) noexcept + { + m_it -= n; + return *this; + } +}; +} // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ +namespace detail +{ +/*! +@brief an iterator value + +@note This structure could easily be a union, but MSVC currently does not allow +unions members with complex constructors, see https://github.com/nlohmann/json/pull/105. +*/ +template struct internal_iterator +{ + /// iterator for JSON objects + typename BasicJsonType::object_t::iterator object_iterator {}; + /// iterator for JSON arrays + typename BasicJsonType::array_t::iterator array_iterator {}; + /// generic iterator for all other types + primitive_iterator_t primitive_iterator {}; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // iterator, random_access_iterator_tag, bidirectional_iterator_tag, advance, next +#include // conditional, is_const, remove_const + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +// forward declare, to be able to friend it later on +template class iteration_proxy; +template class iteration_proxy_value; + +/*! +@brief a template for a bidirectional iterator for the @ref basic_json class +This class implements a both iterators (iterator and const_iterator) for the +@ref basic_json class. +@note An iterator is called *initialized* when a pointer to a JSON value has + been set (e.g., by a constructor or a copy assignment). If the iterator is + default-constructed, it is *uninitialized* and most methods are undefined. + **The library uses assertions to detect calls on uninitialized iterators.** +@requirement The class satisfies the following concept requirements: +- +[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): + The iterator that can be moved can be moved in both directions (i.e. + incremented and decremented). +@since version 1.0.0, simplified in version 2.0.9, change to bidirectional + iterators in version 3.0.0 (see https://github.com/nlohmann/json/issues/593) +*/ +template +class iter_impl // NOLINT(cppcoreguidelines-special-member-functions,hicpp-special-member-functions) +{ + /// the iterator with BasicJsonType of different const-ness + using other_iter_impl = iter_impl::value, typename std::remove_const::type, const BasicJsonType>::type>; + /// allow basic_json to access private members + friend other_iter_impl; + friend BasicJsonType; + friend iteration_proxy; + friend iteration_proxy_value; + + using object_t = typename BasicJsonType::object_t; + using array_t = typename BasicJsonType::array_t; + // make sure BasicJsonType is basic_json or const basic_json + static_assert(is_basic_json::type>::value, + "iter_impl only accepts (const) basic_json"); + + public: + + /// The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17. + /// The C++ Standard has never required user-defined iterators to derive from std::iterator. + /// A user-defined iterator should provide publicly accessible typedefs named + /// iterator_category, value_type, difference_type, pointer, and reference. + /// Note that value_type is required to be non-const, even for constant iterators. + using iterator_category = std::bidirectional_iterator_tag; + + /// the type of the values when the iterator is dereferenced + using value_type = typename BasicJsonType::value_type; + /// a type to represent differences between iterators + using difference_type = typename BasicJsonType::difference_type; + /// defines a pointer to the type iterated over (value_type) + using pointer = typename std::conditional::value, + typename BasicJsonType::const_pointer, + typename BasicJsonType::pointer>::type; + /// defines a reference to the type iterated over (value_type) + using reference = + typename std::conditional::value, + typename BasicJsonType::const_reference, + typename BasicJsonType::reference>::type; + + iter_impl() = default; + ~iter_impl() = default; + iter_impl(iter_impl&&) noexcept = default; + iter_impl& operator=(iter_impl&&) noexcept = default; + + /*! + @brief constructor for a given JSON instance + @param[in] object pointer to a JSON object for this iterator + @pre object != nullptr + @post The iterator is initialized; i.e. `m_object != nullptr`. + */ + explicit iter_impl(pointer object) noexcept : m_object(object) + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = typename object_t::iterator(); + break; + } + + case value_t::array: + { + m_it.array_iterator = typename array_t::iterator(); + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + m_it.primitive_iterator = primitive_iterator_t(); + break; + } + } + } + + /*! + @note The conventional copy constructor and copy assignment are implicitly + defined. Combined with the following converting constructor and + assignment, they support: (1) copy from iterator to iterator, (2) + copy from const iterator to const iterator, and (3) conversion from + iterator to const iterator. However conversion from const iterator + to iterator is not defined. + */ + + /*! + @brief const copy constructor + @param[in] other const iterator to copy from + @note This copy constructor had to be defined explicitly to circumvent a bug + occurring on msvc v19.0 compiler (VS 2015) debug build. For more + information refer to: https://github.com/nlohmann/json/issues/1608 + */ + iter_impl(const iter_impl& other) noexcept + : m_object(other.m_object), m_it(other.m_it) + {} + + /*! + @brief converting assignment + @param[in] other const iterator to copy from + @return const/non-const iterator + @note It is not checked whether @a other is initialized. + */ + iter_impl& operator=(const iter_impl& other) noexcept + { + if (&other != this) + { + m_object = other.m_object; + m_it = other.m_it; + } + return *this; + } + + /*! + @brief converting constructor + @param[in] other non-const iterator to copy from + @note It is not checked whether @a other is initialized. + */ + iter_impl(const iter_impl::type>& other) noexcept + : m_object(other.m_object), m_it(other.m_it) + {} + + /*! + @brief converting assignment + @param[in] other non-const iterator to copy from + @return const/non-const iterator + @note It is not checked whether @a other is initialized. + */ + iter_impl& operator=(const iter_impl::type>& other) noexcept // NOLINT(cert-oop54-cpp) + { + m_object = other.m_object; + m_it = other.m_it; + return *this; + } + + JSON_PRIVATE_UNLESS_TESTED: + /*! + @brief set the iterator to the first value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_begin() noexcept + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->begin(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->begin(); + break; + } + + case value_t::null: + { + // set to end so begin()==end() is true: null is empty + m_it.primitive_iterator.set_end(); + break; + } + + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + m_it.primitive_iterator.set_begin(); + break; + } + } + } + + /*! + @brief set the iterator past the last value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_end() noexcept + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->end(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->end(); + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + m_it.primitive_iterator.set_end(); + break; + } + } + } + + public: + /*! + @brief return a reference to the value pointed to by the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator*() const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); + return m_it.object_iterator->second; + } + + case value_t::array: + { + JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); + return *m_it.array_iterator; + } + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value", *m_object)); + + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value", *m_object)); + } + } + } + + /*! + @brief dereference the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + pointer operator->() const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); + return &(m_it.object_iterator->second); + } + + case value_t::array: + { + JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); + return &*m_it.array_iterator; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) + { + return m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value", *m_object)); + } + } + } + + /*! + @brief post-increment (it++) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl const operator++(int) // NOLINT(readability-const-return-type) + { + auto result = *this; + ++(*this); + return result; + } + + /*! + @brief pre-increment (++it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator++() + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, 1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, 1); + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + ++m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief post-decrement (it--) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl const operator--(int) // NOLINT(readability-const-return-type) + { + auto result = *this; + --(*this); + return result; + } + + /*! + @brief pre-decrement (--it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator--() + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, -1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, -1); + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + --m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief comparison: equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + template < typename IterImpl, detail::enable_if_t < (std::is_same::value || std::is_same::value), std::nullptr_t > = nullptr > + bool operator==(const IterImpl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers", *m_object)); + } + + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + return (m_it.object_iterator == other.m_it.object_iterator); + + case value_t::array: + return (m_it.array_iterator == other.m_it.array_iterator); + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + return (m_it.primitive_iterator == other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: not equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + template < typename IterImpl, detail::enable_if_t < (std::is_same::value || std::is_same::value), std::nullptr_t > = nullptr > + bool operator!=(const IterImpl& other) const + { + return !operator==(other); + } + + /*! + @brief comparison: smaller + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<(const iter_impl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers", *m_object)); + } + + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(213, "cannot compare order of object iterators", *m_object)); + + case value_t::array: + return (m_it.array_iterator < other.m_it.array_iterator); + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + return (m_it.primitive_iterator < other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: less than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<=(const iter_impl& other) const + { + return !other.operator < (*this); + } + + /*! + @brief comparison: greater than + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>(const iter_impl& other) const + { + return !operator<=(other); + } + + /*! + @brief comparison: greater than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>=(const iter_impl& other) const + { + return !operator<(other); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator+=(difference_type i) + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators", *m_object)); + + case value_t::array: + { + std::advance(m_it.array_iterator, i); + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + m_it.primitive_iterator += i; + break; + } + } + + return *this; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator-=(difference_type i) + { + return operator+=(-i); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator+(difference_type i) const + { + auto result = *this; + result += i; + return result; + } + + /*! + @brief addition of distance and iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + friend iter_impl operator+(difference_type i, const iter_impl& it) + { + auto result = it; + result += i; + return result; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator-(difference_type i) const + { + auto result = *this; + result -= i; + return result; + } + + /*! + @brief return difference + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + difference_type operator-(const iter_impl& other) const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators", *m_object)); + + case value_t::array: + return m_it.array_iterator - other.m_it.array_iterator; + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + return m_it.primitive_iterator - other.m_it.primitive_iterator; + } + } + + /*! + @brief access to successor + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator[](difference_type n) const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(208, "cannot use operator[] for object iterators", *m_object)); + + case value_t::array: + return *std::next(m_it.array_iterator, n); + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value", *m_object)); + + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.get_value() == -n)) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value", *m_object)); + } + } + } + + /*! + @brief return the key of an object iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + const typename object_t::key_type& key() const + { + JSON_ASSERT(m_object != nullptr); + + if (JSON_HEDLEY_LIKELY(m_object->is_object())) + { + return m_it.object_iterator->first; + } + + JSON_THROW(invalid_iterator::create(207, "cannot use key() for non-object iterators", *m_object)); + } + + /*! + @brief return the value of an iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference value() const + { + return operator*(); + } + + JSON_PRIVATE_UNLESS_TESTED: + /// associated JSON instance + pointer m_object = nullptr; + /// the actual iterator of the associated instance + internal_iterator::type> m_it {}; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // ptrdiff_t +#include // reverse_iterator +#include // declval + +namespace nlohmann +{ +namespace detail +{ +////////////////////// +// reverse_iterator // +////////////////////// + +/*! +@brief a template for a reverse iterator class + +@tparam Base the base iterator type to reverse. Valid types are @ref +iterator (to create @ref reverse_iterator) and @ref const_iterator (to +create @ref const_reverse_iterator). + +@requirement The class satisfies the following concept requirements: +- +[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): + The iterator that can be moved can be moved in both directions (i.e. + incremented and decremented). +- [OutputIterator](https://en.cppreference.com/w/cpp/named_req/OutputIterator): + It is possible to write to the pointed-to element (only if @a Base is + @ref iterator). + +@since version 1.0.0 +*/ +template +class json_reverse_iterator : public std::reverse_iterator +{ + public: + using difference_type = std::ptrdiff_t; + /// shortcut to the reverse iterator adapter + using base_iterator = std::reverse_iterator; + /// the reference type for the pointed-to element + using reference = typename Base::reference; + + /// create reverse iterator from iterator + explicit json_reverse_iterator(const typename base_iterator::iterator_type& it) noexcept + : base_iterator(it) {} + + /// create reverse iterator from base class + explicit json_reverse_iterator(const base_iterator& it) noexcept : base_iterator(it) {} + + /// post-increment (it++) + json_reverse_iterator const operator++(int) // NOLINT(readability-const-return-type) + { + return static_cast(base_iterator::operator++(1)); + } + + /// pre-increment (++it) + json_reverse_iterator& operator++() + { + return static_cast(base_iterator::operator++()); + } + + /// post-decrement (it--) + json_reverse_iterator const operator--(int) // NOLINT(readability-const-return-type) + { + return static_cast(base_iterator::operator--(1)); + } + + /// pre-decrement (--it) + json_reverse_iterator& operator--() + { + return static_cast(base_iterator::operator--()); + } + + /// add to iterator + json_reverse_iterator& operator+=(difference_type i) + { + return static_cast(base_iterator::operator+=(i)); + } + + /// add to iterator + json_reverse_iterator operator+(difference_type i) const + { + return static_cast(base_iterator::operator+(i)); + } + + /// subtract from iterator + json_reverse_iterator operator-(difference_type i) const + { + return static_cast(base_iterator::operator-(i)); + } + + /// return difference + difference_type operator-(const json_reverse_iterator& other) const + { + return base_iterator(*this) - base_iterator(other); + } + + /// access to successor + reference operator[](difference_type n) const + { + return *(this->operator+(n)); + } + + /// return the key of an object iterator + auto key() const -> decltype(std::declval().key()) + { + auto it = --this->base(); + return it.key(); + } + + /// return the value of an iterator + reference value() const + { + auto it = --this->base(); + return it.operator * (); + } +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // all_of +#include // isdigit +#include // max +#include // accumulate +#include // string +#include // move +#include // vector + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +template +class json_pointer +{ + // allow basic_json to access private members + NLOHMANN_BASIC_JSON_TPL_DECLARATION + friend class basic_json; + + public: + /*! + @brief create JSON pointer + + Create a JSON pointer according to the syntax described in + [Section 3 of RFC6901](https://tools.ietf.org/html/rfc6901#section-3). + + @param[in] s string representing the JSON pointer; if omitted, the empty + string is assumed which references the whole JSON value + + @throw parse_error.107 if the given JSON pointer @a s is nonempty and does + not begin with a slash (`/`); see example below + + @throw parse_error.108 if a tilde (`~`) in the given JSON pointer @a s is + not followed by `0` (representing `~`) or `1` (representing `/`); see + example below + + @liveexample{The example shows the construction several valid JSON pointers + as well as the exceptional behavior.,json_pointer} + + @since version 2.0.0 + */ + explicit json_pointer(const std::string& s = "") + : reference_tokens(split(s)) + {} + + /*! + @brief return a string representation of the JSON pointer + + @invariant For each JSON pointer `ptr`, it holds: + @code {.cpp} + ptr == json_pointer(ptr.to_string()); + @endcode + + @return a string representation of the JSON pointer + + @liveexample{The example shows the result of `to_string`.,json_pointer__to_string} + + @since version 2.0.0 + */ + std::string to_string() const + { + return std::accumulate(reference_tokens.begin(), reference_tokens.end(), + std::string{}, + [](const std::string & a, const std::string & b) + { + return a + "/" + detail::escape(b); + }); + } + + /// @copydoc to_string() + operator std::string() const + { + return to_string(); + } + + /*! + @brief append another JSON pointer at the end of this JSON pointer + + @param[in] ptr JSON pointer to append + @return JSON pointer with @a ptr appended + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Linear in the length of @a ptr. + + @sa see @ref operator/=(std::string) to append a reference token + @sa see @ref operator/=(std::size_t) to append an array index + @sa see @ref operator/(const json_pointer&, const json_pointer&) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(const json_pointer& ptr) + { + reference_tokens.insert(reference_tokens.end(), + ptr.reference_tokens.begin(), + ptr.reference_tokens.end()); + return *this; + } + + /*! + @brief append an unescaped reference token at the end of this JSON pointer + + @param[in] token reference token to append + @return JSON pointer with @a token appended without escaping @a token + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Amortized constant. + + @sa see @ref operator/=(const json_pointer&) to append a JSON pointer + @sa see @ref operator/=(std::size_t) to append an array index + @sa see @ref operator/(const json_pointer&, std::size_t) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(std::string token) + { + push_back(std::move(token)); + return *this; + } + + /*! + @brief append an array index at the end of this JSON pointer + + @param[in] array_idx array index to append + @return JSON pointer with @a array_idx appended + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Amortized constant. + + @sa see @ref operator/=(const json_pointer&) to append a JSON pointer + @sa see @ref operator/=(std::string) to append a reference token + @sa see @ref operator/(const json_pointer&, std::string) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(std::size_t array_idx) + { + return *this /= std::to_string(array_idx); + } + + /*! + @brief create a new JSON pointer by appending the right JSON pointer at the end of the left JSON pointer + + @param[in] lhs JSON pointer + @param[in] rhs JSON pointer + @return a new JSON pointer with @a rhs appended to @a lhs + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a lhs and @a rhs. + + @sa see @ref operator/=(const json_pointer&) to append a JSON pointer + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& lhs, + const json_pointer& rhs) + { + return json_pointer(lhs) /= rhs; + } + + /*! + @brief create a new JSON pointer by appending the unescaped token at the end of the JSON pointer + + @param[in] ptr JSON pointer + @param[in] token reference token + @return a new JSON pointer with unescaped @a token appended to @a ptr + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a ptr. + + @sa see @ref operator/=(std::string) to append a reference token + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& ptr, std::string token) // NOLINT(performance-unnecessary-value-param) + { + return json_pointer(ptr) /= std::move(token); + } + + /*! + @brief create a new JSON pointer by appending the array-index-token at the end of the JSON pointer + + @param[in] ptr JSON pointer + @param[in] array_idx array index + @return a new JSON pointer with @a array_idx appended to @a ptr + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a ptr. + + @sa see @ref operator/=(std::size_t) to append an array index + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& ptr, std::size_t array_idx) + { + return json_pointer(ptr) /= array_idx; + } + + /*! + @brief returns the parent of this JSON pointer + + @return parent of this JSON pointer; in case this JSON pointer is the root, + the root itself is returned + + @complexity Linear in the length of the JSON pointer. + + @liveexample{The example shows the result of `parent_pointer` for different + JSON Pointers.,json_pointer__parent_pointer} + + @since version 3.6.0 + */ + json_pointer parent_pointer() const + { + if (empty()) + { + return *this; + } + + json_pointer res = *this; + res.pop_back(); + return res; + } + + /*! + @brief remove last reference token + + @pre not `empty()` + + @liveexample{The example shows the usage of `pop_back`.,json_pointer__pop_back} + + @complexity Constant. + + @throw out_of_range.405 if JSON pointer has no parent + + @since version 3.6.0 + */ + void pop_back() + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent", BasicJsonType())); + } + + reference_tokens.pop_back(); + } + + /*! + @brief return last reference token + + @pre not `empty()` + @return last reference token + + @liveexample{The example shows the usage of `back`.,json_pointer__back} + + @complexity Constant. + + @throw out_of_range.405 if JSON pointer has no parent + + @since version 3.6.0 + */ + const std::string& back() const + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent", BasicJsonType())); + } + + return reference_tokens.back(); + } + + /*! + @brief append an unescaped token at the end of the reference pointer + + @param[in] token token to add + + @complexity Amortized constant. + + @liveexample{The example shows the result of `push_back` for different + JSON Pointers.,json_pointer__push_back} + + @since version 3.6.0 + */ + void push_back(const std::string& token) + { + reference_tokens.push_back(token); + } + + /// @copydoc push_back(const std::string&) + void push_back(std::string&& token) + { + reference_tokens.push_back(std::move(token)); + } + + /*! + @brief return whether pointer points to the root document + + @return true iff the JSON pointer points to the root document + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example shows the result of `empty` for different JSON + Pointers.,json_pointer__empty} + + @since version 3.6.0 + */ + bool empty() const noexcept + { + return reference_tokens.empty(); + } + + private: + /*! + @param[in] s reference token to be converted into an array index + + @return integer representation of @a s + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index begins not with a digit + @throw out_of_range.404 if string @a s could not be converted to an integer + @throw out_of_range.410 if an array index exceeds size_type + */ + static typename BasicJsonType::size_type array_index(const std::string& s) + { + using size_type = typename BasicJsonType::size_type; + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && s[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, "array index '" + s + "' must not begin with '0'", BasicJsonType())); + } + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && !(s[0] >= '1' && s[0] <= '9'))) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + s + "' is not a number", BasicJsonType())); + } + + std::size_t processed_chars = 0; + unsigned long long res = 0; // NOLINT(runtime/int) + JSON_TRY + { + res = std::stoull(s, &processed_chars); + } + JSON_CATCH(std::out_of_range&) + { + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'", BasicJsonType())); + } + + // check if the string was completely read + if (JSON_HEDLEY_UNLIKELY(processed_chars != s.size())) + { + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'", BasicJsonType())); + } + + // only triggered on special platforms (like 32bit), see also + // https://github.com/nlohmann/json/pull/2203 + if (res >= static_cast((std::numeric_limits::max)())) // NOLINT(runtime/int) + { + JSON_THROW(detail::out_of_range::create(410, "array index " + s + " exceeds size_type", BasicJsonType())); // LCOV_EXCL_LINE + } + + return static_cast(res); + } + + JSON_PRIVATE_UNLESS_TESTED: + json_pointer top() const + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent", BasicJsonType())); + } + + json_pointer result = *this; + result.reference_tokens = {reference_tokens[0]}; + return result; + } + + private: + /*! + @brief create and return a reference to the pointed to value + + @complexity Linear in the number of reference tokens. + + @throw parse_error.109 if array index is not a number + @throw type_error.313 if value cannot be unflattened + */ + BasicJsonType& get_and_create(BasicJsonType& j) const + { + auto* result = &j; + + // in case no reference tokens exist, return a reference to the JSON value + // j which will be overwritten by a primitive value + for (const auto& reference_token : reference_tokens) + { + switch (result->type()) + { + case detail::value_t::null: + { + if (reference_token == "0") + { + // start a new array if reference token is 0 + result = &result->operator[](0); + } + else + { + // start a new object otherwise + result = &result->operator[](reference_token); + } + break; + } + + case detail::value_t::object: + { + // create an entry in the object + result = &result->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + // create an entry in the array + result = &result->operator[](array_index(reference_token)); + break; + } + + /* + The following code is only reached if there exists a reference + token _and_ the current value is primitive. In this case, we have + an error situation, because primitive values may only occur as + single value; that is, with an empty list of reference tokens. + */ + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + JSON_THROW(detail::type_error::create(313, "invalid value to unflatten", j)); + } + } + + return *result; + } + + /*! + @brief return a reference to the pointed to value + + @note This version does not throw if a value is not present, but tries to + create nested values instead. For instance, calling this function + with pointer `"/this/that"` on a null value is equivalent to calling + `operator[]("this").operator[]("that")` on that value, effectively + changing the null value to an object. + + @param[in] ptr a JSON value + + @return reference to the JSON value pointed to by the JSON pointer + + @complexity Linear in the length of the JSON pointer. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + BasicJsonType& get_unchecked(BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + // convert null values to arrays or objects before continuing + if (ptr->is_null()) + { + // check if reference token is a number + const bool nums = + std::all_of(reference_token.begin(), reference_token.end(), + [](const unsigned char x) + { + return std::isdigit(x); + }); + + // change value to array for numbers or "-" or to object otherwise + *ptr = (nums || reference_token == "-") + ? detail::value_t::array + : detail::value_t::object; + } + + switch (ptr->type()) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (reference_token == "-") + { + // explicitly treat "-" as index beyond the end + ptr = &ptr->operator[](ptr->m_value.array->size()); + } + else + { + // convert array index to number; unchecked access + ptr = &ptr->operator[](array_index(reference_token)); + } + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'", *ptr)); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + BasicJsonType& get_checked(BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range", *ptr)); + } + + // note: at performs range check + ptr = &ptr->at(array_index(reference_token)); + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'", *ptr)); + } + } + + return *ptr; + } + + /*! + @brief return a const reference to the pointed to value + + @param[in] ptr a JSON value + + @return const reference to the JSON value pointed to by the JSON + pointer + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + const BasicJsonType& get_unchecked(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" cannot be used for const access + JSON_THROW(detail::out_of_range::create(402, "array index '-' (" + std::to_string(ptr->m_value.array->size()) + ") is out of range", *ptr)); + } + + // use unchecked array access + ptr = &ptr->operator[](array_index(reference_token)); + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'", *ptr)); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + const BasicJsonType& get_checked(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range", *ptr)); + } + + // note: at performs range check + ptr = &ptr->at(array_index(reference_token)); + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'", *ptr)); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + */ + bool contains(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + if (!ptr->contains(reference_token)) + { + // we did not find the key in the object + return false; + } + + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + return false; + } + if (JSON_HEDLEY_UNLIKELY(reference_token.size() == 1 && !("0" <= reference_token && reference_token <= "9"))) + { + // invalid char + return false; + } + if (JSON_HEDLEY_UNLIKELY(reference_token.size() > 1)) + { + if (JSON_HEDLEY_UNLIKELY(!('1' <= reference_token[0] && reference_token[0] <= '9'))) + { + // first char should be between '1' and '9' + return false; + } + for (std::size_t i = 1; i < reference_token.size(); i++) + { + if (JSON_HEDLEY_UNLIKELY(!('0' <= reference_token[i] && reference_token[i] <= '9'))) + { + // other char should be between '0' and '9' + return false; + } + } + } + + const auto idx = array_index(reference_token); + if (idx >= ptr->size()) + { + // index out of range + return false; + } + + ptr = &ptr->operator[](idx); + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + { + // we do not expect primitive values if there is still a + // reference token to process + return false; + } + } + } + + // no reference token left means we found a primitive value + return true; + } + + /*! + @brief split the string input to reference tokens + + @note This function is only called by the json_pointer constructor. + All exceptions below are documented there. + + @throw parse_error.107 if the pointer is not empty or begins with '/' + @throw parse_error.108 if character '~' is not followed by '0' or '1' + */ + static std::vector split(const std::string& reference_string) + { + std::vector result; + + // special case: empty reference string -> no reference tokens + if (reference_string.empty()) + { + return result; + } + + // check if nonempty reference string begins with slash + if (JSON_HEDLEY_UNLIKELY(reference_string[0] != '/')) + { + JSON_THROW(detail::parse_error::create(107, 1, "JSON pointer must be empty or begin with '/' - was: '" + reference_string + "'", BasicJsonType())); + } + + // extract the reference tokens: + // - slash: position of the last read slash (or end of string) + // - start: position after the previous slash + for ( + // search for the first slash after the first character + std::size_t slash = reference_string.find_first_of('/', 1), + // set the beginning of the first reference token + start = 1; + // we can stop if start == 0 (if slash == std::string::npos) + start != 0; + // set the beginning of the next reference token + // (will eventually be 0 if slash == std::string::npos) + start = (slash == std::string::npos) ? 0 : slash + 1, + // find next slash + slash = reference_string.find_first_of('/', start)) + { + // use the text between the beginning of the reference token + // (start) and the last slash (slash). + auto reference_token = reference_string.substr(start, slash - start); + + // check reference tokens are properly escaped + for (std::size_t pos = reference_token.find_first_of('~'); + pos != std::string::npos; + pos = reference_token.find_first_of('~', pos + 1)) + { + JSON_ASSERT(reference_token[pos] == '~'); + + // ~ must be followed by 0 or 1 + if (JSON_HEDLEY_UNLIKELY(pos == reference_token.size() - 1 || + (reference_token[pos + 1] != '0' && + reference_token[pos + 1] != '1'))) + { + JSON_THROW(detail::parse_error::create(108, 0, "escape character '~' must be followed with '0' or '1'", BasicJsonType())); + } + } + + // finally, store the reference token + detail::unescape(reference_token); + result.push_back(reference_token); + } + + return result; + } + + private: + /*! + @param[in] reference_string the reference string to the current value + @param[in] value the value to consider + @param[in,out] result the result object to insert values to + + @note Empty objects or arrays are flattened to `null`. + */ + static void flatten(const std::string& reference_string, + const BasicJsonType& value, + BasicJsonType& result) + { + switch (value.type()) + { + case detail::value_t::array: + { + if (value.m_value.array->empty()) + { + // flatten empty array as null + result[reference_string] = nullptr; + } + else + { + // iterate array and use index as reference string + for (std::size_t i = 0; i < value.m_value.array->size(); ++i) + { + flatten(reference_string + "/" + std::to_string(i), + value.m_value.array->operator[](i), result); + } + } + break; + } + + case detail::value_t::object: + { + if (value.m_value.object->empty()) + { + // flatten empty object as null + result[reference_string] = nullptr; + } + else + { + // iterate object and use keys as reference string + for (const auto& element : *value.m_value.object) + { + flatten(reference_string + "/" + detail::escape(element.first), element.second, result); + } + } + break; + } + + case detail::value_t::null: + case detail::value_t::string: + case detail::value_t::boolean: + case detail::value_t::number_integer: + case detail::value_t::number_unsigned: + case detail::value_t::number_float: + case detail::value_t::binary: + case detail::value_t::discarded: + default: + { + // add primitive value with its reference string + result[reference_string] = value; + break; + } + } + } + + /*! + @param[in] value flattened JSON + + @return unflattened JSON + + @throw parse_error.109 if array index is not a number + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + @throw type_error.313 if value cannot be unflattened + */ + static BasicJsonType + unflatten(const BasicJsonType& value) + { + if (JSON_HEDLEY_UNLIKELY(!value.is_object())) + { + JSON_THROW(detail::type_error::create(314, "only objects can be unflattened", value)); + } + + BasicJsonType result; + + // iterate the JSON object values + for (const auto& element : *value.m_value.object) + { + if (JSON_HEDLEY_UNLIKELY(!element.second.is_primitive())) + { + JSON_THROW(detail::type_error::create(315, "values in object must be primitive", element.second)); + } + + // assign value to reference pointed to by JSON pointer; Note that if + // the JSON pointer is "" (i.e., points to the whole value), function + // get_and_create returns a reference to result itself. An assignment + // will then create a primitive value. + json_pointer(element.first).get_and_create(result) = element.second; + } + + return result; + } + + /*! + @brief compares two JSON pointers for equality + + @param[in] lhs JSON pointer to compare + @param[in] rhs JSON pointer to compare + @return whether @a lhs is equal to @a rhs + + @complexity Linear in the length of the JSON pointer + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + */ + friend bool operator==(json_pointer const& lhs, + json_pointer const& rhs) noexcept + { + return lhs.reference_tokens == rhs.reference_tokens; + } + + /*! + @brief compares two JSON pointers for inequality + + @param[in] lhs JSON pointer to compare + @param[in] rhs JSON pointer to compare + @return whether @a lhs is not equal @a rhs + + @complexity Linear in the length of the JSON pointer + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + */ + friend bool operator!=(json_pointer const& lhs, + json_pointer const& rhs) noexcept + { + return !(lhs == rhs); + } + + /// the reference tokens + std::vector reference_tokens; +}; +} // namespace nlohmann + +// #include + + +#include +#include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +class json_ref +{ + public: + using value_type = BasicJsonType; + + json_ref(value_type&& value) + : owned_value(std::move(value)) + {} + + json_ref(const value_type& value) + : value_ref(&value) + {} + + json_ref(std::initializer_list init) + : owned_value(init) + {} + + template < + class... Args, + enable_if_t::value, int> = 0 > + json_ref(Args && ... args) + : owned_value(std::forward(args)...) + {} + + // class should be movable only + json_ref(json_ref&&) noexcept = default; + json_ref(const json_ref&) = delete; + json_ref& operator=(const json_ref&) = delete; + json_ref& operator=(json_ref&&) = delete; + ~json_ref() = default; + + value_type moved_or_copied() const + { + if (value_ref == nullptr) + { + return std::move(owned_value); + } + return *value_ref; + } + + value_type const& operator*() const + { + return value_ref ? *value_ref : owned_value; + } + + value_type const* operator->() const + { + return &** this; + } + + private: + mutable value_type owned_value = nullptr; + value_type const* value_ref = nullptr; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + +// #include + + +#include // reverse +#include // array +#include // isnan, isinf +#include // uint8_t, uint16_t, uint32_t, uint64_t +#include // memcpy +#include // numeric_limits +#include // string +#include // move + +// #include + +// #include + +// #include + + +#include // copy +#include // size_t +#include // back_inserter +#include // shared_ptr, make_shared +#include // basic_string +#include // vector + +#ifndef JSON_NO_IO + #include // streamsize + #include // basic_ostream +#endif // JSON_NO_IO + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/// abstract output adapter interface +template struct output_adapter_protocol +{ + virtual void write_character(CharType c) = 0; + virtual void write_characters(const CharType* s, std::size_t length) = 0; + virtual ~output_adapter_protocol() = default; + + output_adapter_protocol() = default; + output_adapter_protocol(const output_adapter_protocol&) = default; + output_adapter_protocol(output_adapter_protocol&&) noexcept = default; + output_adapter_protocol& operator=(const output_adapter_protocol&) = default; + output_adapter_protocol& operator=(output_adapter_protocol&&) noexcept = default; +}; + +/// a type to simplify interfaces +template +using output_adapter_t = std::shared_ptr>; + +/// output adapter for byte vectors +template> +class output_vector_adapter : public output_adapter_protocol +{ + public: + explicit output_vector_adapter(std::vector& vec) noexcept + : v(vec) + {} + + void write_character(CharType c) override + { + v.push_back(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + std::copy(s, s + length, std::back_inserter(v)); + } + + private: + std::vector& v; +}; + +#ifndef JSON_NO_IO +/// output adapter for output streams +template +class output_stream_adapter : public output_adapter_protocol +{ + public: + explicit output_stream_adapter(std::basic_ostream& s) noexcept + : stream(s) + {} + + void write_character(CharType c) override + { + stream.put(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + stream.write(s, static_cast(length)); + } + + private: + std::basic_ostream& stream; +}; +#endif // JSON_NO_IO + +/// output adapter for basic_string +template> +class output_string_adapter : public output_adapter_protocol +{ + public: + explicit output_string_adapter(StringType& s) noexcept + : str(s) + {} + + void write_character(CharType c) override + { + str.push_back(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + str.append(s, length); + } + + private: + StringType& str; +}; + +template> +class output_adapter +{ + public: + template> + output_adapter(std::vector& vec) + : oa(std::make_shared>(vec)) {} + +#ifndef JSON_NO_IO + output_adapter(std::basic_ostream& s) + : oa(std::make_shared>(s)) {} +#endif // JSON_NO_IO + + output_adapter(StringType& s) + : oa(std::make_shared>(s)) {} + + operator output_adapter_t() + { + return oa; + } + + private: + output_adapter_t oa = nullptr; +}; +} // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ +namespace detail +{ +/////////////////// +// binary writer // +/////////////////// + +/*! +@brief serialization to CBOR and MessagePack values +*/ +template +class binary_writer +{ + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using number_float_t = typename BasicJsonType::number_float_t; + + public: + /*! + @brief create a binary writer + + @param[in] adapter output adapter to write to + */ + explicit binary_writer(output_adapter_t adapter) : oa(std::move(adapter)) + { + JSON_ASSERT(oa); + } + + /*! + @param[in] j JSON value to serialize + @pre j.type() == value_t::object + */ + void write_bson(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::object: + { + write_bson_object(*j.m_value.object); + break; + } + + case value_t::null: + case value_t::array: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + JSON_THROW(type_error::create(317, "to serialize to BSON, top-level type must be object, but is " + std::string(j.type_name()), j)); + } + } + } + + /*! + @param[in] j JSON value to serialize + */ + void write_cbor(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: + { + oa->write_character(to_char_type(0xF6)); + break; + } + + case value_t::boolean: + { + oa->write_character(j.m_value.boolean + ? to_char_type(0xF5) + : to_char_type(0xF4)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // CBOR does not differentiate between positive signed + // integers and unsigned integers. Therefore, we used the + // code from the value_t::number_unsigned case here. + if (j.m_value.number_integer <= 0x17) + { + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x18)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x19)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x1A)); + write_number(static_cast(j.m_value.number_integer)); + } + else + { + oa->write_character(to_char_type(0x1B)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + // The conversions below encode the sign in the first + // byte, and the value is converted to a positive number. + const auto positive_number = -1 - j.m_value.number_integer; + if (j.m_value.number_integer >= -24) + { + write_number(static_cast(0x20 + positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x38)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x39)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x3A)); + write_number(static_cast(positive_number)); + } + else + { + oa->write_character(to_char_type(0x3B)); + write_number(static_cast(positive_number)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned <= 0x17) + { + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x18)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x19)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x1A)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else + { + oa->write_character(to_char_type(0x1B)); + write_number(static_cast(j.m_value.number_unsigned)); + } + break; + } + + case value_t::number_float: + { + if (std::isnan(j.m_value.number_float)) + { + // NaN is 0xf97e00 in CBOR + oa->write_character(to_char_type(0xF9)); + oa->write_character(to_char_type(0x7E)); + oa->write_character(to_char_type(0x00)); + } + else if (std::isinf(j.m_value.number_float)) + { + // Infinity is 0xf97c00, -Infinity is 0xf9fc00 + oa->write_character(to_char_type(0xf9)); + oa->write_character(j.m_value.number_float > 0 ? to_char_type(0x7C) : to_char_type(0xFC)); + oa->write_character(to_char_type(0x00)); + } + else + { + write_compact_float(j.m_value.number_float, detail::input_format_t::cbor); + } + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 0x17) + { + write_number(static_cast(0x60 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x78)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x79)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x7A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x7B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 0x17) + { + write_number(static_cast(0x80 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x98)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x99)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x9A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x9B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_cbor(el); + } + break; + } + + case value_t::binary: + { + if (j.m_value.binary->has_subtype()) + { + if (j.m_value.binary->subtype() <= (std::numeric_limits::max)()) + { + write_number(static_cast(0xd8)); + write_number(static_cast(j.m_value.binary->subtype())); + } + else if (j.m_value.binary->subtype() <= (std::numeric_limits::max)()) + { + write_number(static_cast(0xd9)); + write_number(static_cast(j.m_value.binary->subtype())); + } + else if (j.m_value.binary->subtype() <= (std::numeric_limits::max)()) + { + write_number(static_cast(0xda)); + write_number(static_cast(j.m_value.binary->subtype())); + } + else if (j.m_value.binary->subtype() <= (std::numeric_limits::max)()) + { + write_number(static_cast(0xdb)); + write_number(static_cast(j.m_value.binary->subtype())); + } + } + + // step 1: write control byte and the binary array size + const auto N = j.m_value.binary->size(); + if (N <= 0x17) + { + write_number(static_cast(0x40 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x58)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x59)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x5A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x5B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + N); + + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 0x17) + { + write_number(static_cast(0xA0 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xB8)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xB9)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xBA)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xBB)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_cbor(el.first); + write_cbor(el.second); + } + break; + } + + case value_t::discarded: + default: + break; + } + } + + /*! + @param[in] j JSON value to serialize + */ + void write_msgpack(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: // nil + { + oa->write_character(to_char_type(0xC0)); + break; + } + + case value_t::boolean: // true and false + { + oa->write_character(j.m_value.boolean + ? to_char_type(0xC3) + : to_char_type(0xC2)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // MessagePack does not differentiate between positive + // signed integers and unsigned integers. Therefore, we used + // the code from the value_t::number_unsigned case here. + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(to_char_type(0xCC)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(to_char_type(0xCD)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(to_char_type(0xCE)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(to_char_type(0xCF)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + if (j.m_value.number_integer >= -32) + { + // negative fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 8 + oa->write_character(to_char_type(0xD0)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 16 + oa->write_character(to_char_type(0xD1)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 32 + oa->write_character(to_char_type(0xD2)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 64 + oa->write_character(to_char_type(0xD3)); + write_number(static_cast(j.m_value.number_integer)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(to_char_type(0xCC)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(to_char_type(0xCD)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(to_char_type(0xCE)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(to_char_type(0xCF)); + write_number(static_cast(j.m_value.number_integer)); + } + break; + } + + case value_t::number_float: + { + write_compact_float(j.m_value.number_float, detail::input_format_t::msgpack); + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 31) + { + // fixstr + write_number(static_cast(0xA0 | N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 8 + oa->write_character(to_char_type(0xD9)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 16 + oa->write_character(to_char_type(0xDA)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 32 + oa->write_character(to_char_type(0xDB)); + write_number(static_cast(N)); + } + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 15) + { + // fixarray + write_number(static_cast(0x90 | N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // array 16 + oa->write_character(to_char_type(0xDC)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // array 32 + oa->write_character(to_char_type(0xDD)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_msgpack(el); + } + break; + } + + case value_t::binary: + { + // step 0: determine if the binary type has a set subtype to + // determine whether or not to use the ext or fixext types + const bool use_ext = j.m_value.binary->has_subtype(); + + // step 1: write control byte and the byte string length + const auto N = j.m_value.binary->size(); + if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type{}; + bool fixed = true; + if (use_ext) + { + switch (N) + { + case 1: + output_type = 0xD4; // fixext 1 + break; + case 2: + output_type = 0xD5; // fixext 2 + break; + case 4: + output_type = 0xD6; // fixext 4 + break; + case 8: + output_type = 0xD7; // fixext 8 + break; + case 16: + output_type = 0xD8; // fixext 16 + break; + default: + output_type = 0xC7; // ext 8 + fixed = false; + break; + } + + } + else + { + output_type = 0xC4; // bin 8 + fixed = false; + } + + oa->write_character(to_char_type(output_type)); + if (!fixed) + { + write_number(static_cast(N)); + } + } + else if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type = use_ext + ? 0xC8 // ext 16 + : 0xC5; // bin 16 + + oa->write_character(to_char_type(output_type)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type = use_ext + ? 0xC9 // ext 32 + : 0xC6; // bin 32 + + oa->write_character(to_char_type(output_type)); + write_number(static_cast(N)); + } + + // step 1.5: if this is an ext type, write the subtype + if (use_ext) + { + write_number(static_cast(j.m_value.binary->subtype())); + } + + // step 2: write the byte string + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + N); + + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 15) + { + // fixmap + write_number(static_cast(0x80 | (N & 0xF))); + } + else if (N <= (std::numeric_limits::max)()) + { + // map 16 + oa->write_character(to_char_type(0xDE)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // map 32 + oa->write_character(to_char_type(0xDF)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_msgpack(el.first); + write_msgpack(el.second); + } + break; + } + + case value_t::discarded: + default: + break; + } + } + + /*! + @param[in] j JSON value to serialize + @param[in] use_count whether to use '#' prefixes (optimized format) + @param[in] use_type whether to use '$' prefixes (optimized format) + @param[in] add_prefix whether prefixes need to be used for this value + */ + void write_ubjson(const BasicJsonType& j, const bool use_count, + const bool use_type, const bool add_prefix = true) + { + switch (j.type()) + { + case value_t::null: + { + if (add_prefix) + { + oa->write_character(to_char_type('Z')); + } + break; + } + + case value_t::boolean: + { + if (add_prefix) + { + oa->write_character(j.m_value.boolean + ? to_char_type('T') + : to_char_type('F')); + } + break; + } + + case value_t::number_integer: + { + write_number_with_ubjson_prefix(j.m_value.number_integer, add_prefix); + break; + } + + case value_t::number_unsigned: + { + write_number_with_ubjson_prefix(j.m_value.number_unsigned, add_prefix); + break; + } + + case value_t::number_float: + { + write_number_with_ubjson_prefix(j.m_value.number_float, add_prefix); + break; + } + + case value_t::string: + { + if (add_prefix) + { + oa->write_character(to_char_type('S')); + } + write_number_with_ubjson_prefix(j.m_value.string->size(), true); + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + if (add_prefix) + { + oa->write_character(to_char_type('[')); + } + + bool prefix_required = true; + if (use_type && !j.m_value.array->empty()) + { + JSON_ASSERT(use_count); + const CharType first_prefix = ubjson_prefix(j.front()); + const bool same_prefix = std::all_of(j.begin() + 1, j.end(), + [this, first_prefix](const BasicJsonType & v) + { + return ubjson_prefix(v) == first_prefix; + }); + + if (same_prefix) + { + prefix_required = false; + oa->write_character(to_char_type('$')); + oa->write_character(first_prefix); + } + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.array->size(), true); + } + + for (const auto& el : *j.m_value.array) + { + write_ubjson(el, use_count, use_type, prefix_required); + } + + if (!use_count) + { + oa->write_character(to_char_type(']')); + } + + break; + } + + case value_t::binary: + { + if (add_prefix) + { + oa->write_character(to_char_type('[')); + } + + if (use_type && !j.m_value.binary->empty()) + { + JSON_ASSERT(use_count); + oa->write_character(to_char_type('$')); + oa->write_character('U'); + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.binary->size(), true); + } + + if (use_type) + { + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + j.m_value.binary->size()); + } + else + { + for (size_t i = 0; i < j.m_value.binary->size(); ++i) + { + oa->write_character(to_char_type('U')); + oa->write_character(j.m_value.binary->data()[i]); + } + } + + if (!use_count) + { + oa->write_character(to_char_type(']')); + } + + break; + } + + case value_t::object: + { + if (add_prefix) + { + oa->write_character(to_char_type('{')); + } + + bool prefix_required = true; + if (use_type && !j.m_value.object->empty()) + { + JSON_ASSERT(use_count); + const CharType first_prefix = ubjson_prefix(j.front()); + const bool same_prefix = std::all_of(j.begin(), j.end(), + [this, first_prefix](const BasicJsonType & v) + { + return ubjson_prefix(v) == first_prefix; + }); + + if (same_prefix) + { + prefix_required = false; + oa->write_character(to_char_type('$')); + oa->write_character(first_prefix); + } + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.object->size(), true); + } + + for (const auto& el : *j.m_value.object) + { + write_number_with_ubjson_prefix(el.first.size(), true); + oa->write_characters( + reinterpret_cast(el.first.c_str()), + el.first.size()); + write_ubjson(el.second, use_count, use_type, prefix_required); + } + + if (!use_count) + { + oa->write_character(to_char_type('}')); + } + + break; + } + + case value_t::discarded: + default: + break; + } + } + + private: + ////////// + // BSON // + ////////// + + /*! + @return The size of a BSON document entry header, including the id marker + and the entry name size (and its null-terminator). + */ + static std::size_t calc_bson_entry_header_size(const string_t& name, const BasicJsonType& j) + { + const auto it = name.find(static_cast(0)); + if (JSON_HEDLEY_UNLIKELY(it != BasicJsonType::string_t::npos)) + { + JSON_THROW(out_of_range::create(409, "BSON key cannot contain code point U+0000 (at byte " + std::to_string(it) + ")", j)); + static_cast(j); + } + + return /*id*/ 1ul + name.size() + /*zero-terminator*/1u; + } + + /*! + @brief Writes the given @a element_type and @a name to the output adapter + */ + void write_bson_entry_header(const string_t& name, + const std::uint8_t element_type) + { + oa->write_character(to_char_type(element_type)); // boolean + oa->write_characters( + reinterpret_cast(name.c_str()), + name.size() + 1u); + } + + /*! + @brief Writes a BSON element with key @a name and boolean value @a value + */ + void write_bson_boolean(const string_t& name, + const bool value) + { + write_bson_entry_header(name, 0x08); + oa->write_character(value ? to_char_type(0x01) : to_char_type(0x00)); + } + + /*! + @brief Writes a BSON element with key @a name and double value @a value + */ + void write_bson_double(const string_t& name, + const double value) + { + write_bson_entry_header(name, 0x01); + write_number(value); + } + + /*! + @return The size of the BSON-encoded string in @a value + */ + static std::size_t calc_bson_string_size(const string_t& value) + { + return sizeof(std::int32_t) + value.size() + 1ul; + } + + /*! + @brief Writes a BSON element with key @a name and string value @a value + */ + void write_bson_string(const string_t& name, + const string_t& value) + { + write_bson_entry_header(name, 0x02); + + write_number(static_cast(value.size() + 1ul)); + oa->write_characters( + reinterpret_cast(value.c_str()), + value.size() + 1); + } + + /*! + @brief Writes a BSON element with key @a name and null value + */ + void write_bson_null(const string_t& name) + { + write_bson_entry_header(name, 0x0A); + } + + /*! + @return The size of the BSON-encoded integer @a value + */ + static std::size_t calc_bson_integer_size(const std::int64_t value) + { + return (std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)() + ? sizeof(std::int32_t) + : sizeof(std::int64_t); + } + + /*! + @brief Writes a BSON element with key @a name and integer @a value + */ + void write_bson_integer(const string_t& name, + const std::int64_t value) + { + if ((std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)()) + { + write_bson_entry_header(name, 0x10); // int32 + write_number(static_cast(value)); + } + else + { + write_bson_entry_header(name, 0x12); // int64 + write_number(static_cast(value)); + } + } + + /*! + @return The size of the BSON-encoded unsigned integer in @a j + */ + static constexpr std::size_t calc_bson_unsigned_size(const std::uint64_t value) noexcept + { + return (value <= static_cast((std::numeric_limits::max)())) + ? sizeof(std::int32_t) + : sizeof(std::int64_t); + } + + /*! + @brief Writes a BSON element with key @a name and unsigned @a value + */ + void write_bson_unsigned(const string_t& name, + const BasicJsonType& j) + { + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + write_bson_entry_header(name, 0x10 /* int32 */); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + write_bson_entry_header(name, 0x12 /* int64 */); + write_number(static_cast(j.m_value.number_unsigned)); + } + else + { + JSON_THROW(out_of_range::create(407, "integer number " + std::to_string(j.m_value.number_unsigned) + " cannot be represented by BSON as it does not fit int64", j)); + } + } + + /*! + @brief Writes a BSON element with key @a name and object @a value + */ + void write_bson_object_entry(const string_t& name, + const typename BasicJsonType::object_t& value) + { + write_bson_entry_header(name, 0x03); // object + write_bson_object(value); + } + + /*! + @return The size of the BSON-encoded array @a value + */ + static std::size_t calc_bson_array_size(const typename BasicJsonType::array_t& value) + { + std::size_t array_index = 0ul; + + const std::size_t embedded_document_size = std::accumulate(std::begin(value), std::end(value), std::size_t(0), [&array_index](std::size_t result, const typename BasicJsonType::array_t::value_type & el) + { + return result + calc_bson_element_size(std::to_string(array_index++), el); + }); + + return sizeof(std::int32_t) + embedded_document_size + 1ul; + } + + /*! + @return The size of the BSON-encoded binary array @a value + */ + static std::size_t calc_bson_binary_size(const typename BasicJsonType::binary_t& value) + { + return sizeof(std::int32_t) + value.size() + 1ul; + } + + /*! + @brief Writes a BSON element with key @a name and array @a value + */ + void write_bson_array(const string_t& name, + const typename BasicJsonType::array_t& value) + { + write_bson_entry_header(name, 0x04); // array + write_number(static_cast(calc_bson_array_size(value))); + + std::size_t array_index = 0ul; + + for (const auto& el : value) + { + write_bson_element(std::to_string(array_index++), el); + } + + oa->write_character(to_char_type(0x00)); + } + + /*! + @brief Writes a BSON element with key @a name and binary value @a value + */ + void write_bson_binary(const string_t& name, + const binary_t& value) + { + write_bson_entry_header(name, 0x05); + + write_number(static_cast(value.size())); + write_number(value.has_subtype() ? static_cast(value.subtype()) : std::uint8_t(0x00)); + + oa->write_characters(reinterpret_cast(value.data()), value.size()); + } + + /*! + @brief Calculates the size necessary to serialize the JSON value @a j with its @a name + @return The calculated size for the BSON document entry for @a j with the given @a name. + */ + static std::size_t calc_bson_element_size(const string_t& name, + const BasicJsonType& j) + { + const auto header_size = calc_bson_entry_header_size(name, j); + switch (j.type()) + { + case value_t::object: + return header_size + calc_bson_object_size(*j.m_value.object); + + case value_t::array: + return header_size + calc_bson_array_size(*j.m_value.array); + + case value_t::binary: + return header_size + calc_bson_binary_size(*j.m_value.binary); + + case value_t::boolean: + return header_size + 1ul; + + case value_t::number_float: + return header_size + 8ul; + + case value_t::number_integer: + return header_size + calc_bson_integer_size(j.m_value.number_integer); + + case value_t::number_unsigned: + return header_size + calc_bson_unsigned_size(j.m_value.number_unsigned); + + case value_t::string: + return header_size + calc_bson_string_size(*j.m_value.string); + + case value_t::null: + return header_size + 0ul; + + // LCOV_EXCL_START + case value_t::discarded: + default: + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) + return 0ul; + // LCOV_EXCL_STOP + } + } + + /*! + @brief Serializes the JSON value @a j to BSON and associates it with the + key @a name. + @param name The name to associate with the JSON entity @a j within the + current BSON document + */ + void write_bson_element(const string_t& name, + const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::object: + return write_bson_object_entry(name, *j.m_value.object); + + case value_t::array: + return write_bson_array(name, *j.m_value.array); + + case value_t::binary: + return write_bson_binary(name, *j.m_value.binary); + + case value_t::boolean: + return write_bson_boolean(name, j.m_value.boolean); + + case value_t::number_float: + return write_bson_double(name, j.m_value.number_float); + + case value_t::number_integer: + return write_bson_integer(name, j.m_value.number_integer); + + case value_t::number_unsigned: + return write_bson_unsigned(name, j); + + case value_t::string: + return write_bson_string(name, *j.m_value.string); + + case value_t::null: + return write_bson_null(name); + + // LCOV_EXCL_START + case value_t::discarded: + default: + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) + return; + // LCOV_EXCL_STOP + } + } + + /*! + @brief Calculates the size of the BSON serialization of the given + JSON-object @a j. + @param[in] value JSON value to serialize + @pre value.type() == value_t::object + */ + static std::size_t calc_bson_object_size(const typename BasicJsonType::object_t& value) + { + std::size_t document_size = std::accumulate(value.begin(), value.end(), std::size_t(0), + [](size_t result, const typename BasicJsonType::object_t::value_type & el) + { + return result += calc_bson_element_size(el.first, el.second); + }); + + return sizeof(std::int32_t) + document_size + 1ul; + } + + /*! + @param[in] value JSON value to serialize + @pre value.type() == value_t::object + */ + void write_bson_object(const typename BasicJsonType::object_t& value) + { + write_number(static_cast(calc_bson_object_size(value))); + + for (const auto& el : value) + { + write_bson_element(el.first, el.second); + } + + oa->write_character(to_char_type(0x00)); + } + + ////////// + // CBOR // + ////////// + + static constexpr CharType get_cbor_float_prefix(float /*unused*/) + { + return to_char_type(0xFA); // Single-Precision Float + } + + static constexpr CharType get_cbor_float_prefix(double /*unused*/) + { + return to_char_type(0xFB); // Double-Precision Float + } + + ///////////// + // MsgPack // + ///////////// + + static constexpr CharType get_msgpack_float_prefix(float /*unused*/) + { + return to_char_type(0xCA); // float 32 + } + + static constexpr CharType get_msgpack_float_prefix(double /*unused*/) + { + return to_char_type(0xCB); // float 64 + } + + //////////// + // UBJSON // + //////////// + + // UBJSON: write number (floating point) + template::value, int>::type = 0> + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if (add_prefix) + { + oa->write_character(get_ubjson_float_prefix(n)); + } + write_number(n); + } + + // UBJSON: write number (unsigned integer) + template::value, int>::type = 0> + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('i')); // int8 + } + write_number(static_cast(n)); + } + else if (n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('U')); // uint8 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('I')); // int16 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('l')); // int32 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('L')); // int64 + } + write_number(static_cast(n)); + } + else + { + if (add_prefix) + { + oa->write_character(to_char_type('H')); // high-precision number + } + + const auto number = BasicJsonType(n).dump(); + write_number_with_ubjson_prefix(number.size(), true); + for (std::size_t i = 0; i < number.size(); ++i) + { + oa->write_character(to_char_type(static_cast(number[i]))); + } + } + } + + // UBJSON: write number (signed integer) + template < typename NumberType, typename std::enable_if < + std::is_signed::value&& + !std::is_floating_point::value, int >::type = 0 > + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('i')); // int8 + } + write_number(static_cast(n)); + } + else if (static_cast((std::numeric_limits::min)()) <= n && n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('U')); // uint8 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('I')); // int16 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('l')); // int32 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('L')); // int64 + } + write_number(static_cast(n)); + } + // LCOV_EXCL_START + else + { + if (add_prefix) + { + oa->write_character(to_char_type('H')); // high-precision number + } + + const auto number = BasicJsonType(n).dump(); + write_number_with_ubjson_prefix(number.size(), true); + for (std::size_t i = 0; i < number.size(); ++i) + { + oa->write_character(to_char_type(static_cast(number[i]))); + } + } + // LCOV_EXCL_STOP + } + + /*! + @brief determine the type prefix of container values + */ + CharType ubjson_prefix(const BasicJsonType& j) const noexcept + { + switch (j.type()) + { + case value_t::null: + return 'Z'; + + case value_t::boolean: + return j.m_value.boolean ? 'T' : 'F'; + + case value_t::number_integer: + { + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'i'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'U'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'I'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'l'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'L'; + } + // anything else is treated as high-precision number + return 'H'; // LCOV_EXCL_LINE + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'i'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'U'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'I'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'l'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'L'; + } + // anything else is treated as high-precision number + return 'H'; // LCOV_EXCL_LINE + } + + case value_t::number_float: + return get_ubjson_float_prefix(j.m_value.number_float); + + case value_t::string: + return 'S'; + + case value_t::array: // fallthrough + case value_t::binary: + return '['; + + case value_t::object: + return '{'; + + case value_t::discarded: + default: // discarded values + return 'N'; + } + } + + static constexpr CharType get_ubjson_float_prefix(float /*unused*/) + { + return 'd'; // float 32 + } + + static constexpr CharType get_ubjson_float_prefix(double /*unused*/) + { + return 'D'; // float 64 + } + + /////////////////////// + // Utility functions // + /////////////////////// + + /* + @brief write a number to output input + @param[in] n number of type @a NumberType + @tparam NumberType the type of the number + @tparam OutputIsLittleEndian Set to true if output data is + required to be little endian + + @note This function needs to respect the system's endianess, because bytes + in CBOR, MessagePack, and UBJSON are stored in network order (big + endian) and therefore need reordering on little endian systems. + */ + template + void write_number(const NumberType n) + { + // step 1: write number to array of length NumberType + std::array vec{}; + std::memcpy(vec.data(), &n, sizeof(NumberType)); + + // step 2: write array to output (with possible reordering) + if (is_little_endian != OutputIsLittleEndian) + { + // reverse byte order prior to conversion if necessary + std::reverse(vec.begin(), vec.end()); + } + + oa->write_characters(vec.data(), sizeof(NumberType)); + } + + void write_compact_float(const number_float_t n, detail::input_format_t format) + { +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + if (static_cast(n) >= static_cast(std::numeric_limits::lowest()) && + static_cast(n) <= static_cast((std::numeric_limits::max)()) && + static_cast(static_cast(n)) == static_cast(n)) + { + oa->write_character(format == detail::input_format_t::cbor + ? get_cbor_float_prefix(static_cast(n)) + : get_msgpack_float_prefix(static_cast(n))); + write_number(static_cast(n)); + } + else + { + oa->write_character(format == detail::input_format_t::cbor + ? get_cbor_float_prefix(n) + : get_msgpack_float_prefix(n)); + write_number(n); + } +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + } + + public: + // The following to_char_type functions are implement the conversion + // between uint8_t and CharType. In case CharType is not unsigned, + // such a conversion is required to allow values greater than 128. + // See for a discussion. + template < typename C = CharType, + enable_if_t < std::is_signed::value && std::is_signed::value > * = nullptr > + static constexpr CharType to_char_type(std::uint8_t x) noexcept + { + return *reinterpret_cast(&x); + } + + template < typename C = CharType, + enable_if_t < std::is_signed::value && std::is_unsigned::value > * = nullptr > + static CharType to_char_type(std::uint8_t x) noexcept + { + static_assert(sizeof(std::uint8_t) == sizeof(CharType), "size of CharType must be equal to std::uint8_t"); + static_assert(std::is_trivial::value, "CharType must be trivial"); + CharType result; + std::memcpy(&result, &x, sizeof(x)); + return result; + } + + template::value>* = nullptr> + static constexpr CharType to_char_type(std::uint8_t x) noexcept + { + return x; + } + + template < typename InputCharType, typename C = CharType, + enable_if_t < + std::is_signed::value && + std::is_signed::value && + std::is_same::type>::value + > * = nullptr > + static constexpr CharType to_char_type(InputCharType x) noexcept + { + return x; + } + + private: + /// whether we can assume little endianess + const bool is_little_endian = little_endianess(); + + /// the output + output_adapter_t oa = nullptr; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // reverse, remove, fill, find, none_of +#include // array +#include // localeconv, lconv +#include // labs, isfinite, isnan, signbit +#include // size_t, ptrdiff_t +#include // uint8_t +#include // snprintf +#include // numeric_limits +#include // string, char_traits +#include // setfill, setw +#include // stringstream +#include // is_same +#include // move + +// #include + + +#include // array +#include // signbit, isfinite +#include // intN_t, uintN_t +#include // memcpy, memmove +#include // numeric_limits +#include // conditional + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +/*! +@brief implements the Grisu2 algorithm for binary to decimal floating-point +conversion. + +This implementation is a slightly modified version of the reference +implementation which may be obtained from +http://florian.loitsch.com/publications (bench.tar.gz). + +The code is distributed under the MIT license, Copyright (c) 2009 Florian Loitsch. + +For a detailed description of the algorithm see: + +[1] Loitsch, "Printing Floating-Point Numbers Quickly and Accurately with + Integers", Proceedings of the ACM SIGPLAN 2010 Conference on Programming + Language Design and Implementation, PLDI 2010 +[2] Burger, Dybvig, "Printing Floating-Point Numbers Quickly and Accurately", + Proceedings of the ACM SIGPLAN 1996 Conference on Programming Language + Design and Implementation, PLDI 1996 +*/ +namespace dtoa_impl +{ + +template +Target reinterpret_bits(const Source source) +{ + static_assert(sizeof(Target) == sizeof(Source), "size mismatch"); + + Target target; + std::memcpy(&target, &source, sizeof(Source)); + return target; +} + +struct diyfp // f * 2^e +{ + static constexpr int kPrecision = 64; // = q + + std::uint64_t f = 0; + int e = 0; + + constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {} + + /*! + @brief returns x - y + @pre x.e == y.e and x.f >= y.f + */ + static diyfp sub(const diyfp& x, const diyfp& y) noexcept + { + JSON_ASSERT(x.e == y.e); + JSON_ASSERT(x.f >= y.f); + + return {x.f - y.f, x.e}; + } + + /*! + @brief returns x * y + @note The result is rounded. (Only the upper q bits are returned.) + */ + static diyfp mul(const diyfp& x, const diyfp& y) noexcept + { + static_assert(kPrecision == 64, "internal error"); + + // Computes: + // f = round((x.f * y.f) / 2^q) + // e = x.e + y.e + q + + // Emulate the 64-bit * 64-bit multiplication: + // + // p = u * v + // = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi) + // = (u_lo v_lo ) + 2^32 ((u_lo v_hi ) + (u_hi v_lo )) + 2^64 (u_hi v_hi ) + // = (p0 ) + 2^32 ((p1 ) + (p2 )) + 2^64 (p3 ) + // = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3 ) + // = (p0_lo ) + 2^32 (p0_hi + p1_lo + p2_lo ) + 2^64 (p1_hi + p2_hi + p3) + // = (p0_lo ) + 2^32 (Q ) + 2^64 (H ) + // = (p0_lo ) + 2^32 (Q_lo + 2^32 Q_hi ) + 2^64 (H ) + // + // (Since Q might be larger than 2^32 - 1) + // + // = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H) + // + // (Q_hi + H does not overflow a 64-bit int) + // + // = p_lo + 2^64 p_hi + + const std::uint64_t u_lo = x.f & 0xFFFFFFFFu; + const std::uint64_t u_hi = x.f >> 32u; + const std::uint64_t v_lo = y.f & 0xFFFFFFFFu; + const std::uint64_t v_hi = y.f >> 32u; + + const std::uint64_t p0 = u_lo * v_lo; + const std::uint64_t p1 = u_lo * v_hi; + const std::uint64_t p2 = u_hi * v_lo; + const std::uint64_t p3 = u_hi * v_hi; + + const std::uint64_t p0_hi = p0 >> 32u; + const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu; + const std::uint64_t p1_hi = p1 >> 32u; + const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu; + const std::uint64_t p2_hi = p2 >> 32u; + + std::uint64_t Q = p0_hi + p1_lo + p2_lo; + + // The full product might now be computed as + // + // p_hi = p3 + p2_hi + p1_hi + (Q >> 32) + // p_lo = p0_lo + (Q << 32) + // + // But in this particular case here, the full p_lo is not required. + // Effectively we only need to add the highest bit in p_lo to p_hi (and + // Q_hi + 1 does not overflow). + + Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up + + const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u); + + return {h, x.e + y.e + 64}; + } + + /*! + @brief normalize x such that the significand is >= 2^(q-1) + @pre x.f != 0 + */ + static diyfp normalize(diyfp x) noexcept + { + JSON_ASSERT(x.f != 0); + + while ((x.f >> 63u) == 0) + { + x.f <<= 1u; + x.e--; + } + + return x; + } + + /*! + @brief normalize x such that the result has the exponent E + @pre e >= x.e and the upper e - x.e bits of x.f must be zero. + */ + static diyfp normalize_to(const diyfp& x, const int target_exponent) noexcept + { + const int delta = x.e - target_exponent; + + JSON_ASSERT(delta >= 0); + JSON_ASSERT(((x.f << delta) >> delta) == x.f); + + return {x.f << delta, target_exponent}; + } +}; + +struct boundaries +{ + diyfp w; + diyfp minus; + diyfp plus; +}; + +/*! +Compute the (normalized) diyfp representing the input number 'value' and its +boundaries. + +@pre value must be finite and positive +*/ +template +boundaries compute_boundaries(FloatType value) +{ + JSON_ASSERT(std::isfinite(value)); + JSON_ASSERT(value > 0); + + // Convert the IEEE representation into a diyfp. + // + // If v is denormal: + // value = 0.F * 2^(1 - bias) = ( F) * 2^(1 - bias - (p-1)) + // If v is normalized: + // value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1)) + + static_assert(std::numeric_limits::is_iec559, + "internal error: dtoa_short requires an IEEE-754 floating-point implementation"); + + constexpr int kPrecision = std::numeric_limits::digits; // = p (includes the hidden bit) + constexpr int kBias = std::numeric_limits::max_exponent - 1 + (kPrecision - 1); + constexpr int kMinExp = 1 - kBias; + constexpr std::uint64_t kHiddenBit = std::uint64_t{1} << (kPrecision - 1); // = 2^(p-1) + + using bits_type = typename std::conditional::type; + + const auto bits = static_cast(reinterpret_bits(value)); + const std::uint64_t E = bits >> (kPrecision - 1); + const std::uint64_t F = bits & (kHiddenBit - 1); + + const bool is_denormal = E == 0; + const diyfp v = is_denormal + ? diyfp(F, kMinExp) + : diyfp(F + kHiddenBit, static_cast(E) - kBias); + + // Compute the boundaries m- and m+ of the floating-point value + // v = f * 2^e. + // + // Determine v- and v+, the floating-point predecessor and successor if v, + // respectively. + // + // v- = v - 2^e if f != 2^(p-1) or e == e_min (A) + // = v - 2^(e-1) if f == 2^(p-1) and e > e_min (B) + // + // v+ = v + 2^e + // + // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_ + // between m- and m+ round to v, regardless of how the input rounding + // algorithm breaks ties. + // + // ---+-------------+-------------+-------------+-------------+--- (A) + // v- m- v m+ v+ + // + // -----------------+------+------+-------------+-------------+--- (B) + // v- m- v m+ v+ + + const bool lower_boundary_is_closer = F == 0 && E > 1; + const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1); + const diyfp m_minus = lower_boundary_is_closer + ? diyfp(4 * v.f - 1, v.e - 2) // (B) + : diyfp(2 * v.f - 1, v.e - 1); // (A) + + // Determine the normalized w+ = m+. + const diyfp w_plus = diyfp::normalize(m_plus); + + // Determine w- = m- such that e_(w-) = e_(w+). + const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e); + + return {diyfp::normalize(v), w_minus, w_plus}; +} + +// Given normalized diyfp w, Grisu needs to find a (normalized) cached +// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies +// within a certain range [alpha, gamma] (Definition 3.2 from [1]) +// +// alpha <= e = e_c + e_w + q <= gamma +// +// or +// +// f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q +// <= f_c * f_w * 2^gamma +// +// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies +// +// 2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma +// +// or +// +// 2^(q - 2 + alpha) <= c * w < 2^(q + gamma) +// +// The choice of (alpha,gamma) determines the size of the table and the form of +// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well +// in practice: +// +// The idea is to cut the number c * w = f * 2^e into two parts, which can be +// processed independently: An integral part p1, and a fractional part p2: +// +// f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e +// = (f div 2^-e) + (f mod 2^-e) * 2^e +// = p1 + p2 * 2^e +// +// The conversion of p1 into decimal form requires a series of divisions and +// modulos by (a power of) 10. These operations are faster for 32-bit than for +// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be +// achieved by choosing +// +// -e >= 32 or e <= -32 := gamma +// +// In order to convert the fractional part +// +// p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ... +// +// into decimal form, the fraction is repeatedly multiplied by 10 and the digits +// d[-i] are extracted in order: +// +// (10 * p2) div 2^-e = d[-1] +// (10 * p2) mod 2^-e = d[-2] / 10^1 + ... +// +// The multiplication by 10 must not overflow. It is sufficient to choose +// +// 10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64. +// +// Since p2 = f mod 2^-e < 2^-e, +// +// -e <= 60 or e >= -60 := alpha + +constexpr int kAlpha = -60; +constexpr int kGamma = -32; + +struct cached_power // c = f * 2^e ~= 10^k +{ + std::uint64_t f; + int e; + int k; +}; + +/*! +For a normalized diyfp w = f * 2^e, this function returns a (normalized) cached +power-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c +satisfies (Definition 3.2 from [1]) + + alpha <= e_c + e + q <= gamma. +*/ +inline cached_power get_cached_power_for_binary_exponent(int e) +{ + // Now + // + // alpha <= e_c + e + q <= gamma (1) + // ==> f_c * 2^alpha <= c * 2^e * 2^q + // + // and since the c's are normalized, 2^(q-1) <= f_c, + // + // ==> 2^(q - 1 + alpha) <= c * 2^(e + q) + // ==> 2^(alpha - e - 1) <= c + // + // If c were an exact power of ten, i.e. c = 10^k, one may determine k as + // + // k = ceil( log_10( 2^(alpha - e - 1) ) ) + // = ceil( (alpha - e - 1) * log_10(2) ) + // + // From the paper: + // "In theory the result of the procedure could be wrong since c is rounded, + // and the computation itself is approximated [...]. In practice, however, + // this simple function is sufficient." + // + // For IEEE double precision floating-point numbers converted into + // normalized diyfp's w = f * 2^e, with q = 64, + // + // e >= -1022 (min IEEE exponent) + // -52 (p - 1) + // -52 (p - 1, possibly normalize denormal IEEE numbers) + // -11 (normalize the diyfp) + // = -1137 + // + // and + // + // e <= +1023 (max IEEE exponent) + // -52 (p - 1) + // -11 (normalize the diyfp) + // = 960 + // + // This binary exponent range [-1137,960] results in a decimal exponent + // range [-307,324]. One does not need to store a cached power for each + // k in this range. For each such k it suffices to find a cached power + // such that the exponent of the product lies in [alpha,gamma]. + // This implies that the difference of the decimal exponents of adjacent + // table entries must be less than or equal to + // + // floor( (gamma - alpha) * log_10(2) ) = 8. + // + // (A smaller distance gamma-alpha would require a larger table.) + + // NB: + // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34. + + constexpr int kCachedPowersMinDecExp = -300; + constexpr int kCachedPowersDecStep = 8; + + static constexpr std::array kCachedPowers = + { + { + { 0xAB70FE17C79AC6CA, -1060, -300 }, + { 0xFF77B1FCBEBCDC4F, -1034, -292 }, + { 0xBE5691EF416BD60C, -1007, -284 }, + { 0x8DD01FAD907FFC3C, -980, -276 }, + { 0xD3515C2831559A83, -954, -268 }, + { 0x9D71AC8FADA6C9B5, -927, -260 }, + { 0xEA9C227723EE8BCB, -901, -252 }, + { 0xAECC49914078536D, -874, -244 }, + { 0x823C12795DB6CE57, -847, -236 }, + { 0xC21094364DFB5637, -821, -228 }, + { 0x9096EA6F3848984F, -794, -220 }, + { 0xD77485CB25823AC7, -768, -212 }, + { 0xA086CFCD97BF97F4, -741, -204 }, + { 0xEF340A98172AACE5, -715, -196 }, + { 0xB23867FB2A35B28E, -688, -188 }, + { 0x84C8D4DFD2C63F3B, -661, -180 }, + { 0xC5DD44271AD3CDBA, -635, -172 }, + { 0x936B9FCEBB25C996, -608, -164 }, + { 0xDBAC6C247D62A584, -582, -156 }, + { 0xA3AB66580D5FDAF6, -555, -148 }, + { 0xF3E2F893DEC3F126, -529, -140 }, + { 0xB5B5ADA8AAFF80B8, -502, -132 }, + { 0x87625F056C7C4A8B, -475, -124 }, + { 0xC9BCFF6034C13053, -449, -116 }, + { 0x964E858C91BA2655, -422, -108 }, + { 0xDFF9772470297EBD, -396, -100 }, + { 0xA6DFBD9FB8E5B88F, -369, -92 }, + { 0xF8A95FCF88747D94, -343, -84 }, + { 0xB94470938FA89BCF, -316, -76 }, + { 0x8A08F0F8BF0F156B, -289, -68 }, + { 0xCDB02555653131B6, -263, -60 }, + { 0x993FE2C6D07B7FAC, -236, -52 }, + { 0xE45C10C42A2B3B06, -210, -44 }, + { 0xAA242499697392D3, -183, -36 }, + { 0xFD87B5F28300CA0E, -157, -28 }, + { 0xBCE5086492111AEB, -130, -20 }, + { 0x8CBCCC096F5088CC, -103, -12 }, + { 0xD1B71758E219652C, -77, -4 }, + { 0x9C40000000000000, -50, 4 }, + { 0xE8D4A51000000000, -24, 12 }, + { 0xAD78EBC5AC620000, 3, 20 }, + { 0x813F3978F8940984, 30, 28 }, + { 0xC097CE7BC90715B3, 56, 36 }, + { 0x8F7E32CE7BEA5C70, 83, 44 }, + { 0xD5D238A4ABE98068, 109, 52 }, + { 0x9F4F2726179A2245, 136, 60 }, + { 0xED63A231D4C4FB27, 162, 68 }, + { 0xB0DE65388CC8ADA8, 189, 76 }, + { 0x83C7088E1AAB65DB, 216, 84 }, + { 0xC45D1DF942711D9A, 242, 92 }, + { 0x924D692CA61BE758, 269, 100 }, + { 0xDA01EE641A708DEA, 295, 108 }, + { 0xA26DA3999AEF774A, 322, 116 }, + { 0xF209787BB47D6B85, 348, 124 }, + { 0xB454E4A179DD1877, 375, 132 }, + { 0x865B86925B9BC5C2, 402, 140 }, + { 0xC83553C5C8965D3D, 428, 148 }, + { 0x952AB45CFA97A0B3, 455, 156 }, + { 0xDE469FBD99A05FE3, 481, 164 }, + { 0xA59BC234DB398C25, 508, 172 }, + { 0xF6C69A72A3989F5C, 534, 180 }, + { 0xB7DCBF5354E9BECE, 561, 188 }, + { 0x88FCF317F22241E2, 588, 196 }, + { 0xCC20CE9BD35C78A5, 614, 204 }, + { 0x98165AF37B2153DF, 641, 212 }, + { 0xE2A0B5DC971F303A, 667, 220 }, + { 0xA8D9D1535CE3B396, 694, 228 }, + { 0xFB9B7CD9A4A7443C, 720, 236 }, + { 0xBB764C4CA7A44410, 747, 244 }, + { 0x8BAB8EEFB6409C1A, 774, 252 }, + { 0xD01FEF10A657842C, 800, 260 }, + { 0x9B10A4E5E9913129, 827, 268 }, + { 0xE7109BFBA19C0C9D, 853, 276 }, + { 0xAC2820D9623BF429, 880, 284 }, + { 0x80444B5E7AA7CF85, 907, 292 }, + { 0xBF21E44003ACDD2D, 933, 300 }, + { 0x8E679C2F5E44FF8F, 960, 308 }, + { 0xD433179D9C8CB841, 986, 316 }, + { 0x9E19DB92B4E31BA9, 1013, 324 }, + } + }; + + // This computation gives exactly the same results for k as + // k = ceil((kAlpha - e - 1) * 0.30102999566398114) + // for |e| <= 1500, but doesn't require floating-point operations. + // NB: log_10(2) ~= 78913 / 2^18 + JSON_ASSERT(e >= -1500); + JSON_ASSERT(e <= 1500); + const int f = kAlpha - e - 1; + const int k = (f * 78913) / (1 << 18) + static_cast(f > 0); + + const int index = (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / kCachedPowersDecStep; + JSON_ASSERT(index >= 0); + JSON_ASSERT(static_cast(index) < kCachedPowers.size()); + + const cached_power cached = kCachedPowers[static_cast(index)]; + JSON_ASSERT(kAlpha <= cached.e + e + 64); + JSON_ASSERT(kGamma >= cached.e + e + 64); + + return cached; +} + +/*! +For n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k. +For n == 0, returns 1 and sets pow10 := 1. +*/ +inline int find_largest_pow10(const std::uint32_t n, std::uint32_t& pow10) +{ + // LCOV_EXCL_START + if (n >= 1000000000) + { + pow10 = 1000000000; + return 10; + } + // LCOV_EXCL_STOP + if (n >= 100000000) + { + pow10 = 100000000; + return 9; + } + if (n >= 10000000) + { + pow10 = 10000000; + return 8; + } + if (n >= 1000000) + { + pow10 = 1000000; + return 7; + } + if (n >= 100000) + { + pow10 = 100000; + return 6; + } + if (n >= 10000) + { + pow10 = 10000; + return 5; + } + if (n >= 1000) + { + pow10 = 1000; + return 4; + } + if (n >= 100) + { + pow10 = 100; + return 3; + } + if (n >= 10) + { + pow10 = 10; + return 2; + } + + pow10 = 1; + return 1; +} + +inline void grisu2_round(char* buf, int len, std::uint64_t dist, std::uint64_t delta, + std::uint64_t rest, std::uint64_t ten_k) +{ + JSON_ASSERT(len >= 1); + JSON_ASSERT(dist <= delta); + JSON_ASSERT(rest <= delta); + JSON_ASSERT(ten_k > 0); + + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // ten_k + // <------> + // <---- rest ----> + // --------------[------------------+----+--------------]-------------- + // w V + // = buf * 10^k + // + // ten_k represents a unit-in-the-last-place in the decimal representation + // stored in buf. + // Decrement buf by ten_k while this takes buf closer to w. + + // The tests are written in this order to avoid overflow in unsigned + // integer arithmetic. + + while (rest < dist + && delta - rest >= ten_k + && (rest + ten_k < dist || dist - rest > rest + ten_k - dist)) + { + JSON_ASSERT(buf[len - 1] != '0'); + buf[len - 1]--; + rest += ten_k; + } +} + +/*! +Generates V = buffer * 10^decimal_exponent, such that M- <= V <= M+. +M- and M+ must be normalized and share the same exponent -60 <= e <= -32. +*/ +inline void grisu2_digit_gen(char* buffer, int& length, int& decimal_exponent, + diyfp M_minus, diyfp w, diyfp M_plus) +{ + static_assert(kAlpha >= -60, "internal error"); + static_assert(kGamma <= -32, "internal error"); + + // Generates the digits (and the exponent) of a decimal floating-point + // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The diyfp's + // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= gamma. + // + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // Grisu2 generates the digits of M+ from left to right and stops as soon as + // V is in [M-,M+]. + + JSON_ASSERT(M_plus.e >= kAlpha); + JSON_ASSERT(M_plus.e <= kGamma); + + std::uint64_t delta = diyfp::sub(M_plus, M_minus).f; // (significand of (M+ - M-), implicit exponent is e) + std::uint64_t dist = diyfp::sub(M_plus, w ).f; // (significand of (M+ - w ), implicit exponent is e) + + // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0): + // + // M+ = f * 2^e + // = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e + // = ((p1 ) * 2^-e + (p2 )) * 2^e + // = p1 + p2 * 2^e + + const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e); + + auto p1 = static_cast(M_plus.f >> -one.e); // p1 = f div 2^-e (Since -e >= 32, p1 fits into a 32-bit int.) + std::uint64_t p2 = M_plus.f & (one.f - 1); // p2 = f mod 2^-e + + // 1) + // + // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0] + + JSON_ASSERT(p1 > 0); + + std::uint32_t pow10{}; + const int k = find_largest_pow10(p1, pow10); + + // 10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1) + // + // p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1)) + // = (d[k-1] ) * 10^(k-1) + (p1 mod 10^(k-1)) + // + // M+ = p1 + p2 * 2^e + // = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1)) + p2 * 2^e + // = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e + // = d[k-1] * 10^(k-1) + ( rest) * 2^e + // + // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0) + // + // p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0] + // + // but stop as soon as + // + // rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e + + int n = k; + while (n > 0) + { + // Invariants: + // M+ = buffer * 10^n + (p1 + p2 * 2^e) (buffer = 0 for n = k) + // pow10 = 10^(n-1) <= p1 < 10^n + // + const std::uint32_t d = p1 / pow10; // d = p1 div 10^(n-1) + const std::uint32_t r = p1 % pow10; // r = p1 mod 10^(n-1) + // + // M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e + // = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e) + // + JSON_ASSERT(d <= 9); + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(n-1) + (r + p2 * 2^e) + // + p1 = r; + n--; + // + // M+ = buffer * 10^n + (p1 + p2 * 2^e) + // pow10 = 10^n + // + + // Now check if enough digits have been generated. + // Compute + // + // p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e + // + // Note: + // Since rest and delta share the same exponent e, it suffices to + // compare the significands. + const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2; + if (rest <= delta) + { + // V = buffer * 10^n, with M- <= V <= M+. + + decimal_exponent += n; + + // We may now just stop. But instead look if the buffer could be + // decremented to bring V closer to w. + // + // pow10 = 10^n is now 1 ulp in the decimal representation V. + // The rounding procedure works with diyfp's with an implicit + // exponent of e. + // + // 10^n = (10^n * 2^-e) * 2^e = ulp * 2^e + // + const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e; + grisu2_round(buffer, length, dist, delta, rest, ten_n); + + return; + } + + pow10 /= 10; + // + // pow10 = 10^(n-1) <= p1 < 10^n + // Invariants restored. + } + + // 2) + // + // The digits of the integral part have been generated: + // + // M+ = d[k-1]...d[1]d[0] + p2 * 2^e + // = buffer + p2 * 2^e + // + // Now generate the digits of the fractional part p2 * 2^e. + // + // Note: + // No decimal point is generated: the exponent is adjusted instead. + // + // p2 actually represents the fraction + // + // p2 * 2^e + // = p2 / 2^-e + // = d[-1] / 10^1 + d[-2] / 10^2 + ... + // + // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...) + // + // p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m + // + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...) + // + // using + // + // 10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e) + // = ( d) * 2^-e + ( r) + // + // or + // 10^m * p2 * 2^e = d + r * 2^e + // + // i.e. + // + // M+ = buffer + p2 * 2^e + // = buffer + 10^-m * (d + r * 2^e) + // = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e + // + // and stop as soon as 10^-m * r * 2^e <= delta * 2^e + + JSON_ASSERT(p2 > delta); + + int m = 0; + for (;;) + { + // Invariant: + // M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + ...) * 2^e + // = buffer * 10^-m + 10^-m * (p2 ) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * (10 * p2) ) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + (10*p2 mod 2^-e)) * 2^e + // + JSON_ASSERT(p2 <= (std::numeric_limits::max)() / 10); + p2 *= 10; + const std::uint64_t d = p2 >> -one.e; // d = (10 * p2) div 2^-e + const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e + // + // M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e)) + // = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + JSON_ASSERT(d <= 9); + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + p2 = r; + m++; + // + // M+ = buffer * 10^-m + 10^-m * p2 * 2^e + // Invariant restored. + + // Check if enough digits have been generated. + // + // 10^-m * p2 * 2^e <= delta * 2^e + // p2 * 2^e <= 10^m * delta * 2^e + // p2 <= 10^m * delta + delta *= 10; + dist *= 10; + if (p2 <= delta) + { + break; + } + } + + // V = buffer * 10^-m, with M- <= V <= M+. + + decimal_exponent -= m; + + // 1 ulp in the decimal representation is now 10^-m. + // Since delta and dist are now scaled by 10^m, we need to do the + // same with ulp in order to keep the units in sync. + // + // 10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e + // + const std::uint64_t ten_m = one.f; + grisu2_round(buffer, length, dist, delta, p2, ten_m); + + // By construction this algorithm generates the shortest possible decimal + // number (Loitsch, Theorem 6.2) which rounds back to w. + // For an input number of precision p, at least + // + // N = 1 + ceil(p * log_10(2)) + // + // decimal digits are sufficient to identify all binary floating-point + // numbers (Matula, "In-and-Out conversions"). + // This implies that the algorithm does not produce more than N decimal + // digits. + // + // N = 17 for p = 53 (IEEE double precision) + // N = 9 for p = 24 (IEEE single precision) +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +JSON_HEDLEY_NON_NULL(1) +inline void grisu2(char* buf, int& len, int& decimal_exponent, + diyfp m_minus, diyfp v, diyfp m_plus) +{ + JSON_ASSERT(m_plus.e == m_minus.e); + JSON_ASSERT(m_plus.e == v.e); + + // --------(-----------------------+-----------------------)-------- (A) + // m- v m+ + // + // --------------------(-----------+-----------------------)-------- (B) + // m- v m+ + // + // First scale v (and m- and m+) such that the exponent is in the range + // [alpha, gamma]. + + const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e); + + const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k + + // The exponent of the products is = v.e + c_minus_k.e + q and is in the range [alpha,gamma] + const diyfp w = diyfp::mul(v, c_minus_k); + const diyfp w_minus = diyfp::mul(m_minus, c_minus_k); + const diyfp w_plus = diyfp::mul(m_plus, c_minus_k); + + // ----(---+---)---------------(---+---)---------------(---+---)---- + // w- w w+ + // = c*m- = c*v = c*m+ + // + // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and + // w+ are now off by a small amount. + // In fact: + // + // w - v * 10^k < 1 ulp + // + // To account for this inaccuracy, add resp. subtract 1 ulp. + // + // --------+---[---------------(---+---)---------------]---+-------- + // w- M- w M+ w+ + // + // Now any number in [M-, M+] (bounds included) will round to w when input, + // regardless of how the input rounding algorithm breaks ties. + // + // And digit_gen generates the shortest possible such number in [M-, M+]. + // Note that this does not mean that Grisu2 always generates the shortest + // possible number in the interval (m-, m+). + const diyfp M_minus(w_minus.f + 1, w_minus.e); + const diyfp M_plus (w_plus.f - 1, w_plus.e ); + + decimal_exponent = -cached.k; // = -(-k) = k + + grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus); +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +template +JSON_HEDLEY_NON_NULL(1) +void grisu2(char* buf, int& len, int& decimal_exponent, FloatType value) +{ + static_assert(diyfp::kPrecision >= std::numeric_limits::digits + 3, + "internal error: not enough precision"); + + JSON_ASSERT(std::isfinite(value)); + JSON_ASSERT(value > 0); + + // If the neighbors (and boundaries) of 'value' are always computed for double-precision + // numbers, all float's can be recovered using strtod (and strtof). However, the resulting + // decimal representations are not exactly "short". + // + // The documentation for 'std::to_chars' (https://en.cppreference.com/w/cpp/utility/to_chars) + // says "value is converted to a string as if by std::sprintf in the default ("C") locale" + // and since sprintf promotes float's to double's, I think this is exactly what 'std::to_chars' + // does. + // On the other hand, the documentation for 'std::to_chars' requires that "parsing the + // representation using the corresponding std::from_chars function recovers value exactly". That + // indicates that single precision floating-point numbers should be recovered using + // 'std::strtof'. + // + // NB: If the neighbors are computed for single-precision numbers, there is a single float + // (7.0385307e-26f) which can't be recovered using strtod. The resulting double precision + // value is off by 1 ulp. +#if 0 + const boundaries w = compute_boundaries(static_cast(value)); +#else + const boundaries w = compute_boundaries(value); +#endif + + grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus); +} + +/*! +@brief appends a decimal representation of e to buf +@return a pointer to the element following the exponent. +@pre -1000 < e < 1000 +*/ +JSON_HEDLEY_NON_NULL(1) +JSON_HEDLEY_RETURNS_NON_NULL +inline char* append_exponent(char* buf, int e) +{ + JSON_ASSERT(e > -1000); + JSON_ASSERT(e < 1000); + + if (e < 0) + { + e = -e; + *buf++ = '-'; + } + else + { + *buf++ = '+'; + } + + auto k = static_cast(e); + if (k < 10) + { + // Always print at least two digits in the exponent. + // This is for compatibility with printf("%g"). + *buf++ = '0'; + *buf++ = static_cast('0' + k); + } + else if (k < 100) + { + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } + else + { + *buf++ = static_cast('0' + k / 100); + k %= 100; + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } + + return buf; +} + +/*! +@brief prettify v = buf * 10^decimal_exponent + +If v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point +notation. Otherwise it will be printed in exponential notation. + +@pre min_exp < 0 +@pre max_exp > 0 +*/ +JSON_HEDLEY_NON_NULL(1) +JSON_HEDLEY_RETURNS_NON_NULL +inline char* format_buffer(char* buf, int len, int decimal_exponent, + int min_exp, int max_exp) +{ + JSON_ASSERT(min_exp < 0); + JSON_ASSERT(max_exp > 0); + + const int k = len; + const int n = len + decimal_exponent; + + // v = buf * 10^(n-k) + // k is the length of the buffer (number of decimal digits) + // n is the position of the decimal point relative to the start of the buffer. + + if (k <= n && n <= max_exp) + { + // digits[000] + // len <= max_exp + 2 + + std::memset(buf + k, '0', static_cast(n) - static_cast(k)); + // Make it look like a floating-point number (#362, #378) + buf[n + 0] = '.'; + buf[n + 1] = '0'; + return buf + (static_cast(n) + 2); + } + + if (0 < n && n <= max_exp) + { + // dig.its + // len <= max_digits10 + 1 + + JSON_ASSERT(k > n); + + std::memmove(buf + (static_cast(n) + 1), buf + n, static_cast(k) - static_cast(n)); + buf[n] = '.'; + return buf + (static_cast(k) + 1U); + } + + if (min_exp < n && n <= 0) + { + // 0.[000]digits + // len <= 2 + (-min_exp - 1) + max_digits10 + + std::memmove(buf + (2 + static_cast(-n)), buf, static_cast(k)); + buf[0] = '0'; + buf[1] = '.'; + std::memset(buf + 2, '0', static_cast(-n)); + return buf + (2U + static_cast(-n) + static_cast(k)); + } + + if (k == 1) + { + // dE+123 + // len <= 1 + 5 + + buf += 1; + } + else + { + // d.igitsE+123 + // len <= max_digits10 + 1 + 5 + + std::memmove(buf + 2, buf + 1, static_cast(k) - 1); + buf[1] = '.'; + buf += 1 + static_cast(k); + } + + *buf++ = 'e'; + return append_exponent(buf, n - 1); +} + +} // namespace dtoa_impl + +/*! +@brief generates a decimal representation of the floating-point number value in [first, last). + +The format of the resulting decimal representation is similar to printf's %g +format. Returns an iterator pointing past-the-end of the decimal representation. + +@note The input number must be finite, i.e. NaN's and Inf's are not supported. +@note The buffer must be large enough. +@note The result is NOT null-terminated. +*/ +template +JSON_HEDLEY_NON_NULL(1, 2) +JSON_HEDLEY_RETURNS_NON_NULL +char* to_chars(char* first, const char* last, FloatType value) +{ + static_cast(last); // maybe unused - fix warning + JSON_ASSERT(std::isfinite(value)); + + // Use signbit(value) instead of (value < 0) since signbit works for -0. + if (std::signbit(value)) + { + value = -value; + *first++ = '-'; + } + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + if (value == 0) // +-0 + { + *first++ = '0'; + // Make it look like a floating-point number (#362, #378) + *first++ = '.'; + *first++ = '0'; + return first; + } +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + + JSON_ASSERT(last - first >= std::numeric_limits::max_digits10); + + // Compute v = buffer * 10^decimal_exponent. + // The decimal digits are stored in the buffer, which needs to be interpreted + // as an unsigned decimal integer. + // len is the length of the buffer, i.e. the number of decimal digits. + int len = 0; + int decimal_exponent = 0; + dtoa_impl::grisu2(first, len, decimal_exponent, value); + + JSON_ASSERT(len <= std::numeric_limits::max_digits10); + + // Format the buffer like printf("%.*g", prec, value) + constexpr int kMinExp = -4; + // Use digits10 here to increase compatibility with version 2. + constexpr int kMaxExp = std::numeric_limits::digits10; + + JSON_ASSERT(last - first >= kMaxExp + 2); + JSON_ASSERT(last - first >= 2 + (-kMinExp - 1) + std::numeric_limits::max_digits10); + JSON_ASSERT(last - first >= std::numeric_limits::max_digits10 + 6); + + return dtoa_impl::format_buffer(first, len, decimal_exponent, kMinExp, kMaxExp); +} + +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/////////////////// +// serialization // +/////////////////// + +/// how to treat decoding errors +enum class error_handler_t +{ + strict, ///< throw a type_error exception in case of invalid UTF-8 + replace, ///< replace invalid UTF-8 sequences with U+FFFD + ignore ///< ignore invalid UTF-8 sequences +}; + +template +class serializer +{ + using string_t = typename BasicJsonType::string_t; + using number_float_t = typename BasicJsonType::number_float_t; + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using binary_char_t = typename BasicJsonType::binary_t::value_type; + static constexpr std::uint8_t UTF8_ACCEPT = 0; + static constexpr std::uint8_t UTF8_REJECT = 1; + + public: + /*! + @param[in] s output stream to serialize to + @param[in] ichar indentation character to use + @param[in] error_handler_ how to react on decoding errors + */ + serializer(output_adapter_t s, const char ichar, + error_handler_t error_handler_ = error_handler_t::strict) + : o(std::move(s)) + , loc(std::localeconv()) + , thousands_sep(loc->thousands_sep == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->thousands_sep))) + , decimal_point(loc->decimal_point == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->decimal_point))) + , indent_char(ichar) + , indent_string(512, indent_char) + , error_handler(error_handler_) + {} + + // delete because of pointer members + serializer(const serializer&) = delete; + serializer& operator=(const serializer&) = delete; + serializer(serializer&&) = delete; + serializer& operator=(serializer&&) = delete; + ~serializer() = default; + + /*! + @brief internal implementation of the serialization function + + This function is called by the public member function dump and organizes + the serialization internally. The indentation level is propagated as + additional parameter. In case of arrays and objects, the function is + called recursively. + + - strings and object keys are escaped using `escape_string()` + - integer numbers are converted implicitly via `operator<<` + - floating-point numbers are converted to a string using `"%g"` format + - binary values are serialized as objects containing the subtype and the + byte array + + @param[in] val value to serialize + @param[in] pretty_print whether the output shall be pretty-printed + @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters + in the output are escaped with `\uXXXX` sequences, and the result consists + of ASCII characters only. + @param[in] indent_step the indent level + @param[in] current_indent the current indent level (only used internally) + */ + void dump(const BasicJsonType& val, + const bool pretty_print, + const bool ensure_ascii, + const unsigned int indent_step, + const unsigned int current_indent = 0) + { + switch (val.m_type) + { + case value_t::object: + { + if (val.m_value.object->empty()) + { + o->write_characters("{}", 2); + return; + } + + if (pretty_print) + { + o->write_characters("{\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + JSON_ASSERT(i != val.m_value.object->cend()); + JSON_ASSERT(std::next(i) == val.m_value.object->cend()); + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character('}'); + } + else + { + o->write_character('{'); + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + JSON_ASSERT(i != val.m_value.object->cend()); + JSON_ASSERT(std::next(i) == val.m_value.object->cend()); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + + o->write_character('}'); + } + + return; + } + + case value_t::array: + { + if (val.m_value.array->empty()) + { + o->write_characters("[]", 2); + return; + } + + if (pretty_print) + { + o->write_characters("[\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + dump(*i, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + JSON_ASSERT(!val.m_value.array->empty()); + o->write_characters(indent_string.c_str(), new_indent); + dump(val.m_value.array->back(), true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character(']'); + } + else + { + o->write_character('['); + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + dump(*i, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + JSON_ASSERT(!val.m_value.array->empty()); + dump(val.m_value.array->back(), false, ensure_ascii, indent_step, current_indent); + + o->write_character(']'); + } + + return; + } + + case value_t::string: + { + o->write_character('\"'); + dump_escaped(*val.m_value.string, ensure_ascii); + o->write_character('\"'); + return; + } + + case value_t::binary: + { + if (pretty_print) + { + o->write_characters("{\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + o->write_characters(indent_string.c_str(), new_indent); + + o->write_characters("\"bytes\": [", 10); + + if (!val.m_value.binary->empty()) + { + for (auto i = val.m_value.binary->cbegin(); + i != val.m_value.binary->cend() - 1; ++i) + { + dump_integer(*i); + o->write_characters(", ", 2); + } + dump_integer(val.m_value.binary->back()); + } + + o->write_characters("],\n", 3); + o->write_characters(indent_string.c_str(), new_indent); + + o->write_characters("\"subtype\": ", 11); + if (val.m_value.binary->has_subtype()) + { + dump_integer(val.m_value.binary->subtype()); + } + else + { + o->write_characters("null", 4); + } + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character('}'); + } + else + { + o->write_characters("{\"bytes\":[", 10); + + if (!val.m_value.binary->empty()) + { + for (auto i = val.m_value.binary->cbegin(); + i != val.m_value.binary->cend() - 1; ++i) + { + dump_integer(*i); + o->write_character(','); + } + dump_integer(val.m_value.binary->back()); + } + + o->write_characters("],\"subtype\":", 12); + if (val.m_value.binary->has_subtype()) + { + dump_integer(val.m_value.binary->subtype()); + o->write_character('}'); + } + else + { + o->write_characters("null}", 5); + } + } + return; + } + + case value_t::boolean: + { + if (val.m_value.boolean) + { + o->write_characters("true", 4); + } + else + { + o->write_characters("false", 5); + } + return; + } + + case value_t::number_integer: + { + dump_integer(val.m_value.number_integer); + return; + } + + case value_t::number_unsigned: + { + dump_integer(val.m_value.number_unsigned); + return; + } + + case value_t::number_float: + { + dump_float(val.m_value.number_float); + return; + } + + case value_t::discarded: + { + o->write_characters("", 11); + return; + } + + case value_t::null: + { + o->write_characters("null", 4); + return; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + } + + JSON_PRIVATE_UNLESS_TESTED: + /*! + @brief dump escaped string + + Escape a string by replacing certain special characters by a sequence of an + escape character (backslash) and another character and other control + characters by a sequence of "\u" followed by a four-digit hex + representation. The escaped string is written to output stream @a o. + + @param[in] s the string to escape + @param[in] ensure_ascii whether to escape non-ASCII characters with + \uXXXX sequences + + @complexity Linear in the length of string @a s. + */ + void dump_escaped(const string_t& s, const bool ensure_ascii) + { + std::uint32_t codepoint{}; + std::uint8_t state = UTF8_ACCEPT; + std::size_t bytes = 0; // number of bytes written to string_buffer + + // number of bytes written at the point of the last valid byte + std::size_t bytes_after_last_accept = 0; + std::size_t undumped_chars = 0; + + for (std::size_t i = 0; i < s.size(); ++i) + { + const auto byte = static_cast(s[i]); + + switch (decode(state, codepoint, byte)) + { + case UTF8_ACCEPT: // decode found a new code point + { + switch (codepoint) + { + case 0x08: // backspace + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'b'; + break; + } + + case 0x09: // horizontal tab + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 't'; + break; + } + + case 0x0A: // newline + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'n'; + break; + } + + case 0x0C: // formfeed + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'f'; + break; + } + + case 0x0D: // carriage return + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'r'; + break; + } + + case 0x22: // quotation mark + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = '\"'; + break; + } + + case 0x5C: // reverse solidus + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = '\\'; + break; + } + + default: + { + // escape control characters (0x00..0x1F) or, if + // ensure_ascii parameter is used, non-ASCII characters + if ((codepoint <= 0x1F) || (ensure_ascii && (codepoint >= 0x7F))) + { + if (codepoint <= 0xFFFF) + { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + (std::snprintf)(string_buffer.data() + bytes, 7, "\\u%04x", + static_cast(codepoint)); + bytes += 6; + } + else + { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + (std::snprintf)(string_buffer.data() + bytes, 13, "\\u%04x\\u%04x", + static_cast(0xD7C0u + (codepoint >> 10u)), + static_cast(0xDC00u + (codepoint & 0x3FFu))); + bytes += 12; + } + } + else + { + // copy byte to buffer (all previous bytes + // been copied have in default case above) + string_buffer[bytes++] = s[i]; + } + break; + } + } + + // write buffer and reset index; there must be 13 bytes + // left, as this is the maximal number of bytes to be + // written ("\uxxxx\uxxxx\0") for one code point + if (string_buffer.size() - bytes < 13) + { + o->write_characters(string_buffer.data(), bytes); + bytes = 0; + } + + // remember the byte position of this accept + bytes_after_last_accept = bytes; + undumped_chars = 0; + break; + } + + case UTF8_REJECT: // decode found invalid UTF-8 byte + { + switch (error_handler) + { + case error_handler_t::strict: + { + std::stringstream ss; + ss << std::uppercase << std::setfill('0') << std::setw(2) << std::hex << (byte | 0); + JSON_THROW(type_error::create(316, "invalid UTF-8 byte at index " + std::to_string(i) + ": 0x" + ss.str(), BasicJsonType())); + } + + case error_handler_t::ignore: + case error_handler_t::replace: + { + // in case we saw this character the first time, we + // would like to read it again, because the byte + // may be OK for itself, but just not OK for the + // previous sequence + if (undumped_chars > 0) + { + --i; + } + + // reset length buffer to the last accepted index; + // thus removing/ignoring the invalid characters + bytes = bytes_after_last_accept; + + if (error_handler == error_handler_t::replace) + { + // add a replacement character + if (ensure_ascii) + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'u'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'd'; + } + else + { + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xEF'); + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBF'); + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBD'); + } + + // write buffer and reset index; there must be 13 bytes + // left, as this is the maximal number of bytes to be + // written ("\uxxxx\uxxxx\0") for one code point + if (string_buffer.size() - bytes < 13) + { + o->write_characters(string_buffer.data(), bytes); + bytes = 0; + } + + bytes_after_last_accept = bytes; + } + + undumped_chars = 0; + + // continue processing the string + state = UTF8_ACCEPT; + break; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + break; + } + + default: // decode found yet incomplete multi-byte code point + { + if (!ensure_ascii) + { + // code point will not be escaped - copy byte to buffer + string_buffer[bytes++] = s[i]; + } + ++undumped_chars; + break; + } + } + } + + // we finished processing the string + if (JSON_HEDLEY_LIKELY(state == UTF8_ACCEPT)) + { + // write buffer + if (bytes > 0) + { + o->write_characters(string_buffer.data(), bytes); + } + } + else + { + // we finish reading, but do not accept: string was incomplete + switch (error_handler) + { + case error_handler_t::strict: + { + std::stringstream ss; + ss << std::uppercase << std::setfill('0') << std::setw(2) << std::hex << (static_cast(s.back()) | 0); + JSON_THROW(type_error::create(316, "incomplete UTF-8 string; last byte: 0x" + ss.str(), BasicJsonType())); + } + + case error_handler_t::ignore: + { + // write all accepted bytes + o->write_characters(string_buffer.data(), bytes_after_last_accept); + break; + } + + case error_handler_t::replace: + { + // write all accepted bytes + o->write_characters(string_buffer.data(), bytes_after_last_accept); + // add a replacement character + if (ensure_ascii) + { + o->write_characters("\\ufffd", 6); + } + else + { + o->write_characters("\xEF\xBF\xBD", 3); + } + break; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + } + } + + private: + /*! + @brief count digits + + Count the number of decimal (base 10) digits for an input unsigned integer. + + @param[in] x unsigned integer number to count its digits + @return number of decimal digits + */ + inline unsigned int count_digits(number_unsigned_t x) noexcept + { + unsigned int n_digits = 1; + for (;;) + { + if (x < 10) + { + return n_digits; + } + if (x < 100) + { + return n_digits + 1; + } + if (x < 1000) + { + return n_digits + 2; + } + if (x < 10000) + { + return n_digits + 3; + } + x = x / 10000u; + n_digits += 4; + } + } + + /*! + @brief dump an integer + + Dump a given integer to output stream @a o. Works internally with + @a number_buffer. + + @param[in] x integer number (signed or unsigned) to dump + @tparam NumberType either @a number_integer_t or @a number_unsigned_t + */ + template < typename NumberType, detail::enable_if_t < + std::is_integral::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + int > = 0 > + void dump_integer(NumberType x) + { + static constexpr std::array, 100> digits_to_99 + { + { + {{'0', '0'}}, {{'0', '1'}}, {{'0', '2'}}, {{'0', '3'}}, {{'0', '4'}}, {{'0', '5'}}, {{'0', '6'}}, {{'0', '7'}}, {{'0', '8'}}, {{'0', '9'}}, + {{'1', '0'}}, {{'1', '1'}}, {{'1', '2'}}, {{'1', '3'}}, {{'1', '4'}}, {{'1', '5'}}, {{'1', '6'}}, {{'1', '7'}}, {{'1', '8'}}, {{'1', '9'}}, + {{'2', '0'}}, {{'2', '1'}}, {{'2', '2'}}, {{'2', '3'}}, {{'2', '4'}}, {{'2', '5'}}, {{'2', '6'}}, {{'2', '7'}}, {{'2', '8'}}, {{'2', '9'}}, + {{'3', '0'}}, {{'3', '1'}}, {{'3', '2'}}, {{'3', '3'}}, {{'3', '4'}}, {{'3', '5'}}, {{'3', '6'}}, {{'3', '7'}}, {{'3', '8'}}, {{'3', '9'}}, + {{'4', '0'}}, {{'4', '1'}}, {{'4', '2'}}, {{'4', '3'}}, {{'4', '4'}}, {{'4', '5'}}, {{'4', '6'}}, {{'4', '7'}}, {{'4', '8'}}, {{'4', '9'}}, + {{'5', '0'}}, {{'5', '1'}}, {{'5', '2'}}, {{'5', '3'}}, {{'5', '4'}}, {{'5', '5'}}, {{'5', '6'}}, {{'5', '7'}}, {{'5', '8'}}, {{'5', '9'}}, + {{'6', '0'}}, {{'6', '1'}}, {{'6', '2'}}, {{'6', '3'}}, {{'6', '4'}}, {{'6', '5'}}, {{'6', '6'}}, {{'6', '7'}}, {{'6', '8'}}, {{'6', '9'}}, + {{'7', '0'}}, {{'7', '1'}}, {{'7', '2'}}, {{'7', '3'}}, {{'7', '4'}}, {{'7', '5'}}, {{'7', '6'}}, {{'7', '7'}}, {{'7', '8'}}, {{'7', '9'}}, + {{'8', '0'}}, {{'8', '1'}}, {{'8', '2'}}, {{'8', '3'}}, {{'8', '4'}}, {{'8', '5'}}, {{'8', '6'}}, {{'8', '7'}}, {{'8', '8'}}, {{'8', '9'}}, + {{'9', '0'}}, {{'9', '1'}}, {{'9', '2'}}, {{'9', '3'}}, {{'9', '4'}}, {{'9', '5'}}, {{'9', '6'}}, {{'9', '7'}}, {{'9', '8'}}, {{'9', '9'}}, + } + }; + + // special case for "0" + if (x == 0) + { + o->write_character('0'); + return; + } + + // use a pointer to fill the buffer + auto buffer_ptr = number_buffer.begin(); // NOLINT(llvm-qualified-auto,readability-qualified-auto,cppcoreguidelines-pro-type-vararg,hicpp-vararg) + + const bool is_negative = std::is_signed::value && !(x >= 0); // see issue #755 + number_unsigned_t abs_value; + + unsigned int n_chars{}; + + if (is_negative) + { + *buffer_ptr = '-'; + abs_value = remove_sign(static_cast(x)); + + // account one more byte for the minus sign + n_chars = 1 + count_digits(abs_value); + } + else + { + abs_value = static_cast(x); + n_chars = count_digits(abs_value); + } + + // spare 1 byte for '\0' + JSON_ASSERT(n_chars < number_buffer.size() - 1); + + // jump to the end to generate the string from backward + // so we later avoid reversing the result + buffer_ptr += n_chars; + + // Fast int2ascii implementation inspired by "Fastware" talk by Andrei Alexandrescu + // See: https://www.youtube.com/watch?v=o4-CwDo2zpg + while (abs_value >= 100) + { + const auto digits_index = static_cast((abs_value % 100)); + abs_value /= 100; + *(--buffer_ptr) = digits_to_99[digits_index][1]; + *(--buffer_ptr) = digits_to_99[digits_index][0]; + } + + if (abs_value >= 10) + { + const auto digits_index = static_cast(abs_value); + *(--buffer_ptr) = digits_to_99[digits_index][1]; + *(--buffer_ptr) = digits_to_99[digits_index][0]; + } + else + { + *(--buffer_ptr) = static_cast('0' + abs_value); + } + + o->write_characters(number_buffer.data(), n_chars); + } + + /*! + @brief dump a floating-point number + + Dump a given floating-point number to output stream @a o. Works internally + with @a number_buffer. + + @param[in] x floating-point number to dump + */ + void dump_float(number_float_t x) + { + // NaN / inf + if (!std::isfinite(x)) + { + o->write_characters("null", 4); + return; + } + + // If number_float_t is an IEEE-754 single or double precision number, + // use the Grisu2 algorithm to produce short numbers which are + // guaranteed to round-trip, using strtof and strtod, resp. + // + // NB: The test below works if == . + static constexpr bool is_ieee_single_or_double + = (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 24 && std::numeric_limits::max_exponent == 128) || + (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 53 && std::numeric_limits::max_exponent == 1024); + + dump_float(x, std::integral_constant()); + } + + void dump_float(number_float_t x, std::true_type /*is_ieee_single_or_double*/) + { + auto* begin = number_buffer.data(); + auto* end = ::nlohmann::detail::to_chars(begin, begin + number_buffer.size(), x); + + o->write_characters(begin, static_cast(end - begin)); + } + + void dump_float(number_float_t x, std::false_type /*is_ieee_single_or_double*/) + { + // get number of digits for a float -> text -> float round-trip + static constexpr auto d = std::numeric_limits::max_digits10; + + // the actual conversion + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg) + std::ptrdiff_t len = (std::snprintf)(number_buffer.data(), number_buffer.size(), "%.*g", d, x); + + // negative value indicates an error + JSON_ASSERT(len > 0); + // check if buffer was large enough + JSON_ASSERT(static_cast(len) < number_buffer.size()); + + // erase thousands separator + if (thousands_sep != '\0') + { + // NOLINTNEXTLINE(readability-qualified-auto,llvm-qualified-auto): std::remove returns an iterator, see https://github.com/nlohmann/json/issues/3081 + const auto end = std::remove(number_buffer.begin(), number_buffer.begin() + len, thousands_sep); + std::fill(end, number_buffer.end(), '\0'); + JSON_ASSERT((end - number_buffer.begin()) <= len); + len = (end - number_buffer.begin()); + } + + // convert decimal point to '.' + if (decimal_point != '\0' && decimal_point != '.') + { + // NOLINTNEXTLINE(readability-qualified-auto,llvm-qualified-auto): std::find returns an iterator, see https://github.com/nlohmann/json/issues/3081 + const auto dec_pos = std::find(number_buffer.begin(), number_buffer.end(), decimal_point); + if (dec_pos != number_buffer.end()) + { + *dec_pos = '.'; + } + } + + o->write_characters(number_buffer.data(), static_cast(len)); + + // determine if need to append ".0" + const bool value_is_int_like = + std::none_of(number_buffer.begin(), number_buffer.begin() + len + 1, + [](char c) + { + return c == '.' || c == 'e'; + }); + + if (value_is_int_like) + { + o->write_characters(".0", 2); + } + } + + /*! + @brief check whether a string is UTF-8 encoded + + The function checks each byte of a string whether it is UTF-8 encoded. The + result of the check is stored in the @a state parameter. The function must + be called initially with state 0 (accept). State 1 means the string must + be rejected, because the current byte is not allowed. If the string is + completely processed, but the state is non-zero, the string ended + prematurely; that is, the last byte indicated more bytes should have + followed. + + @param[in,out] state the state of the decoding + @param[in,out] codep codepoint (valid only if resulting state is UTF8_ACCEPT) + @param[in] byte next byte to decode + @return new state + + @note The function has been edited: a std::array is used. + + @copyright Copyright (c) 2008-2009 Bjoern Hoehrmann + @sa http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + */ + static std::uint8_t decode(std::uint8_t& state, std::uint32_t& codep, const std::uint8_t byte) noexcept + { + static const std::array utf8d = + { + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9F + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // A0..BF + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C0..DF + 0xA, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // E0..EF + 0xB, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // F0..FF + 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 + 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 + 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 + 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 // s7..s8 + } + }; + + JSON_ASSERT(byte < utf8d.size()); + const std::uint8_t type = utf8d[byte]; + + codep = (state != UTF8_ACCEPT) + ? (byte & 0x3fu) | (codep << 6u) + : (0xFFu >> type) & (byte); + + std::size_t index = 256u + static_cast(state) * 16u + static_cast(type); + JSON_ASSERT(index < 400); + state = utf8d[index]; + return state; + } + + /* + * Overload to make the compiler happy while it is instantiating + * dump_integer for number_unsigned_t. + * Must never be called. + */ + number_unsigned_t remove_sign(number_unsigned_t x) + { + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + return x; // LCOV_EXCL_LINE + } + + /* + * Helper function for dump_integer + * + * This function takes a negative signed integer and returns its absolute + * value as unsigned integer. The plus/minus shuffling is necessary as we can + * not directly remove the sign of an arbitrary signed integer as the + * absolute values of INT_MIN and INT_MAX are usually not the same. See + * #1708 for details. + */ + inline number_unsigned_t remove_sign(number_integer_t x) noexcept + { + JSON_ASSERT(x < 0 && x < (std::numeric_limits::max)()); // NOLINT(misc-redundant-expression) + return static_cast(-(x + 1)) + 1; + } + + private: + /// the output of the serializer + output_adapter_t o = nullptr; + + /// a (hopefully) large enough character buffer + std::array number_buffer{{}}; + + /// the locale + const std::lconv* loc = nullptr; + /// the locale's thousand separator character + const char thousands_sep = '\0'; + /// the locale's decimal point character + const char decimal_point = '\0'; + + /// string buffer + std::array string_buffer{{}}; + + /// the indentation character + const char indent_char; + /// the indentation string + string_t indent_string; + + /// error_handler how to react on decoding errors + const error_handler_t error_handler; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include // less +#include // initializer_list +#include // input_iterator_tag, iterator_traits +#include // allocator +#include // for out_of_range +#include // enable_if, is_convertible +#include // pair +#include // vector + +// #include + + +namespace nlohmann +{ + +/// ordered_map: a minimal map-like container that preserves insertion order +/// for use within nlohmann::basic_json +template , + class Allocator = std::allocator>> + struct ordered_map : std::vector, Allocator> +{ + using key_type = Key; + using mapped_type = T; + using Container = std::vector, Allocator>; + using typename Container::iterator; + using typename Container::const_iterator; + using typename Container::size_type; + using typename Container::value_type; + + // Explicit constructors instead of `using Container::Container` + // otherwise older compilers choke on it (GCC <= 5.5, xcode <= 9.4) + ordered_map(const Allocator& alloc = Allocator()) : Container{alloc} {} + template + ordered_map(It first, It last, const Allocator& alloc = Allocator()) + : Container{first, last, alloc} {} + ordered_map(std::initializer_list init, const Allocator& alloc = Allocator() ) + : Container{init, alloc} {} + + std::pair emplace(const key_type& key, T&& t) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return {it, false}; + } + } + Container::emplace_back(key, t); + return {--this->end(), true}; + } + + T& operator[](const Key& key) + { + return emplace(key, T{}).first->second; + } + + const T& operator[](const Key& key) const + { + return at(key); + } + + T& at(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it->second; + } + } + + JSON_THROW(std::out_of_range("key not found")); + } + + const T& at(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it->second; + } + } + + JSON_THROW(std::out_of_range("key not found")); + } + + size_type erase(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + // Since we cannot move const Keys, re-construct them in place + for (auto next = it; ++next != this->end(); ++it) + { + it->~value_type(); // Destroy but keep allocation + new (&*it) value_type{std::move(*next)}; + } + Container::pop_back(); + return 1; + } + } + return 0; + } + + iterator erase(iterator pos) + { + auto it = pos; + + // Since we cannot move const Keys, re-construct them in place + for (auto next = it; ++next != this->end(); ++it) + { + it->~value_type(); // Destroy but keep allocation + new (&*it) value_type{std::move(*next)}; + } + Container::pop_back(); + return pos; + } + + size_type count(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return 1; + } + } + return 0; + } + + iterator find(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it; + } + } + return Container::end(); + } + + const_iterator find(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it; + } + } + return Container::end(); + } + + std::pair insert( value_type&& value ) + { + return emplace(value.first, std::move(value.second)); + } + + std::pair insert( const value_type& value ) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == value.first) + { + return {it, false}; + } + } + Container::push_back(value); + return {--this->end(), true}; + } + + template + using require_input_iter = typename std::enable_if::iterator_category, + std::input_iterator_tag>::value>::type; + + template> + void insert(InputIt first, InputIt last) + { + for (auto it = first; it != last; ++it) + { + insert(*it); + } + } +}; + +} // namespace nlohmann + + +#if defined(JSON_HAS_CPP_17) + #include +#endif + +/*! +@brief namespace for Niels Lohmann +@see https://github.com/nlohmann +@since version 1.0.0 +*/ +namespace nlohmann +{ + +/*! +@brief a class to store JSON values + +@tparam ObjectType type for JSON objects (`std::map` by default; will be used +in @ref object_t) +@tparam ArrayType type for JSON arrays (`std::vector` by default; will be used +in @ref array_t) +@tparam StringType type for JSON strings and object keys (`std::string` by +default; will be used in @ref string_t) +@tparam BooleanType type for JSON booleans (`bool` by default; will be used +in @ref boolean_t) +@tparam NumberIntegerType type for JSON integer numbers (`int64_t` by +default; will be used in @ref number_integer_t) +@tparam NumberUnsignedType type for JSON unsigned integer numbers (@c +`uint64_t` by default; will be used in @ref number_unsigned_t) +@tparam NumberFloatType type for JSON floating-point numbers (`double` by +default; will be used in @ref number_float_t) +@tparam BinaryType type for packed binary data for compatibility with binary +serialization formats (`std::vector` by default; will be used in +@ref binary_t) +@tparam AllocatorType type of the allocator to use (`std::allocator` by +default) +@tparam JSONSerializer the serializer to resolve internal calls to `to_json()` +and `from_json()` (@ref adl_serializer by default) + +@requirement The class satisfies the following concept requirements: +- Basic + - [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible): + JSON values can be default constructed. The result will be a JSON null + value. + - [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible): + A JSON value can be constructed from an rvalue argument. + - [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible): + A JSON value can be copy-constructed from an lvalue expression. + - [MoveAssignable](https://en.cppreference.com/w/cpp/named_req/MoveAssignable): + A JSON value van be assigned from an rvalue argument. + - [CopyAssignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable): + A JSON value can be copy-assigned from an lvalue expression. + - [Destructible](https://en.cppreference.com/w/cpp/named_req/Destructible): + JSON values can be destructed. +- Layout + - [StandardLayoutType](https://en.cppreference.com/w/cpp/named_req/StandardLayoutType): + JSON values have + [standard layout](https://en.cppreference.com/w/cpp/language/data_members#Standard_layout): + All non-static data members are private and standard layout types, the + class has no virtual functions or (virtual) base classes. +- Library-wide + - [EqualityComparable](https://en.cppreference.com/w/cpp/named_req/EqualityComparable): + JSON values can be compared with `==`, see @ref + operator==(const_reference,const_reference). + - [LessThanComparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable): + JSON values can be compared with `<`, see @ref + operator<(const_reference,const_reference). + - [Swappable](https://en.cppreference.com/w/cpp/named_req/Swappable): + Any JSON lvalue or rvalue of can be swapped with any lvalue or rvalue of + other compatible types, using unqualified function call @ref swap(). + - [NullablePointer](https://en.cppreference.com/w/cpp/named_req/NullablePointer): + JSON values can be compared against `std::nullptr_t` objects which are used + to model the `null` value. +- Container + - [Container](https://en.cppreference.com/w/cpp/named_req/Container): + JSON values can be used like STL containers and provide iterator access. + - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer); + JSON values can be used like STL containers and provide reverse iterator + access. + +@invariant The member variables @a m_value and @a m_type have the following +relationship: +- If `m_type == value_t::object`, then `m_value.object != nullptr`. +- If `m_type == value_t::array`, then `m_value.array != nullptr`. +- If `m_type == value_t::string`, then `m_value.string != nullptr`. +The invariants are checked by member function assert_invariant(). + +@internal +@note ObjectType trick from https://stackoverflow.com/a/9860911 +@endinternal + +@see [RFC 8259: The JavaScript Object Notation (JSON) Data Interchange +Format](https://tools.ietf.org/html/rfc8259) + +@since version 1.0.0 + +@nosubgrouping +*/ +NLOHMANN_BASIC_JSON_TPL_DECLARATION +class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-special-member-functions) +{ + private: + template friend struct detail::external_constructor; + friend ::nlohmann::json_pointer; + + template + friend class ::nlohmann::detail::parser; + friend ::nlohmann::detail::serializer; + template + friend class ::nlohmann::detail::iter_impl; + template + friend class ::nlohmann::detail::binary_writer; + template + friend class ::nlohmann::detail::binary_reader; + template + friend class ::nlohmann::detail::json_sax_dom_parser; + template + friend class ::nlohmann::detail::json_sax_dom_callback_parser; + friend class ::nlohmann::detail::exception; + + /// workaround type for MSVC + using basic_json_t = NLOHMANN_BASIC_JSON_TPL; + + JSON_PRIVATE_UNLESS_TESTED: + // convenience aliases for types residing in namespace detail; + using lexer = ::nlohmann::detail::lexer_base; + + template + static ::nlohmann::detail::parser parser( + InputAdapterType adapter, + detail::parser_callback_tcb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false + ) + { + return ::nlohmann::detail::parser(std::move(adapter), + std::move(cb), allow_exceptions, ignore_comments); + } + + private: + using primitive_iterator_t = ::nlohmann::detail::primitive_iterator_t; + template + using internal_iterator = ::nlohmann::detail::internal_iterator; + template + using iter_impl = ::nlohmann::detail::iter_impl; + template + using iteration_proxy = ::nlohmann::detail::iteration_proxy; + template using json_reverse_iterator = ::nlohmann::detail::json_reverse_iterator; + + template + using output_adapter_t = ::nlohmann::detail::output_adapter_t; + + template + using binary_reader = ::nlohmann::detail::binary_reader; + template using binary_writer = ::nlohmann::detail::binary_writer; + + JSON_PRIVATE_UNLESS_TESTED: + using serializer = ::nlohmann::detail::serializer; + + public: + using value_t = detail::value_t; + /// JSON Pointer, see @ref nlohmann::json_pointer + using json_pointer = ::nlohmann::json_pointer; + template + using json_serializer = JSONSerializer; + /// how to treat decoding errors + using error_handler_t = detail::error_handler_t; + /// how to treat CBOR tags + using cbor_tag_handler_t = detail::cbor_tag_handler_t; + /// helper type for initializer lists of basic_json values + using initializer_list_t = std::initializer_list>; + + using input_format_t = detail::input_format_t; + /// SAX interface type, see @ref nlohmann::json_sax + using json_sax_t = json_sax; + + //////////////// + // exceptions // + //////////////// + + /// @name exceptions + /// Classes to implement user-defined exceptions. + /// @{ + + /// @copydoc detail::exception + using exception = detail::exception; + /// @copydoc detail::parse_error + using parse_error = detail::parse_error; + /// @copydoc detail::invalid_iterator + using invalid_iterator = detail::invalid_iterator; + /// @copydoc detail::type_error + using type_error = detail::type_error; + /// @copydoc detail::out_of_range + using out_of_range = detail::out_of_range; + /// @copydoc detail::other_error + using other_error = detail::other_error; + + /// @} + + + ///////////////////// + // container types // + ///////////////////// + + /// @name container types + /// The canonic container types to use @ref basic_json like any other STL + /// container. + /// @{ + + /// the type of elements in a basic_json container + using value_type = basic_json; + + /// the type of an element reference + using reference = value_type&; + /// the type of an element const reference + using const_reference = const value_type&; + + /// a type to represent differences between iterators + using difference_type = std::ptrdiff_t; + /// a type to represent container sizes + using size_type = std::size_t; + + /// the allocator type + using allocator_type = AllocatorType; + + /// the type of an element pointer + using pointer = typename std::allocator_traits::pointer; + /// the type of an element const pointer + using const_pointer = typename std::allocator_traits::const_pointer; + + /// an iterator for a basic_json container + using iterator = iter_impl; + /// a const iterator for a basic_json container + using const_iterator = iter_impl; + /// a reverse iterator for a basic_json container + using reverse_iterator = json_reverse_iterator; + /// a const reverse iterator for a basic_json container + using const_reverse_iterator = json_reverse_iterator; + + /// @} + + + /*! + @brief returns the allocator associated with the container + */ + static allocator_type get_allocator() + { + return allocator_type(); + } + + /*! + @brief returns version information on the library + + This function returns a JSON object with information about the library, + including the version number and information on the platform and compiler. + + @return JSON object holding version information + key | description + ----------- | --------------- + `compiler` | Information on the used compiler. It is an object with the following keys: `c++` (the used C++ standard), `family` (the compiler family; possible values are `clang`, `icc`, `gcc`, `ilecpp`, `msvc`, `pgcpp`, `sunpro`, and `unknown`), and `version` (the compiler version). + `copyright` | The copyright line for the library as string. + `name` | The name of the library as string. + `platform` | The used platform as string. Possible values are `win32`, `linux`, `apple`, `unix`, and `unknown`. + `url` | The URL of the project as string. + `version` | The version of the library. It is an object with the following keys: `major`, `minor`, and `patch` as defined by [Semantic Versioning](http://semver.org), and `string` (the version string). + + @liveexample{The following code shows an example output of the `meta()` + function.,meta} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @complexity Constant. + + @since 2.1.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json meta() + { + basic_json result; + + result["copyright"] = "(C) 2013-2021 Niels Lohmann"; + result["name"] = "JSON for Modern C++"; + result["url"] = "https://github.com/nlohmann/json"; + result["version"]["string"] = + std::to_string(NLOHMANN_JSON_VERSION_MAJOR) + "." + + std::to_string(NLOHMANN_JSON_VERSION_MINOR) + "." + + std::to_string(NLOHMANN_JSON_VERSION_PATCH); + result["version"]["major"] = NLOHMANN_JSON_VERSION_MAJOR; + result["version"]["minor"] = NLOHMANN_JSON_VERSION_MINOR; + result["version"]["patch"] = NLOHMANN_JSON_VERSION_PATCH; + +#ifdef _WIN32 + result["platform"] = "win32"; +#elif defined __linux__ + result["platform"] = "linux"; +#elif defined __APPLE__ + result["platform"] = "apple"; +#elif defined __unix__ + result["platform"] = "unix"; +#else + result["platform"] = "unknown"; +#endif + +#if defined(__ICC) || defined(__INTEL_COMPILER) + result["compiler"] = {{"family", "icc"}, {"version", __INTEL_COMPILER}}; +#elif defined(__clang__) + result["compiler"] = {{"family", "clang"}, {"version", __clang_version__}}; +#elif defined(__GNUC__) || defined(__GNUG__) + result["compiler"] = {{"family", "gcc"}, {"version", std::to_string(__GNUC__) + "." + std::to_string(__GNUC_MINOR__) + "." + std::to_string(__GNUC_PATCHLEVEL__)}}; +#elif defined(__HP_cc) || defined(__HP_aCC) + result["compiler"] = "hp" +#elif defined(__IBMCPP__) + result["compiler"] = {{"family", "ilecpp"}, {"version", __IBMCPP__}}; +#elif defined(_MSC_VER) + result["compiler"] = {{"family", "msvc"}, {"version", _MSC_VER}}; +#elif defined(__PGI) + result["compiler"] = {{"family", "pgcpp"}, {"version", __PGI}}; +#elif defined(__SUNPRO_CC) + result["compiler"] = {{"family", "sunpro"}, {"version", __SUNPRO_CC}}; +#else + result["compiler"] = {{"family", "unknown"}, {"version", "unknown"}}; +#endif + +#ifdef __cplusplus + result["compiler"]["c++"] = std::to_string(__cplusplus); +#else + result["compiler"]["c++"] = "unknown"; +#endif + return result; + } + + + /////////////////////////// + // JSON value data types // + /////////////////////////// + + /// @name JSON value data types + /// The data types to store a JSON value. These types are derived from + /// the template arguments passed to class @ref basic_json. + /// @{ + +#if defined(JSON_HAS_CPP_14) + // Use transparent comparator if possible, combined with perfect forwarding + // on find() and count() calls prevents unnecessary string construction. + using object_comparator_t = std::less<>; +#else + using object_comparator_t = std::less; +#endif + + /*! + @brief a type for an object + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes JSON objects as follows: + > An object is an unordered collection of zero or more name/value pairs, + > where a name is a string and a value is a string, number, boolean, null, + > object, or array. + + To store objects in C++, a type is defined by the template parameters + described below. + + @tparam ObjectType the container to store objects (e.g., `std::map` or + `std::unordered_map`) + @tparam StringType the type of the keys or names (e.g., `std::string`). + The comparison function `std::less` is used to order elements + inside the container. + @tparam AllocatorType the allocator to use for objects (e.g., + `std::allocator`) + + #### Default type + + With the default values for @a ObjectType (`std::map`), @a StringType + (`std::string`), and @a AllocatorType (`std::allocator`), the default + value for @a object_t is: + + @code {.cpp} + std::map< + std::string, // key_type + basic_json, // value_type + std::less, // key_compare + std::allocator> // allocator_type + > + @endcode + + #### Behavior + + The choice of @a object_t influences the behavior of the JSON class. With + the default type, objects have the following behavior: + + - When all names are unique, objects will be interoperable in the sense + that all software implementations receiving that object will agree on + the name-value mappings. + - When the names within an object are not unique, it is unspecified which + one of the values for a given key will be chosen. For instance, + `{"key": 2, "key": 1}` could be equal to either `{"key": 1}` or + `{"key": 2}`. + - Internally, name/value pairs are stored in lexicographical order of the + names. Objects will also be serialized (see @ref dump) in this order. + For instance, `{"b": 1, "a": 2}` and `{"a": 2, "b": 1}` will be stored + and serialized as `{"a": 2, "b": 1}`. + - When comparing objects, the order of the name/value pairs is irrelevant. + This makes objects interoperable in the sense that they will not be + affected by these differences. For instance, `{"b": 1, "a": 2}` and + `{"a": 2, "b": 1}` will be treated as equal. + + #### Limits + + [RFC 8259](https://tools.ietf.org/html/rfc8259) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the object's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON object. + + #### Storage + + Objects are stored as pointers in a @ref basic_json type. That is, for any + access to object values, a pointer of type `object_t*` must be + dereferenced. + + @sa see @ref array_t -- type for an array value + + @since version 1.0.0 + + @note The order name/value pairs are added to the object is *not* + preserved by the library. Therefore, iterating an object may return + name/value pairs in a different order than they were originally stored. In + fact, keys will be traversed in alphabetical order as `std::map` with + `std::less` is used by default. Please note this behavior conforms to [RFC + 8259](https://tools.ietf.org/html/rfc8259), because any order implements the + specified "unordered" nature of JSON objects. + */ + using object_t = ObjectType>>; + + /*! + @brief a type for an array + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes JSON arrays as follows: + > An array is an ordered sequence of zero or more values. + + To store objects in C++, a type is defined by the template parameters + explained below. + + @tparam ArrayType container type to store arrays (e.g., `std::vector` or + `std::list`) + @tparam AllocatorType allocator to use for arrays (e.g., `std::allocator`) + + #### Default type + + With the default values for @a ArrayType (`std::vector`) and @a + AllocatorType (`std::allocator`), the default value for @a array_t is: + + @code {.cpp} + std::vector< + basic_json, // value_type + std::allocator // allocator_type + > + @endcode + + #### Limits + + [RFC 8259](https://tools.ietf.org/html/rfc8259) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the array's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON array. + + #### Storage + + Arrays are stored as pointers in a @ref basic_json type. That is, for any + access to array values, a pointer of type `array_t*` must be dereferenced. + + @sa see @ref object_t -- type for an object value + + @since version 1.0.0 + */ + using array_t = ArrayType>; + + /*! + @brief a type for a string + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes JSON strings as follows: + > A string is a sequence of zero or more Unicode characters. + + To store objects in C++, a type is defined by the template parameter + described below. Unicode values are split by the JSON class into + byte-sized characters during deserialization. + + @tparam StringType the container to store strings (e.g., `std::string`). + Note this container is used for keys/names in objects, see @ref object_t. + + #### Default type + + With the default values for @a StringType (`std::string`), the default + value for @a string_t is: + + @code {.cpp} + std::string + @endcode + + #### Encoding + + Strings are stored in UTF-8 encoding. Therefore, functions like + `std::string::size()` or `std::string::length()` return the number of + bytes in the string rather than the number of characters or glyphs. + + #### String comparison + + [RFC 8259](https://tools.ietf.org/html/rfc8259) states: + > Software implementations are typically required to test names of object + > members for equality. Implementations that transform the textual + > representation into sequences of Unicode code units and then perform the + > comparison numerically, code unit by code unit, are interoperable in the + > sense that implementations will agree in all cases on equality or + > inequality of two strings. For example, implementations that compare + > strings with escaped characters unconverted may incorrectly find that + > `"a\\b"` and `"a\u005Cb"` are not equal. + + This implementation is interoperable as it does compare strings code unit + by code unit. + + #### Storage + + String values are stored as pointers in a @ref basic_json type. That is, + for any access to string values, a pointer of type `string_t*` must be + dereferenced. + + @since version 1.0.0 + */ + using string_t = StringType; + + /*! + @brief a type for a boolean + + [RFC 8259](https://tools.ietf.org/html/rfc8259) implicitly describes a boolean as a + type which differentiates the two literals `true` and `false`. + + To store objects in C++, a type is defined by the template parameter @a + BooleanType which chooses the type to use. + + #### Default type + + With the default values for @a BooleanType (`bool`), the default value for + @a boolean_t is: + + @code {.cpp} + bool + @endcode + + #### Storage + + Boolean values are stored directly inside a @ref basic_json type. + + @since version 1.0.0 + */ + using boolean_t = BooleanType; + + /*! + @brief a type for a number (integer) + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store integer numbers in C++, a type is defined by the template + parameter @a NumberIntegerType which chooses the type to use. + + #### Default type + + With the default values for @a NumberIntegerType (`int64_t`), the default + value for @a number_integer_t is: + + @code {.cpp} + int64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 8259](https://tools.ietf.org/html/rfc8259) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `9223372036854775807` (INT64_MAX) and the minimal integer number + that can be stored is `-9223372036854775808` (INT64_MIN). Integer numbers + that are out of range will yield over/underflow when used in a + constructor. During deserialization, too large or small integer numbers + will be automatically be stored as @ref number_unsigned_t or @ref + number_float_t. + + [RFC 8259](https://tools.ietf.org/html/rfc8259) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange of the exactly supported range [INT64_MIN, + INT64_MAX], this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa see @ref number_float_t -- type for number values (floating-point) + + @sa see @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_integer_t = NumberIntegerType; + + /*! + @brief a type for a number (unsigned) + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store unsigned integer numbers in C++, a type is defined by the + template parameter @a NumberUnsignedType which chooses the type to use. + + #### Default type + + With the default values for @a NumberUnsignedType (`uint64_t`), the + default value for @a number_unsigned_t is: + + @code {.cpp} + uint64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 8259](https://tools.ietf.org/html/rfc8259) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `18446744073709551615` (UINT64_MAX) and the minimal integer + number that can be stored is `0`. Integer numbers that are out of range + will yield over/underflow when used in a constructor. During + deserialization, too large or small integer numbers will be automatically + be stored as @ref number_integer_t or @ref number_float_t. + + [RFC 8259](https://tools.ietf.org/html/rfc8259) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange (when considered in conjunction with the + number_integer_t type) of the exactly supported range [0, UINT64_MAX], + this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa see @ref number_float_t -- type for number values (floating-point) + @sa see @ref number_integer_t -- type for number values (integer) + + @since version 2.0.0 + */ + using number_unsigned_t = NumberUnsignedType; + + /*! + @brief a type for a number (floating-point) + + [RFC 8259](https://tools.ietf.org/html/rfc8259) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store floating-point numbers in C++, a type is defined by the template + parameter @a NumberFloatType which chooses the type to use. + + #### Default type + + With the default values for @a NumberFloatType (`double`), the default + value for @a number_float_t is: + + @code {.cpp} + double + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in floating-point literals will be ignored. Internally, + the value will be stored as decimal number. For instance, the C++ + floating-point literal `01.2` will be serialized to `1.2`. During + deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 8259](https://tools.ietf.org/html/rfc8259) states: + > This specification allows implementations to set limits on the range and + > precision of numbers accepted. Since software that implements IEEE + > 754-2008 binary64 (double precision) numbers is generally available and + > widely used, good interoperability can be achieved by implementations + > that expect no more precision or range than these provide, in the sense + > that implementations will approximate JSON numbers within the expected + > precision. + + This implementation does exactly follow this approach, as it uses double + precision floating-point numbers. Note values smaller than + `-1.79769313486232e+308` and values greater than `1.79769313486232e+308` + will be stored as NaN internally and be serialized to `null`. + + #### Storage + + Floating-point number values are stored directly inside a @ref basic_json + type. + + @sa see @ref number_integer_t -- type for number values (integer) + + @sa see @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_float_t = NumberFloatType; + + /*! + @brief a type for a packed binary type + + This type is a type designed to carry binary data that appears in various + serialized formats, such as CBOR's Major Type 2, MessagePack's bin, and + BSON's generic binary subtype. This type is NOT a part of standard JSON and + exists solely for compatibility with these binary types. As such, it is + simply defined as an ordered sequence of zero or more byte values. + + Additionally, as an implementation detail, the subtype of the binary data is + carried around as a `std::uint8_t`, which is compatible with both of the + binary data formats that use binary subtyping, (though the specific + numbering is incompatible with each other, and it is up to the user to + translate between them). + + [CBOR's RFC 7049](https://tools.ietf.org/html/rfc7049) describes this type + as: + > Major type 2: a byte string. The string's length in bytes is represented + > following the rules for positive integers (major type 0). + + [MessagePack's documentation on the bin type + family](https://github.com/msgpack/msgpack/blob/master/spec.md#bin-format-family) + describes this type as: + > Bin format family stores an byte array in 2, 3, or 5 bytes of extra bytes + > in addition to the size of the byte array. + + [BSON's specifications](http://bsonspec.org/spec.html) describe several + binary types; however, this type is intended to represent the generic binary + type which has the description: + > Generic binary subtype - This is the most commonly used binary subtype and + > should be the 'default' for drivers and tools. + + None of these impose any limitations on the internal representation other + than the basic unit of storage be some type of array whose parts are + decomposable into bytes. + + The default representation of this binary format is a + `std::vector`, which is a very common way to represent a byte + array in modern C++. + + #### Default type + + The default values for @a BinaryType is `std::vector` + + #### Storage + + Binary Arrays are stored as pointers in a @ref basic_json type. That is, + for any access to array values, a pointer of the type `binary_t*` must be + dereferenced. + + #### Notes on subtypes + + - CBOR + - Binary values are represented as byte strings. Subtypes are serialized + as tagged values. + - MessagePack + - If a subtype is given and the binary array contains exactly 1, 2, 4, 8, + or 16 elements, the fixext family (fixext1, fixext2, fixext4, fixext8) + is used. For other sizes, the ext family (ext8, ext16, ext32) is used. + The subtype is then added as singed 8-bit integer. + - If no subtype is given, the bin family (bin8, bin16, bin32) is used. + - BSON + - If a subtype is given, it is used and added as unsigned 8-bit integer. + - If no subtype is given, the generic binary subtype 0x00 is used. + + @sa see @ref binary -- create a binary array + + @since version 3.8.0 + */ + using binary_t = nlohmann::byte_container_with_subtype; + /// @} + + private: + + /// helper for exception-safe object creation + template + JSON_HEDLEY_RETURNS_NON_NULL + static T* create(Args&& ... args) + { + AllocatorType alloc; + using AllocatorTraits = std::allocator_traits>; + + auto deleter = [&](T * obj) + { + AllocatorTraits::deallocate(alloc, obj, 1); + }; + std::unique_ptr obj(AllocatorTraits::allocate(alloc, 1), deleter); + AllocatorTraits::construct(alloc, obj.get(), std::forward(args)...); + JSON_ASSERT(obj != nullptr); + return obj.release(); + } + + //////////////////////// + // JSON value storage // + //////////////////////// + + JSON_PRIVATE_UNLESS_TESTED: + /*! + @brief a JSON value + + The actual storage for a JSON value of the @ref basic_json class. This + union combines the different storage types for the JSON value types + defined in @ref value_t. + + JSON type | value_t type | used type + --------- | --------------- | ------------------------ + object | object | pointer to @ref object_t + array | array | pointer to @ref array_t + string | string | pointer to @ref string_t + boolean | boolean | @ref boolean_t + number | number_integer | @ref number_integer_t + number | number_unsigned | @ref number_unsigned_t + number | number_float | @ref number_float_t + binary | binary | pointer to @ref binary_t + null | null | *no value is stored* + + @note Variable-length types (objects, arrays, and strings) are stored as + pointers. The size of the union should not exceed 64 bits if the default + value types are used. + + @since version 1.0.0 + */ + union json_value + { + /// object (stored with pointer to save storage) + object_t* object; + /// array (stored with pointer to save storage) + array_t* array; + /// string (stored with pointer to save storage) + string_t* string; + /// binary (stored with pointer to save storage) + binary_t* binary; + /// boolean + boolean_t boolean; + /// number (integer) + number_integer_t number_integer; + /// number (unsigned integer) + number_unsigned_t number_unsigned; + /// number (floating-point) + number_float_t number_float; + + /// default constructor (for null values) + json_value() = default; + /// constructor for booleans + json_value(boolean_t v) noexcept : boolean(v) {} + /// constructor for numbers (integer) + json_value(number_integer_t v) noexcept : number_integer(v) {} + /// constructor for numbers (unsigned) + json_value(number_unsigned_t v) noexcept : number_unsigned(v) {} + /// constructor for numbers (floating-point) + json_value(number_float_t v) noexcept : number_float(v) {} + /// constructor for empty values of a given type + json_value(value_t t) + { + switch (t) + { + case value_t::object: + { + object = create(); + break; + } + + case value_t::array: + { + array = create(); + break; + } + + case value_t::string: + { + string = create(""); + break; + } + + case value_t::binary: + { + binary = create(); + break; + } + + case value_t::boolean: + { + boolean = boolean_t(false); + break; + } + + case value_t::number_integer: + { + number_integer = number_integer_t(0); + break; + } + + case value_t::number_unsigned: + { + number_unsigned = number_unsigned_t(0); + break; + } + + case value_t::number_float: + { + number_float = number_float_t(0.0); + break; + } + + case value_t::null: + { + object = nullptr; // silence warning, see #821 + break; + } + + case value_t::discarded: + default: + { + object = nullptr; // silence warning, see #821 + if (JSON_HEDLEY_UNLIKELY(t == value_t::null)) + { + JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 3.10.4", basic_json())); // LCOV_EXCL_LINE + } + break; + } + } + } + + /// constructor for strings + json_value(const string_t& value) : string(create(value)) {} + + /// constructor for rvalue strings + json_value(string_t&& value) : string(create(std::move(value))) {} + + /// constructor for objects + json_value(const object_t& value) : object(create(value)) {} + + /// constructor for rvalue objects + json_value(object_t&& value) : object(create(std::move(value))) {} + + /// constructor for arrays + json_value(const array_t& value) : array(create(value)) {} + + /// constructor for rvalue arrays + json_value(array_t&& value) : array(create(std::move(value))) {} + + /// constructor for binary arrays + json_value(const typename binary_t::container_type& value) : binary(create(value)) {} + + /// constructor for rvalue binary arrays + json_value(typename binary_t::container_type&& value) : binary(create(std::move(value))) {} + + /// constructor for binary arrays (internal type) + json_value(const binary_t& value) : binary(create(value)) {} + + /// constructor for rvalue binary arrays (internal type) + json_value(binary_t&& value) : binary(create(std::move(value))) {} + + void destroy(value_t t) + { + if (t == value_t::array || t == value_t::object) + { + // flatten the current json_value to a heap-allocated stack + std::vector stack; + + // move the top-level items to stack + if (t == value_t::array) + { + stack.reserve(array->size()); + std::move(array->begin(), array->end(), std::back_inserter(stack)); + } + else + { + stack.reserve(object->size()); + for (auto&& it : *object) + { + stack.push_back(std::move(it.second)); + } + } + + while (!stack.empty()) + { + // move the last item to local variable to be processed + basic_json current_item(std::move(stack.back())); + stack.pop_back(); + + // if current_item is array/object, move + // its children to the stack to be processed later + if (current_item.is_array()) + { + std::move(current_item.m_value.array->begin(), current_item.m_value.array->end(), std::back_inserter(stack)); + + current_item.m_value.array->clear(); + } + else if (current_item.is_object()) + { + for (auto&& it : *current_item.m_value.object) + { + stack.push_back(std::move(it.second)); + } + + current_item.m_value.object->clear(); + } + + // it's now safe that current_item get destructed + // since it doesn't have any children + } + } + + switch (t) + { + case value_t::object: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, object); + std::allocator_traits::deallocate(alloc, object, 1); + break; + } + + case value_t::array: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, array); + std::allocator_traits::deallocate(alloc, array, 1); + break; + } + + case value_t::string: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, string); + std::allocator_traits::deallocate(alloc, string, 1); + break; + } + + case value_t::binary: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, binary); + std::allocator_traits::deallocate(alloc, binary, 1); + break; + } + + case value_t::null: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::discarded: + default: + { + break; + } + } + } + }; + + private: + /*! + @brief checks the class invariants + + This function asserts the class invariants. It needs to be called at the + end of every constructor to make sure that created objects respect the + invariant. Furthermore, it has to be called each time the type of a JSON + value is changed, because the invariant expresses a relationship between + @a m_type and @a m_value. + + Furthermore, the parent relation is checked for arrays and objects: If + @a check_parents true and the value is an array or object, then the + container's elements must have the current value as parent. + + @param[in] check_parents whether the parent relation should be checked. + The value is true by default and should only be set to false + during destruction of objects when the invariant does not + need to hold. + */ + void assert_invariant(bool check_parents = true) const noexcept + { + JSON_ASSERT(m_type != value_t::object || m_value.object != nullptr); + JSON_ASSERT(m_type != value_t::array || m_value.array != nullptr); + JSON_ASSERT(m_type != value_t::string || m_value.string != nullptr); + JSON_ASSERT(m_type != value_t::binary || m_value.binary != nullptr); + +#if JSON_DIAGNOSTICS + JSON_TRY + { + // cppcheck-suppress assertWithSideEffect + JSON_ASSERT(!check_parents || !is_structured() || std::all_of(begin(), end(), [this](const basic_json & j) + { + return j.m_parent == this; + })); + } + JSON_CATCH(...) {} // LCOV_EXCL_LINE +#endif + static_cast(check_parents); + } + + void set_parents() + { +#if JSON_DIAGNOSTICS + switch (m_type) + { + case value_t::array: + { + for (auto& element : *m_value.array) + { + element.m_parent = this; + } + break; + } + + case value_t::object: + { + for (auto& element : *m_value.object) + { + element.second.m_parent = this; + } + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + break; + } +#endif + } + + iterator set_parents(iterator it, typename iterator::difference_type count) + { +#if JSON_DIAGNOSTICS + for (typename iterator::difference_type i = 0; i < count; ++i) + { + (it + i)->m_parent = this; + } +#else + static_cast(count); +#endif + return it; + } + + reference set_parent(reference j, std::size_t old_capacity = std::size_t(-1)) + { +#if JSON_DIAGNOSTICS + if (old_capacity != std::size_t(-1)) + { + // see https://github.com/nlohmann/json/issues/2838 + JSON_ASSERT(type() == value_t::array); + if (JSON_HEDLEY_UNLIKELY(m_value.array->capacity() != old_capacity)) + { + // capacity has changed: update all parents + set_parents(); + return j; + } + } + + // ordered_json uses a vector internally, so pointers could have + // been invalidated; see https://github.com/nlohmann/json/issues/2962 +#ifdef JSON_HEDLEY_MSVC_VERSION +#pragma warning(push ) +#pragma warning(disable : 4127) // ignore warning to replace if with if constexpr +#endif + if (detail::is_ordered_map::value) + { + set_parents(); + return j; + } +#ifdef JSON_HEDLEY_MSVC_VERSION +#pragma warning( pop ) +#endif + + j.m_parent = this; +#else + static_cast(j); + static_cast(old_capacity); +#endif + return j; + } + + public: + ////////////////////////// + // JSON parser callback // + ////////////////////////// + + /*! + @brief parser event types + + The parser callback distinguishes the following events: + - `object_start`: the parser read `{` and started to process a JSON object + - `key`: the parser read a key of a value in an object + - `object_end`: the parser read `}` and finished processing a JSON object + - `array_start`: the parser read `[` and started to process a JSON array + - `array_end`: the parser read `]` and finished processing a JSON array + - `value`: the parser finished reading a JSON value + + @image html callback_events.png "Example when certain parse events are triggered" + + @sa see @ref parser_callback_t for more information and examples + */ + using parse_event_t = detail::parse_event_t; + + /*! + @brief per-element parser callback type + + With a parser callback function, the result of parsing a JSON text can be + influenced. When passed to @ref parse, it is called on certain events + (passed as @ref parse_event_t via parameter @a event) with a set recursion + depth @a depth and context JSON value @a parsed. The return value of the + callback function is a boolean indicating whether the element that emitted + the callback shall be kept or not. + + We distinguish six scenarios (determined by the event type) in which the + callback function can be called. The following table describes the values + of the parameters @a depth, @a event, and @a parsed. + + parameter @a event | description | parameter @a depth | parameter @a parsed + ------------------ | ----------- | ------------------ | ------------------- + parse_event_t::object_start | the parser read `{` and started to process a JSON object | depth of the parent of the JSON object | a JSON value with type discarded + parse_event_t::key | the parser read a key of a value in an object | depth of the currently parsed JSON object | a JSON string containing the key + parse_event_t::object_end | the parser read `}` and finished processing a JSON object | depth of the parent of the JSON object | the parsed JSON object + parse_event_t::array_start | the parser read `[` and started to process a JSON array | depth of the parent of the JSON array | a JSON value with type discarded + parse_event_t::array_end | the parser read `]` and finished processing a JSON array | depth of the parent of the JSON array | the parsed JSON array + parse_event_t::value | the parser finished reading a JSON value | depth of the value | the parsed JSON value + + @image html callback_events.png "Example when certain parse events are triggered" + + Discarding a value (i.e., returning `false`) has different effects + depending on the context in which function was called: + + - Discarded values in structured types are skipped. That is, the parser + will behave as if the discarded value was never read. + - In case a value outside a structured type is skipped, it is replaced + with `null`. This case happens if the top-level element is skipped. + + @param[in] depth the depth of the recursion during parsing + + @param[in] event an event of type parse_event_t indicating the context in + the callback function has been called + + @param[in,out] parsed the current intermediate parse result; note that + writing to this value has no effect for parse_event_t::key events + + @return Whether the JSON value which called the function during parsing + should be kept (`true`) or not (`false`). In the latter case, it is either + skipped completely or replaced by an empty discarded object. + + @sa see @ref parse for examples + + @since version 1.0.0 + */ + using parser_callback_t = detail::parser_callback_t; + + ////////////////// + // constructors // + ////////////////// + + /// @name constructors and destructors + /// Constructors of class @ref basic_json, copy/move constructor, copy + /// assignment, static functions creating objects, and the destructor. + /// @{ + + /*! + @brief create an empty value with a given type + + Create an empty JSON value with a given type. The value will be default + initialized with an empty value which depends on the type: + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + object | `{}` + array | `[]` + binary | empty array + + @param[in] v the type of the value to create + + @complexity Constant. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows the constructor for different @ref + value_t values,basic_json__value_t} + + @sa see @ref clear() -- restores the postcondition of this constructor + + @since version 1.0.0 + */ + basic_json(const value_t v) + : m_type(v), m_value(v) + { + assert_invariant(); + } + + /*! + @brief create a null object + + Create a `null` JSON value. It either takes a null pointer as parameter + (explicitly creating `null`) or no parameter (implicitly creating `null`). + The passed null pointer itself is not read -- it is only used to choose + the right constructor. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @liveexample{The following code shows the constructor with and without a + null pointer parameter.,basic_json__nullptr_t} + + @since version 1.0.0 + */ + basic_json(std::nullptr_t = nullptr) noexcept + : basic_json(value_t::null) + { + assert_invariant(); + } + + /*! + @brief create a JSON value + + This is a "catch all" constructor for all compatible JSON types; that is, + types for which a `to_json()` method exists. The constructor forwards the + parameter @a val to that method (to `json_serializer::to_json` method + with `U = uncvref_t`, to be exact). + + Template type @a CompatibleType includes, but is not limited to, the + following types: + - **arrays**: @ref array_t and all kinds of compatible containers such as + `std::vector`, `std::deque`, `std::list`, `std::forward_list`, + `std::array`, `std::valarray`, `std::set`, `std::unordered_set`, + `std::multiset`, and `std::unordered_multiset` with a `value_type` from + which a @ref basic_json value can be constructed. + - **objects**: @ref object_t and all kinds of compatible associative + containers such as `std::map`, `std::unordered_map`, `std::multimap`, + and `std::unordered_multimap` with a `key_type` compatible to + @ref string_t and a `value_type` from which a @ref basic_json value can + be constructed. + - **strings**: @ref string_t, string literals, and all compatible string + containers can be used. + - **numbers**: @ref number_integer_t, @ref number_unsigned_t, + @ref number_float_t, and all convertible number types such as `int`, + `size_t`, `int64_t`, `float` or `double` can be used. + - **boolean**: @ref boolean_t / `bool` can be used. + - **binary**: @ref binary_t / `std::vector` may be used, + unfortunately because string literals cannot be distinguished from binary + character arrays by the C++ type system, all types compatible with `const + char*` will be directed to the string constructor instead. This is both + for backwards compatibility, and due to the fact that a binary type is not + a standard JSON type. + + See the examples below. + + @tparam CompatibleType a type such that: + - @a CompatibleType is not derived from `std::istream`, + - @a CompatibleType is not @ref basic_json (to avoid hijacking copy/move + constructors), + - @a CompatibleType is not a different @ref basic_json type (i.e. with different template arguments) + - @a CompatibleType is not a @ref basic_json nested type (e.g., + @ref json_pointer, @ref iterator, etc ...) + - `json_serializer` has a `to_json(basic_json_t&, CompatibleType&&)` method + + @tparam U = `uncvref_t` + + @param[in] val the value to be forwarded to the respective constructor + + @complexity Usually linear in the size of the passed @a val, also + depending on the implementation of the called `to_json()` + method. + + @exceptionsafety Depends on the called constructor. For types directly + supported by the library (i.e., all types for which no `to_json()` function + was provided), strong guarantee holds: if an exception is thrown, there are + no changes to any JSON value. + + @liveexample{The following code shows the constructor with several + compatible types.,basic_json__CompatibleType} + + @since version 2.1.0 + */ + template < typename CompatibleType, + typename U = detail::uncvref_t, + detail::enable_if_t < + !detail::is_basic_json::value && detail::is_compatible_type::value, int > = 0 > + basic_json(CompatibleType && val) noexcept(noexcept( // NOLINT(bugprone-forwarding-reference-overload,bugprone-exception-escape) + JSONSerializer::to_json(std::declval(), + std::forward(val)))) + { + JSONSerializer::to_json(*this, std::forward(val)); + set_parents(); + assert_invariant(); + } + + /*! + @brief create a JSON value from an existing one + + This is a constructor for existing @ref basic_json types. + It does not hijack copy/move constructors, since the parameter has different + template arguments than the current ones. + + The constructor tries to convert the internal @ref m_value of the parameter. + + @tparam BasicJsonType a type such that: + - @a BasicJsonType is a @ref basic_json type. + - @a BasicJsonType has different template arguments than @ref basic_json_t. + + @param[in] val the @ref basic_json value to be converted. + + @complexity Usually linear in the size of the passed @a val, also + depending on the implementation of the called `to_json()` + method. + + @exceptionsafety Depends on the called constructor. For types directly + supported by the library (i.e., all types for which no `to_json()` function + was provided), strong guarantee holds: if an exception is thrown, there are + no changes to any JSON value. + + @since version 3.2.0 + */ + template < typename BasicJsonType, + detail::enable_if_t < + detail::is_basic_json::value&& !std::is_same::value, int > = 0 > + basic_json(const BasicJsonType& val) + { + using other_boolean_t = typename BasicJsonType::boolean_t; + using other_number_float_t = typename BasicJsonType::number_float_t; + using other_number_integer_t = typename BasicJsonType::number_integer_t; + using other_number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using other_string_t = typename BasicJsonType::string_t; + using other_object_t = typename BasicJsonType::object_t; + using other_array_t = typename BasicJsonType::array_t; + using other_binary_t = typename BasicJsonType::binary_t; + + switch (val.type()) + { + case value_t::boolean: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_float: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_integer: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_unsigned: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::string: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::object: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::array: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::binary: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::null: + *this = nullptr; + break; + case value_t::discarded: + m_type = value_t::discarded; + break; + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + set_parents(); + assert_invariant(); + } + + /*! + @brief create a container (array or object) from an initializer list + + Creates a JSON value of type array or object from the passed initializer + list @a init. In case @a type_deduction is `true` (default), the type of + the JSON value to be created is deducted from the initializer list @a init + according to the following rules: + + 1. If the list is empty, an empty JSON object value `{}` is created. + 2. If the list consists of pairs whose first element is a string, a JSON + object value is created where the first elements of the pairs are + treated as keys and the second elements are as values. + 3. In all other cases, an array is created. + + The rules aim to create the best fit between a C++ initializer list and + JSON values. The rationale is as follows: + + 1. The empty initializer list is written as `{}` which is exactly an empty + JSON object. + 2. C++ has no way of describing mapped types other than to list a list of + pairs. As JSON requires that keys must be of type string, rule 2 is the + weakest constraint one can pose on initializer lists to interpret them + as an object. + 3. In all other cases, the initializer list could not be interpreted as + JSON object type, so interpreting it as JSON array type is safe. + + With the rules described above, the following JSON values cannot be + expressed by an initializer list: + + - the empty array (`[]`): use @ref array(initializer_list_t) + with an empty initializer list in this case + - arrays whose elements satisfy rule 2: use @ref + array(initializer_list_t) with the same initializer list + in this case + + @note When used without parentheses around an empty initializer list, @ref + basic_json() is called instead of this function, yielding the JSON null + value. + + @param[in] init initializer list with JSON values + + @param[in] type_deduction internal parameter; when set to `true`, the type + of the JSON value is deducted from the initializer list @a init; when set + to `false`, the type provided via @a manual_type is forced. This mode is + used by the functions @ref array(initializer_list_t) and + @ref object(initializer_list_t). + + @param[in] manual_type internal parameter; when @a type_deduction is set + to `false`, the created JSON value will use the provided type (only @ref + value_t::array and @ref value_t::object are valid); when @a type_deduction + is set to `true`, this parameter has no effect + + @throw type_error.301 if @a type_deduction is `false`, @a manual_type is + `value_t::object`, but @a init contains an element which is not a pair + whose first element is a string. In this case, the constructor could not + create an object. If @a type_deduction would have be `true`, an array + would have been created. See @ref object(initializer_list_t) + for an example. + + @complexity Linear in the size of the initializer list @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows how JSON values are created from + initializer lists.,basic_json__list_init_t} + + @sa see @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + @sa see @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + basic_json(initializer_list_t init, + bool type_deduction = true, + value_t manual_type = value_t::array) + { + // check if each element is an array with two elements whose first + // element is a string + bool is_an_object = std::all_of(init.begin(), init.end(), + [](const detail::json_ref& element_ref) + { + return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[0].is_string(); + }); + + // adjust type if type deduction is not wanted + if (!type_deduction) + { + // if array is wanted, do not create an object though possible + if (manual_type == value_t::array) + { + is_an_object = false; + } + + // if object is wanted but impossible, throw an exception + if (JSON_HEDLEY_UNLIKELY(manual_type == value_t::object && !is_an_object)) + { + JSON_THROW(type_error::create(301, "cannot create object from initializer list", basic_json())); + } + } + + if (is_an_object) + { + // the initializer list is a list of pairs -> create object + m_type = value_t::object; + m_value = value_t::object; + + for (auto& element_ref : init) + { + auto element = element_ref.moved_or_copied(); + m_value.object->emplace( + std::move(*((*element.m_value.array)[0].m_value.string)), + std::move((*element.m_value.array)[1])); + } + } + else + { + // the initializer list describes an array -> create array + m_type = value_t::array; + m_value.array = create(init.begin(), init.end()); + } + + set_parents(); + assert_invariant(); + } + + /*! + @brief explicitly create a binary array (without subtype) + + Creates a JSON binary array value from a given binary container. Binary + values are part of various binary formats, such as CBOR, MessagePack, and + BSON. This constructor is used to create a value for serialization to those + formats. + + @note Note, this function exists because of the difficulty in correctly + specifying the correct template overload in the standard value ctor, as both + JSON arrays and JSON binary arrays are backed with some form of a + `std::vector`. Because JSON binary arrays are a non-standard extension it + was decided that it would be best to prevent automatic initialization of a + binary array type, for backwards compatibility and so it does not happen on + accident. + + @param[in] init container containing bytes to use as binary type + + @return JSON binary array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @since version 3.8.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(const typename binary_t::container_type& init) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = init; + return res; + } + + /*! + @brief explicitly create a binary array (with subtype) + + Creates a JSON binary array value from a given binary container. Binary + values are part of various binary formats, such as CBOR, MessagePack, and + BSON. This constructor is used to create a value for serialization to those + formats. + + @note Note, this function exists because of the difficulty in correctly + specifying the correct template overload in the standard value ctor, as both + JSON arrays and JSON binary arrays are backed with some form of a + `std::vector`. Because JSON binary arrays are a non-standard extension it + was decided that it would be best to prevent automatic initialization of a + binary array type, for backwards compatibility and so it does not happen on + accident. + + @param[in] init container containing bytes to use as binary type + @param[in] subtype subtype to use in MessagePack and BSON + + @return JSON binary array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @since version 3.8.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(const typename binary_t::container_type& init, typename binary_t::subtype_type subtype) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = binary_t(init, subtype); + return res; + } + + /// @copydoc binary(const typename binary_t::container_type&) + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(typename binary_t::container_type&& init) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = std::move(init); + return res; + } + + /// @copydoc binary(const typename binary_t::container_type&, typename binary_t::subtype_type) + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(typename binary_t::container_type&& init, typename binary_t::subtype_type subtype) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = binary_t(std::move(init), subtype); + return res; + } + + /*! + @brief explicitly create an array from an initializer list + + Creates a JSON array value from a given initializer list. That is, given a + list of values `a, b, c`, creates the JSON value `[a, b, c]`. If the + initializer list is empty, the empty array `[]` is created. + + @note This function is only needed to express two edge cases that cannot + be realized with the initializer list constructor (@ref + basic_json(initializer_list_t, bool, value_t)). These cases + are: + 1. creating an array whose elements are all pairs whose first element is a + string -- in this case, the initializer list constructor would create an + object, taking the first elements as keys + 2. creating an empty array -- passing the empty initializer list to the + initializer list constructor yields an empty object + + @param[in] init initializer list with JSON values to create an array from + (optional) + + @return JSON array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `array` + function.,array} + + @sa see @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa see @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json array(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::array); + } + + /*! + @brief explicitly create an object from an initializer list + + Creates a JSON object value from a given initializer list. The initializer + lists elements must be pairs, and their first elements must be strings. If + the initializer list is empty, the empty object `{}` is created. + + @note This function is only added for symmetry reasons. In contrast to the + related function @ref array(initializer_list_t), there are + no cases which can only be expressed by this function. That is, any + initializer list @a init can also be passed to the initializer list + constructor @ref basic_json(initializer_list_t, bool, value_t). + + @param[in] init initializer list to create an object from (optional) + + @return JSON object value + + @throw type_error.301 if @a init is not a list of pairs whose first + elements are strings. In this case, no object can be created. When such a + value is passed to @ref basic_json(initializer_list_t, bool, value_t), + an array would have been created from the passed initializer list @a init. + See example below. + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `object` + function.,object} + + @sa see @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa see @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + + @since version 1.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json object(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::object); + } + + /*! + @brief construct an array with count copies of given value + + Constructs a JSON array value by creating @a cnt copies of a passed value. + In case @a cnt is `0`, an empty array is created. + + @param[in] cnt the number of JSON copies of @a val to create + @param[in] val the JSON value to copy + + @post `std::distance(begin(),end()) == cnt` holds. + + @complexity Linear in @a cnt. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows examples for the @ref + basic_json(size_type\, const basic_json&) + constructor.,basic_json__size_type_basic_json} + + @since version 1.0.0 + */ + basic_json(size_type cnt, const basic_json& val) + : m_type(value_t::array) + { + m_value.array = create(cnt, val); + set_parents(); + assert_invariant(); + } + + /*! + @brief construct a JSON container given an iterator range + + Constructs the JSON value with the contents of the range `[first, last)`. + The semantics depends on the different types a JSON value can have: + - In case of a null type, invalid_iterator.206 is thrown. + - In case of other primitive types (number, boolean, or string), @a first + must be `begin()` and @a last must be `end()`. In this case, the value is + copied. Otherwise, invalid_iterator.204 is thrown. + - In case of structured types (array, object), the constructor behaves as + similar versions for `std::vector` or `std::map`; that is, a JSON array + or object is constructed from the values in the range. + + @tparam InputIT an input iterator type (@ref iterator or @ref + const_iterator) + + @param[in] first begin of the range to copy from (included) + @param[in] last end of the range to copy from (excluded) + + @pre Iterators @a first and @a last must be initialized. **This + precondition is enforced with an assertion (see warning).** If + assertions are switched off, a violation of this precondition yields + undefined behavior. + + @pre Range `[first, last)` is valid. Usually, this precondition cannot be + checked efficiently. Only certain edge cases are detected; see the + description of the exceptions below. A violation of this precondition + yields undefined behavior. + + @warning A precondition is enforced with a runtime assertion that will + result in calling `std::abort` if this precondition is not met. + Assertions can be disabled by defining `NDEBUG` at compile time. + See https://en.cppreference.com/w/cpp/error/assert for more + information. + + @throw invalid_iterator.201 if iterators @a first and @a last are not + compatible (i.e., do not belong to the same JSON value). In this case, + the range `[first, last)` is undefined. + @throw invalid_iterator.204 if iterators @a first and @a last belong to a + primitive type (number, boolean, or string), but @a first does not point + to the first element any more. In this case, the range `[first, last)` is + undefined. See example code below. + @throw invalid_iterator.206 if iterators @a first and @a last belong to a + null value. In this case, the range `[first, last)` is undefined. + + @complexity Linear in distance between @a first and @a last. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows several ways to create JSON values by + specifying a subrange with iterators.,basic_json__InputIt_InputIt} + + @since version 1.0.0 + */ + template < class InputIT, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type = 0 > + basic_json(InputIT first, InputIT last) + { + JSON_ASSERT(first.m_object != nullptr); + JSON_ASSERT(last.m_object != nullptr); + + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(201, "iterators are not compatible", basic_json())); + } + + // copy type from first iterator + m_type = first.m_object->m_type; + + // check if iterator range is complete for primitive values + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + { + if (JSON_HEDLEY_UNLIKELY(!first.m_it.primitive_iterator.is_begin() + || !last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range", *first.m_object)); + } + break; + } + + case value_t::null: + case value_t::object: + case value_t::array: + case value_t::binary: + case value_t::discarded: + default: + break; + } + + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = first.m_object->m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = first.m_object->m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value.number_float = first.m_object->m_value.number_float; + break; + } + + case value_t::boolean: + { + m_value.boolean = first.m_object->m_value.boolean; + break; + } + + case value_t::string: + { + m_value = *first.m_object->m_value.string; + break; + } + + case value_t::object: + { + m_value.object = create(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + m_value.array = create(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + case value_t::binary: + { + m_value = *first.m_object->m_value.binary; + break; + } + + case value_t::null: + case value_t::discarded: + default: + JSON_THROW(invalid_iterator::create(206, "cannot construct with iterators from " + std::string(first.m_object->type_name()), *first.m_object)); + } + + set_parents(); + assert_invariant(); + } + + + /////////////////////////////////////// + // other constructors and destructor // + /////////////////////////////////////// + + template, + std::is_same>::value, int> = 0 > + basic_json(const JsonRef& ref) : basic_json(ref.moved_or_copied()) {} + + /*! + @brief copy constructor + + Creates a copy of a given JSON value. + + @param[in] other the JSON value to copy + + @post `*this == other` + + @complexity Linear in the size of @a other. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + - As postcondition, it holds: `other == basic_json(other)`. + + @liveexample{The following code shows an example for the copy + constructor.,basic_json__basic_json} + + @since version 1.0.0 + */ + basic_json(const basic_json& other) + : m_type(other.m_type) + { + // check of passed value is valid + other.assert_invariant(); + + switch (m_type) + { + case value_t::object: + { + m_value = *other.m_value.object; + break; + } + + case value_t::array: + { + m_value = *other.m_value.array; + break; + } + + case value_t::string: + { + m_value = *other.m_value.string; + break; + } + + case value_t::boolean: + { + m_value = other.m_value.boolean; + break; + } + + case value_t::number_integer: + { + m_value = other.m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value = other.m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value = other.m_value.number_float; + break; + } + + case value_t::binary: + { + m_value = *other.m_value.binary; + break; + } + + case value_t::null: + case value_t::discarded: + default: + break; + } + + set_parents(); + assert_invariant(); + } + + /*! + @brief move constructor + + Move constructor. Constructs a JSON value with the contents of the given + value @a other using move semantics. It "steals" the resources from @a + other and leaves it as JSON null value. + + @param[in,out] other value to move to this object + + @post `*this` has the same value as @a other before the call. + @post @a other is a JSON null value. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @requirement This function helps `basic_json` satisfying the + [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible) + requirements. + + @liveexample{The code below shows the move constructor explicitly called + via std::move.,basic_json__moveconstructor} + + @since version 1.0.0 + */ + basic_json(basic_json&& other) noexcept + : m_type(std::move(other.m_type)), + m_value(std::move(other.m_value)) + { + // check that passed value is valid + other.assert_invariant(false); + + // invalidate payload + other.m_type = value_t::null; + other.m_value = {}; + + set_parents(); + assert_invariant(); + } + + /*! + @brief copy assignment + + Copy assignment operator. Copies a JSON value via the "copy and swap" + strategy: It is expressed in terms of the copy constructor, destructor, + and the `swap()` member function. + + @param[in] other value to copy from + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + + @liveexample{The code below shows and example for the copy assignment. It + creates a copy of value `a` which is then swapped with `b`. Finally\, the + copy of `a` (which is the null value after the swap) is + destroyed.,basic_json__copyassignment} + + @since version 1.0.0 + */ + basic_json& operator=(basic_json other) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + // check that passed value is valid + other.assert_invariant(); + + using std::swap; + swap(m_type, other.m_type); + swap(m_value, other.m_value); + + set_parents(); + assert_invariant(); + return *this; + } + + /*! + @brief destructor + + Destroys the JSON value and frees all allocated memory. + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + - All stored elements are destroyed and all memory is freed. + + @since version 1.0.0 + */ + ~basic_json() noexcept + { + assert_invariant(false); + m_value.destroy(m_type); + } + + /// @} + + public: + /////////////////////// + // object inspection // + /////////////////////// + + /// @name object inspection + /// Functions to inspect the type of a JSON value. + /// @{ + + /*! + @brief serialization + + Serialization function for JSON values. The function tries to mimic + Python's `json.dumps()` function, and currently supports its @a indent + and @a ensure_ascii parameters. + + @param[in] indent If indent is nonnegative, then array elements and object + members will be pretty-printed with that indent level. An indent level of + `0` will only insert newlines. `-1` (the default) selects the most compact + representation. + @param[in] indent_char The character to use for indentation if @a indent is + greater than `0`. The default is ` ` (space). + @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters + in the output are escaped with `\uXXXX` sequences, and the result consists + of ASCII characters only. + @param[in] error_handler how to react on decoding errors; there are three + possible values: `strict` (throws and exception in case a decoding error + occurs; default), `replace` (replace invalid UTF-8 sequences with U+FFFD), + and `ignore` (ignore invalid UTF-8 sequences during serialization; all + bytes are copied to the output unchanged). + + @return string containing the serialization of the JSON value + + @throw type_error.316 if a string stored inside the JSON value is not + UTF-8 encoded and @a error_handler is set to strict + + @note Binary values are serialized as object containing two keys: + - "bytes": an array of bytes as integers + - "subtype": the subtype as integer or "null" if the binary has no subtype + + @complexity Linear. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @liveexample{The following example shows the effect of different @a indent\, + @a indent_char\, and @a ensure_ascii parameters to the result of the + serialization.,dump} + + @see https://docs.python.org/2/library/json.html#json.dump + + @since version 1.0.0; indentation character @a indent_char, option + @a ensure_ascii and exceptions added in version 3.0.0; error + handlers added in version 3.4.0; serialization of binary values added + in version 3.8.0. + */ + string_t dump(const int indent = -1, + const char indent_char = ' ', + const bool ensure_ascii = false, + const error_handler_t error_handler = error_handler_t::strict) const + { + string_t result; + serializer s(detail::output_adapter(result), indent_char, error_handler); + + if (indent >= 0) + { + s.dump(*this, true, ensure_ascii, static_cast(indent)); + } + else + { + s.dump(*this, false, ensure_ascii, 0); + } + + return result; + } + + /*! + @brief return the type of the JSON value (explicit) + + Return the type of the JSON value as a value from the @ref value_t + enumeration. + + @return the type of the JSON value + Value type | return value + ------------------------- | ------------------------- + null | value_t::null + boolean | value_t::boolean + string | value_t::string + number (integer) | value_t::number_integer + number (unsigned integer) | value_t::number_unsigned + number (floating-point) | value_t::number_float + object | value_t::object + array | value_t::array + binary | value_t::binary + discarded | value_t::discarded + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `type()` for all JSON + types.,type} + + @sa see @ref operator value_t() -- return the type of the JSON value (implicit) + @sa see @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr value_t type() const noexcept + { + return m_type; + } + + /*! + @brief return whether type is primitive + + This function returns true if and only if the JSON type is primitive + (string, number, boolean, or null). + + @return `true` if type is primitive (string, number, boolean, or null), + `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_primitive()` for all JSON + types.,is_primitive} + + @sa see @ref is_structured() -- returns whether JSON value is structured + @sa see @ref is_null() -- returns whether JSON value is `null` + @sa see @ref is_string() -- returns whether JSON value is a string + @sa see @ref is_boolean() -- returns whether JSON value is a boolean + @sa see @ref is_number() -- returns whether JSON value is a number + @sa see @ref is_binary() -- returns whether JSON value is a binary array + + @since version 1.0.0 + */ + constexpr bool is_primitive() const noexcept + { + return is_null() || is_string() || is_boolean() || is_number() || is_binary(); + } + + /*! + @brief return whether type is structured + + This function returns true if and only if the JSON type is structured + (array or object). + + @return `true` if type is structured (array or object), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_structured()` for all JSON + types.,is_structured} + + @sa see @ref is_primitive() -- returns whether value is primitive + @sa see @ref is_array() -- returns whether value is an array + @sa see @ref is_object() -- returns whether value is an object + + @since version 1.0.0 + */ + constexpr bool is_structured() const noexcept + { + return is_array() || is_object(); + } + + /*! + @brief return whether value is null + + This function returns true if and only if the JSON value is null. + + @return `true` if type is null, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_null()` for all JSON + types.,is_null} + + @since version 1.0.0 + */ + constexpr bool is_null() const noexcept + { + return m_type == value_t::null; + } + + /*! + @brief return whether value is a boolean + + This function returns true if and only if the JSON value is a boolean. + + @return `true` if type is boolean, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_boolean()` for all JSON + types.,is_boolean} + + @since version 1.0.0 + */ + constexpr bool is_boolean() const noexcept + { + return m_type == value_t::boolean; + } + + /*! + @brief return whether value is a number + + This function returns true if and only if the JSON value is a number. This + includes both integer (signed and unsigned) and floating-point values. + + @return `true` if type is number (regardless whether integer, unsigned + integer or floating-type), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number()` for all JSON + types.,is_number} + + @sa see @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa see @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa see @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number() const noexcept + { + return is_number_integer() || is_number_float(); + } + + /*! + @brief return whether value is an integer number + + This function returns true if and only if the JSON value is a signed or + unsigned integer number. This excludes floating-point values. + + @return `true` if type is an integer or unsigned integer number, `false` + otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_integer()` for all + JSON types.,is_number_integer} + + @sa see @ref is_number() -- check if value is a number + @sa see @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa see @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number_integer() const noexcept + { + return m_type == value_t::number_integer || m_type == value_t::number_unsigned; + } + + /*! + @brief return whether value is an unsigned integer number + + This function returns true if and only if the JSON value is an unsigned + integer number. This excludes floating-point and signed integer values. + + @return `true` if type is an unsigned integer number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_unsigned()` for all + JSON types.,is_number_unsigned} + + @sa see @ref is_number() -- check if value is a number + @sa see @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa see @ref is_number_float() -- check if value is a floating-point number + + @since version 2.0.0 + */ + constexpr bool is_number_unsigned() const noexcept + { + return m_type == value_t::number_unsigned; + } + + /*! + @brief return whether value is a floating-point number + + This function returns true if and only if the JSON value is a + floating-point number. This excludes signed and unsigned integer values. + + @return `true` if type is a floating-point number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_float()` for all + JSON types.,is_number_float} + + @sa see @ref is_number() -- check if value is number + @sa see @ref is_number_integer() -- check if value is an integer number + @sa see @ref is_number_unsigned() -- check if value is an unsigned integer + number + + @since version 1.0.0 + */ + constexpr bool is_number_float() const noexcept + { + return m_type == value_t::number_float; + } + + /*! + @brief return whether value is an object + + This function returns true if and only if the JSON value is an object. + + @return `true` if type is object, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_object()` for all JSON + types.,is_object} + + @since version 1.0.0 + */ + constexpr bool is_object() const noexcept + { + return m_type == value_t::object; + } + + /*! + @brief return whether value is an array + + This function returns true if and only if the JSON value is an array. + + @return `true` if type is array, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_array()` for all JSON + types.,is_array} + + @since version 1.0.0 + */ + constexpr bool is_array() const noexcept + { + return m_type == value_t::array; + } + + /*! + @brief return whether value is a string + + This function returns true if and only if the JSON value is a string. + + @return `true` if type is string, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_string()` for all JSON + types.,is_string} + + @since version 1.0.0 + */ + constexpr bool is_string() const noexcept + { + return m_type == value_t::string; + } + + /*! + @brief return whether value is a binary array + + This function returns true if and only if the JSON value is a binary array. + + @return `true` if type is binary array, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_binary()` for all JSON + types.,is_binary} + + @since version 3.8.0 + */ + constexpr bool is_binary() const noexcept + { + return m_type == value_t::binary; + } + + /*! + @brief return whether value is discarded + + This function returns true if and only if the JSON value was discarded + during parsing with a callback function (see @ref parser_callback_t). + + @note This function will always be `false` for JSON values after parsing. + That is, discarded values can only occur during parsing, but will be + removed when inside a structured value or replaced by null in other cases. + + @return `true` if type is discarded, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_discarded()` for all JSON + types.,is_discarded} + + @since version 1.0.0 + */ + constexpr bool is_discarded() const noexcept + { + return m_type == value_t::discarded; + } + + /*! + @brief return the type of the JSON value (implicit) + + Implicitly return the type of the JSON value as a value from the @ref + value_t enumeration. + + @return the type of the JSON value + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies the @ref value_t operator for + all JSON types.,operator__value_t} + + @sa see @ref type() -- return the type of the JSON value (explicit) + @sa see @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr operator value_t() const noexcept + { + return m_type; + } + + /// @} + + private: + ////////////////// + // value access // + ////////////////// + + /// get a boolean (explicit) + boolean_t get_impl(boolean_t* /*unused*/) const + { + if (JSON_HEDLEY_LIKELY(is_boolean())) + { + return m_value.boolean; + } + + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(type_name()), *this)); + } + + /// get a pointer to the value (object) + object_t* get_impl_ptr(object_t* /*unused*/) noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (object) + constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (array) + array_t* get_impl_ptr(array_t* /*unused*/) noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (array) + constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (string) + string_t* get_impl_ptr(string_t* /*unused*/) noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (string) + constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (boolean) + boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (boolean) + constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) const noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (integer number) + number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (integer number) + constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /*unused*/) const noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (unsigned number) + number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (unsigned number) + constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t* /*unused*/) const noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (floating-point number) + number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /// get a pointer to the value (floating-point number) + constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unused*/) const noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /// get a pointer to the value (binary) + binary_t* get_impl_ptr(binary_t* /*unused*/) noexcept + { + return is_binary() ? m_value.binary : nullptr; + } + + /// get a pointer to the value (binary) + constexpr const binary_t* get_impl_ptr(const binary_t* /*unused*/) const noexcept + { + return is_binary() ? m_value.binary : nullptr; + } + + /*! + @brief helper function to implement get_ref() + + This function helps to implement get_ref() without code duplication for + const and non-const overloads + + @tparam ThisType will be deduced as `basic_json` or `const basic_json` + + @throw type_error.303 if ReferenceType does not match underlying value + type of the current JSON + */ + template + static ReferenceType get_ref_impl(ThisType& obj) + { + // delegate the call to get_ptr<>() + auto* ptr = obj.template get_ptr::type>(); + + if (JSON_HEDLEY_LIKELY(ptr != nullptr)) + { + return *ptr; + } + + JSON_THROW(type_error::create(303, "incompatible ReferenceType for get_ref, actual type is " + std::string(obj.type_name()), obj)); + } + + public: + /// @name value access + /// Direct access to the stored value of a JSON value. + /// @{ + + /*! + @brief get a pointer value (implicit) + + Implicit pointer access to the internally stored JSON value. No copies are + made. + + @warning Writing data to the pointee of the result yields an undefined + state. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. Enforced by a static + assertion. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get_ptr} + + @since version 1.0.0 + */ + template::value, int>::type = 0> + auto get_ptr() noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) + { + // delegate the call to get_impl_ptr<>() + return get_impl_ptr(static_cast(nullptr)); + } + + /*! + @brief get a pointer value (implicit) + @copydoc get_ptr() + */ + template < typename PointerType, typename std::enable_if < + std::is_pointer::value&& + std::is_const::type>::value, int >::type = 0 > + constexpr auto get_ptr() const noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) + { + // delegate the call to get_impl_ptr<>() const + return get_impl_ptr(static_cast(nullptr)); + } + + private: + /*! + @brief get a value (explicit) + + Explicit type conversion between the JSON value and a compatible value + which is [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) + and [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + ValueType ret; + JSONSerializer::from_json(*this, ret); + return ret; + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json, + - @ref json_serializer has a `from_json()` method of the form + `void from_json(const basic_json&, ValueType&)`, and + - @ref json_serializer does not have a `from_json()` method of + the form `ValueType from_json(const basic_json&)` + + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,get__ValueType_const} + + @since version 2.1.0 + */ + template < typename ValueType, + detail::enable_if_t < + detail::is_default_constructible::value&& + detail::has_from_json::value, + int > = 0 > + ValueType get_impl(detail::priority_tag<0> /*unused*/) const noexcept(noexcept( + JSONSerializer::from_json(std::declval(), std::declval()))) + { + auto ret = ValueType(); + JSONSerializer::from_json(*this, ret); + return ret; + } + + /*! + @brief get a value (explicit); special case + + Explicit type conversion between the JSON value and a compatible value + which is **not** [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) + and **not** [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + return JSONSerializer::from_json(*this); + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json and + - @ref json_serializer has a `from_json()` method of the form + `ValueType from_json(const basic_json&)` + + @note If @ref json_serializer has both overloads of + `from_json()`, this one is chosen. + + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @since version 2.1.0 + */ + template < typename ValueType, + detail::enable_if_t < + detail::has_non_default_from_json::value, + int > = 0 > + ValueType get_impl(detail::priority_tag<1> /*unused*/) const noexcept(noexcept( + JSONSerializer::from_json(std::declval()))) + { + return JSONSerializer::from_json(*this); + } + + /*! + @brief get special-case overload + + This overloads converts the current @ref basic_json in a different + @ref basic_json type + + @tparam BasicJsonType == @ref basic_json + + @return a copy of *this, converted into @a BasicJsonType + + @complexity Depending on the implementation of the called `from_json()` + method. + + @since version 3.2.0 + */ + template < typename BasicJsonType, + detail::enable_if_t < + detail::is_basic_json::value, + int > = 0 > + BasicJsonType get_impl(detail::priority_tag<2> /*unused*/) const + { + return *this; + } + + /*! + @brief get special-case overload + + This overloads avoids a lot of template boilerplate, it can be seen as the + identity method + + @tparam BasicJsonType == @ref basic_json + + @return a copy of *this + + @complexity Constant. + + @since version 2.1.0 + */ + template::value, + int> = 0> + basic_json get_impl(detail::priority_tag<3> /*unused*/) const + { + return *this; + } + + /*! + @brief get a pointer value (explicit) + @copydoc get() + */ + template::value, + int> = 0> + constexpr auto get_impl(detail::priority_tag<4> /*unused*/) const noexcept + -> decltype(std::declval().template get_ptr()) + { + // delegate the call to get_ptr + return get_ptr(); + } + + public: + /*! + @brief get a (pointer) value (explicit) + + Performs explicit type conversion between the JSON value and a compatible value if required. + + - If the requested type is a pointer to the internally stored JSON value that pointer is returned. + No copies are made. + + - If the requested type is the current @ref basic_json, or a different @ref basic_json convertible + from the current @ref basic_json. + + - Otherwise the value is converted by calling the @ref json_serializer `from_json()` + method. + + @tparam ValueTypeCV the provided value type + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @tparam ValueType if necessary + + @throw what @ref json_serializer `from_json()` method throws if conversion is required + + @since version 2.1.0 + */ + template < typename ValueTypeCV, typename ValueType = detail::uncvref_t> +#if defined(JSON_HAS_CPP_14) + constexpr +#endif + auto get() const noexcept( + noexcept(std::declval().template get_impl(detail::priority_tag<4> {}))) + -> decltype(std::declval().template get_impl(detail::priority_tag<4> {})) + { + // we cannot static_assert on ValueTypeCV being non-const, because + // there is support for get(), which is why we + // still need the uncvref + static_assert(!std::is_reference::value, + "get() cannot be used with reference types, you might want to use get_ref()"); + return get_impl(detail::priority_tag<4> {}); + } + + /*! + @brief get a pointer value (explicit) + + Explicit pointer access to the internally stored JSON value. No copies are + made. + + @warning The pointer becomes invalid if the underlying JSON object + changes. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get__PointerType} + + @sa see @ref get_ptr() for explicit pointer-member access + + @since version 1.0.0 + */ + template::value, int>::type = 0> + auto get() noexcept -> decltype(std::declval().template get_ptr()) + { + // delegate the call to get_ptr + return get_ptr(); + } + + /*! + @brief get a value (explicit) + + Explicit type conversion between the JSON value and a compatible value. + The value is filled into the input parameter by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + ValueType v; + JSONSerializer::from_json(*this, v); + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json, + - @ref json_serializer has a `from_json()` method of the form + `void from_json(const basic_json&, ValueType&)`, and + + @tparam ValueType the input parameter type. + + @return the input parameter, allowing chaining calls. + + @throw what @ref json_serializer `from_json()` method throws + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,get_to} + + @since version 3.3.0 + */ + template < typename ValueType, + detail::enable_if_t < + !detail::is_basic_json::value&& + detail::has_from_json::value, + int > = 0 > + ValueType & get_to(ValueType& v) const noexcept(noexcept( + JSONSerializer::from_json(std::declval(), v))) + { + JSONSerializer::from_json(*this, v); + return v; + } + + // specialization to allow to call get_to with a basic_json value + // see https://github.com/nlohmann/json/issues/2175 + template::value, + int> = 0> + ValueType & get_to(ValueType& v) const + { + v = *this; + return v; + } + + template < + typename T, std::size_t N, + typename Array = T (&)[N], // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) + detail::enable_if_t < + detail::has_from_json::value, int > = 0 > + Array get_to(T (&v)[N]) const // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) + noexcept(noexcept(JSONSerializer::from_json( + std::declval(), v))) + { + JSONSerializer::from_json(*this, v); + return v; + } + + /*! + @brief get a reference value (implicit) + + Implicit reference access to the internally stored JSON value. No copies + are made. + + @warning Writing data to the referee of the result yields an undefined + state. + + @tparam ReferenceType reference type; must be a reference to @ref array_t, + @ref object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, or + @ref number_float_t. Enforced by static assertion. + + @return reference to the internally stored JSON value if the requested + reference type @a ReferenceType fits to the JSON value; throws + type_error.303 otherwise + + @throw type_error.303 in case passed type @a ReferenceType is incompatible + with the stored JSON value; see example below + + @complexity Constant. + + @liveexample{The example shows several calls to `get_ref()`.,get_ref} + + @since version 1.1.0 + */ + template::value, int>::type = 0> + ReferenceType get_ref() + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a reference value (implicit) + @copydoc get_ref() + */ + template < typename ReferenceType, typename std::enable_if < + std::is_reference::value&& + std::is_const::type>::value, int >::type = 0 > + ReferenceType get_ref() const + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a value (implicit) + + Implicit type conversion between the JSON value and a compatible value. + The call is realized by calling @ref get() const. + + @tparam ValueType non-pointer type compatible to the JSON value, for + instance `int` for JSON integer numbers, `bool` for JSON booleans, or + `std::vector` types for JSON arrays. The character type of @ref string_t + as well as an initializer list of this type is excluded to avoid + ambiguities as these types implicitly convert to `std::string`. + + @return copy of the JSON value, converted to type @a ValueType + + @throw type_error.302 in case passed type @a ValueType is incompatible + to the JSON value type (e.g., the JSON value is of type boolean, but a + string is requested); see example below + + @complexity Linear in the size of the JSON value. + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,operator__ValueType} + + @since version 1.0.0 + */ + template < typename ValueType, typename std::enable_if < + detail::conjunction < + detail::negation>, + detail::negation>>, + detail::negation>, + detail::negation>, + detail::negation>>, + +#if defined(JSON_HAS_CPP_17) && (defined(__GNUC__) || (defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER <= 1914)) + detail::negation>, +#endif + detail::is_detected_lazy + >::value, int >::type = 0 > + JSON_EXPLICIT operator ValueType() const + { + // delegate the call to get<>() const + return get(); + } + + /*! + @return reference to the binary value + + @throw type_error.302 if the value is not binary + + @sa see @ref is_binary() to check if the value is binary + + @since version 3.8.0 + */ + binary_t& get_binary() + { + if (!is_binary()) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()), *this)); + } + + return *get_ptr(); + } + + /// @copydoc get_binary() + const binary_t& get_binary() const + { + if (!is_binary()) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()), *this)); + } + + return *get_ptr(); + } + + /// @} + + + //////////////////// + // element access // + //////////////////// + + /// @name element access + /// Access to the JSON value. + /// @{ + + /*! + @brief access specified array element with bounds checking + + Returns a reference to the element at specified location @a idx, with + bounds checking. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__size_type} + */ + reference at(size_type idx) + { + // at only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + JSON_TRY + { + return set_parent(m_value.array->at(idx)); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range", *this)); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief access specified array element with bounds checking + + Returns a const reference to the element at specified location @a idx, + with bounds checking. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__size_type_const} + */ + const_reference at(size_type idx) const + { + // at only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + JSON_TRY + { + return m_value.array->at(idx); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range", *this)); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a reference to the element at with specified key @a key, with + bounds checking. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa see @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa see @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__object_t_key_type} + */ + reference at(const typename object_t::key_type& key) + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_TRY + { + return set_parent(m_value.object->at(key)); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found", *this)); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a const reference to the element at with specified key @a key, + with bounds checking. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa see @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa see @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__object_t_key_type_const} + */ + const_reference at(const typename object_t::key_type& key) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_TRY + { + return m_value.object->at(key); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found", *this)); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief access specified array element + + Returns a reference to the element at specified location @a idx. + + @note If @a idx is beyond the range of the array (i.e., `idx >= size()`), + then the array is silently filled up with `null` values to make `idx` a + valid reference to the last stored element. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array or null; in that + cases, using the [] operator with an index makes no sense. + + @complexity Constant if @a idx is in the range of the array. Otherwise + linear in `idx - size()`. + + @liveexample{The example below shows how array elements can be read and + written using `[]` operator. Note the addition of `null` + values.,operatorarray__size_type} + + @since version 1.0.0 + */ + reference operator[](size_type idx) + { + // implicitly convert null value to an empty array + if (is_null()) + { + m_type = value_t::array; + m_value.array = create(); + assert_invariant(); + } + + // operator[] only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // fill up array with null values if given idx is outside range + if (idx >= m_value.array->size()) + { +#if JSON_DIAGNOSTICS + // remember array size & capacity before resizing + const auto old_size = m_value.array->size(); + const auto old_capacity = m_value.array->capacity(); +#endif + m_value.array->resize(idx + 1); + +#if JSON_DIAGNOSTICS + if (JSON_HEDLEY_UNLIKELY(m_value.array->capacity() != old_capacity)) + { + // capacity has changed: update all parents + set_parents(); + } + else + { + // set parent for values added above + set_parents(begin() + static_cast(old_size), static_cast(idx + 1 - old_size)); + } +#endif + assert_invariant(); + } + + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()), *this)); + } + + /*! + @brief access specified array element + + Returns a const reference to the element at specified location @a idx. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array; in that case, + using the [] operator with an index makes no sense. + + @complexity Constant. + + @liveexample{The example below shows how array elements can be read using + the `[]` operator.,operatorarray__size_type_const} + + @since version 1.0.0 + */ + const_reference operator[](size_type idx) const + { + // const operator[] only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()), *this)); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa see @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa see @ref value() for access by value with a default value + + @since version 1.0.0 + */ + reference operator[](const typename object_t::key_type& key) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + // operator[] only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return set_parent(m_value.object->operator[](key)); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()), *this)); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that case, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa see @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa see @ref value() for access by value with a default value + + @since version 1.0.0 + */ + const_reference operator[](const typename object_t::key_type& key) const + { + // const operator[] only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()), *this)); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa see @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa see @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + reference operator[](T* key) + { + // implicitly convert null to object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return set_parent(m_value.object->operator[](key)); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()), *this)); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that case, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa see @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa see @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + const_reference operator[](T* key) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()), *this)); + } + + /*! + @brief access specified object element with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(key); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const typename object_t::key_type&), this function + does not throw if the given key @a key was not found. + + @note Unlike @ref operator[](const typename object_t::key_type& key), this + function does not implicitly add an element to the position defined by @a + key. This function is furthermore also applicable to const objects. + + @param[in] key key of the element to access + @param[in] default_value the value to return if @a key is not found + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.302 if @a default_value does not match the type of the + value at @a key + @throw type_error.306 if the JSON value is not an object; in that case, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value} + + @sa see @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa see @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + + @since version 1.0.0 + */ + // using std::is_convertible in a std::enable_if will fail when using explicit conversions + template < class ValueType, typename std::enable_if < + detail::is_getable::value + && !std::is_same::value, int >::type = 0 > + ValueType value(const typename object_t::key_type& key, const ValueType& default_value) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + // if key is found, return value and given default value otherwise + const auto it = find(key); + if (it != end()) + { + return it->template get(); + } + + return default_value; + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()), *this)); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const typename object_t::key_type&, const ValueType&) const + */ + string_t value(const typename object_t::key_type& key, const char* default_value) const + { + return value(key, string_t(default_value)); + } + + /*! + @brief access specified object element via JSON Pointer with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(ptr); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const json_pointer&), this function does not throw + if the given key @a key was not found. + + @param[in] ptr a JSON pointer to the element to access + @param[in] default_value the value to return if @a ptr found no value + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.302 if @a default_value does not match the type of the + value at @a ptr + @throw type_error.306 if the JSON value is not an object; in that case, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value_ptr} + + @sa see @ref operator[](const json_pointer&) for unchecked access by reference + + @since version 2.0.2 + */ + template::value, int>::type = 0> + ValueType value(const json_pointer& ptr, const ValueType& default_value) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + // if pointer resolves a value, return it or use default value + JSON_TRY + { + return ptr.get_checked(this).template get(); + } + JSON_INTERNAL_CATCH (out_of_range&) + { + return default_value; + } + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()), *this)); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const json_pointer&, ValueType) const + */ + JSON_HEDLEY_NON_NULL(3) + string_t value(const json_pointer& ptr, const char* default_value) const + { + return value(ptr, string_t(default_value)); + } + + /*! + @brief access the first element + + Returns a reference to the first element in the container. For a JSON + container `c`, the expression `c.front()` is equivalent to `*c.begin()`. + + @return In case of a structured type (array or object), a reference to the + first element is returned. In case of number, string, boolean, or binary + values, a reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on `null` value + + @liveexample{The following code shows an example for `front()`.,front} + + @sa see @ref back() -- access the last element + + @since version 1.0.0 + */ + reference front() + { + return *begin(); + } + + /*! + @copydoc basic_json::front() + */ + const_reference front() const + { + return *cbegin(); + } + + /*! + @brief access the last element + + Returns a reference to the last element in the container. For a JSON + container `c`, the expression `c.back()` is equivalent to + @code {.cpp} + auto tmp = c.end(); + --tmp; + return *tmp; + @endcode + + @return In case of a structured type (array or object), a reference to the + last element is returned. In case of number, string, boolean, or binary + values, a reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on a `null` value. See example + below. + + @liveexample{The following code shows an example for `back()`.,back} + + @sa see @ref front() -- access the first element + + @since version 1.0.0 + */ + reference back() + { + auto tmp = end(); + --tmp; + return *tmp; + } + + /*! + @copydoc basic_json::back() + */ + const_reference back() const + { + auto tmp = cend(); + --tmp; + return *tmp; + } + + /*! + @brief remove element given an iterator + + Removes the element specified by iterator @a pos. The iterator @a pos must + be valid and dereferenceable. Thus the `end()` iterator (which is valid, + but is not dereferenceable) cannot be used as a value for @a pos. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] pos iterator to the element to remove + @return Iterator following the last removed element. If the iterator @a + pos refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.202 if called on an iterator which does not belong + to the current JSON value; example: `"iterator does not fit current + value"` + @throw invalid_iterator.205 if called on a primitive type with invalid + iterator (i.e., any iterator which is not `begin()`); example: `"iterator + out of range"` + + @complexity The complexity depends on the type: + - objects: amortized constant + - arrays: linear in distance between @a pos and the end of the container + - strings and binary: linear in the length of the member + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType} + + @sa see @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa see @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa see @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template < class IteratorType, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type + = 0 > + IteratorType erase(IteratorType pos) + { + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(this != pos.m_object)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value", *this)); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + case value_t::binary: + { + if (JSON_HEDLEY_UNLIKELY(!pos.m_it.primitive_iterator.is_begin())) + { + JSON_THROW(invalid_iterator::create(205, "iterator out of range", *this)); + } + + if (is_string()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.string); + std::allocator_traits::deallocate(alloc, m_value.string, 1); + m_value.string = nullptr; + } + else if (is_binary()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.binary); + std::allocator_traits::deallocate(alloc, m_value.binary, 1); + m_value.binary = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(pos.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(pos.m_it.array_iterator); + break; + } + + case value_t::null: + case value_t::discarded: + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()), *this)); + } + + return result; + } + + /*! + @brief remove elements given an iterator range + + Removes the element specified by the range `[first; last)`. The iterator + @a first does not need to be dereferenceable if `first == last`: erasing + an empty range is a no-op. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] first iterator to the beginning of the range to remove + @param[in] last iterator past the end of the range to remove + @return Iterator following the last removed element. If the iterator @a + second refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.203 if called on iterators which does not belong + to the current JSON value; example: `"iterators do not fit current value"` + @throw invalid_iterator.204 if called on a primitive type with invalid + iterators (i.e., if `first != begin()` and `last != end()`); example: + `"iterators out of range"` + + @complexity The complexity depends on the type: + - objects: `log(size()) + std::distance(first, last)` + - arrays: linear in the distance between @a first and @a last, plus linear + in the distance between @a last and end of the container + - strings and binary: linear in the length of the member + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType_IteratorType} + + @sa see @ref erase(IteratorType) -- removes the element at a given position + @sa see @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa see @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template < class IteratorType, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type + = 0 > + IteratorType erase(IteratorType first, IteratorType last) + { + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(this != first.m_object || this != last.m_object)) + { + JSON_THROW(invalid_iterator::create(203, "iterators do not fit current value", *this)); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + case value_t::binary: + { + if (JSON_HEDLEY_LIKELY(!first.m_it.primitive_iterator.is_begin() + || !last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range", *this)); + } + + if (is_string()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.string); + std::allocator_traits::deallocate(alloc, m_value.string, 1); + m_value.string = nullptr; + } + else if (is_binary()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.binary); + std::allocator_traits::deallocate(alloc, m_value.binary, 1); + m_value.binary = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + case value_t::null: + case value_t::discarded: + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()), *this)); + } + + return result; + } + + /*! + @brief remove element from a JSON object given a key + + Removes elements from a JSON object with the key value @a key. + + @param[in] key value of the elements to remove + + @return Number of elements removed. If @a ObjectType is the default + `std::map` type, the return value will always be `0` (@a key was not + found) or `1` (@a key was found). + + @post References and iterators to the erased elements are invalidated. + Other references and iterators are not affected. + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + + @complexity `log(size()) + count(key)` + + @liveexample{The example shows the effect of `erase()`.,erase__key_type} + + @sa see @ref erase(IteratorType) -- removes the element at a given position + @sa see @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa see @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + size_type erase(const typename object_t::key_type& key) + { + // this erase only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return m_value.object->erase(key); + } + + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()), *this)); + } + + /*! + @brief remove element from a JSON array given an index + + Removes element from a JSON array at the index @a idx. + + @param[in] idx index of the element to remove + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + @throw out_of_range.401 when `idx >= size()`; example: `"array index 17 + is out of range"` + + @complexity Linear in distance between @a idx and the end of the container. + + @liveexample{The example shows the effect of `erase()`.,erase__size_type} + + @sa see @ref erase(IteratorType) -- removes the element at a given position + @sa see @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa see @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + + @since version 1.0.0 + */ + void erase(const size_type idx) + { + // this erase only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + if (JSON_HEDLEY_UNLIKELY(idx >= size())) + { + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range", *this)); + } + + m_value.array->erase(m_value.array->begin() + static_cast(idx)); + } + else + { + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()), *this)); + } + } + + /// @} + + + //////////// + // lookup // + //////////// + + /// @name lookup + /// @{ + + /*! + @brief find an element in a JSON object + + Finds an element in a JSON object with key equivalent to @a key. If the + element is not found or the JSON value is not an object, end() is + returned. + + @note This method always returns @ref end() when executed on a JSON type + that is not an object. + + @param[in] key key value of the element to search for. + + @return Iterator to an element with key equivalent to @a key. If no such + element is found or the JSON value is not an object, past-the-end (see + @ref end()) iterator is returned. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `find()` is used.,find__key_type} + + @sa see @ref contains(KeyT&&) const -- checks whether a key exists + + @since version 1.0.0 + */ + template + iterator find(KeyT&& key) + { + auto result = end(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(std::forward(key)); + } + + return result; + } + + /*! + @brief find an element in a JSON object + @copydoc find(KeyT&&) + */ + template + const_iterator find(KeyT&& key) const + { + auto result = cend(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(std::forward(key)); + } + + return result; + } + + /*! + @brief returns the number of occurrences of a key in a JSON object + + Returns the number of elements with key @a key. If ObjectType is the + default `std::map` type, the return value will always be `0` (@a key was + not found) or `1` (@a key was found). + + @note This method always returns `0` when executed on a JSON type that is + not an object. + + @param[in] key key value of the element to count + + @return Number of elements with key @a key. If the JSON value is not an + object, the return value will be `0`. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `count()` is used.,count} + + @since version 1.0.0 + */ + template + size_type count(KeyT&& key) const + { + // return 0 for all nonobject types + return is_object() ? m_value.object->count(std::forward(key)) : 0; + } + + /*! + @brief check the existence of an element in a JSON object + + Check whether an element exists in a JSON object with key equivalent to + @a key. If the element is not found or the JSON value is not an object, + false is returned. + + @note This method always returns false when executed on a JSON type + that is not an object. + + @param[in] key key value to check its existence. + + @return true if an element with specified @a key exists. If no such + element with such key is found or the JSON value is not an object, + false is returned. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The following code shows an example for `contains()`.,contains} + + @sa see @ref find(KeyT&&) -- returns an iterator to an object element + @sa see @ref contains(const json_pointer&) const -- checks the existence for a JSON pointer + + @since version 3.6.0 + */ + template < typename KeyT, typename std::enable_if < + !std::is_same::type, json_pointer>::value, int >::type = 0 > + bool contains(KeyT && key) const + { + return is_object() && m_value.object->find(std::forward(key)) != m_value.object->end(); + } + + /*! + @brief check the existence of an element in a JSON object given a JSON pointer + + Check whether the given JSON pointer @a ptr can be resolved in the current + JSON value. + + @note This method can be executed on any JSON value type. + + @param[in] ptr JSON pointer to check its existence. + + @return true if the JSON pointer can be resolved to a stored value, false + otherwise. + + @post If `j.contains(ptr)` returns true, it is safe to call `j[ptr]`. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The following code shows an example for `contains()`.,contains_json_pointer} + + @sa see @ref contains(KeyT &&) const -- checks the existence of a key + + @since version 3.7.0 + */ + bool contains(const json_pointer& ptr) const + { + return ptr.contains(this); + } + + /// @} + + + /////////////// + // iterators // + /////////////// + + /// @name iterators + /// @{ + + /*! + @brief returns an iterator to the first element + + Returns an iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `begin()`.,begin} + + @sa see @ref cbegin() -- returns a const iterator to the beginning + @sa see @ref end() -- returns an iterator to the end + @sa see @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + iterator begin() noexcept + { + iterator result(this); + result.set_begin(); + return result; + } + + /*! + @copydoc basic_json::cbegin() + */ + const_iterator begin() const noexcept + { + return cbegin(); + } + + /*! + @brief returns a const iterator to the first element + + Returns a const iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).begin()`. + + @liveexample{The following code shows an example for `cbegin()`.,cbegin} + + @sa see @ref begin() -- returns an iterator to the beginning + @sa see @ref end() -- returns an iterator to the end + @sa see @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + const_iterator cbegin() const noexcept + { + const_iterator result(this); + result.set_begin(); + return result; + } + + /*! + @brief returns an iterator to one past the last element + + Returns an iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `end()`.,end} + + @sa see @ref cend() -- returns a const iterator to the end + @sa see @ref begin() -- returns an iterator to the beginning + @sa see @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + iterator end() noexcept + { + iterator result(this); + result.set_end(); + return result; + } + + /*! + @copydoc basic_json::cend() + */ + const_iterator end() const noexcept + { + return cend(); + } + + /*! + @brief returns a const iterator to one past the last element + + Returns a const iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).end()`. + + @liveexample{The following code shows an example for `cend()`.,cend} + + @sa see @ref end() -- returns an iterator to the end + @sa see @ref begin() -- returns an iterator to the beginning + @sa see @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + const_iterator cend() const noexcept + { + const_iterator result(this); + result.set_end(); + return result; + } + + /*! + @brief returns an iterator to the reverse-beginning + + Returns an iterator to the reverse-beginning; that is, the last element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(end())`. + + @liveexample{The following code shows an example for `rbegin()`.,rbegin} + + @sa see @ref crbegin() -- returns a const reverse iterator to the beginning + @sa see @ref rend() -- returns a reverse iterator to the end + @sa see @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + reverse_iterator rbegin() noexcept + { + return reverse_iterator(end()); + } + + /*! + @copydoc basic_json::crbegin() + */ + const_reverse_iterator rbegin() const noexcept + { + return crbegin(); + } + + /*! + @brief returns an iterator to the reverse-end + + Returns an iterator to the reverse-end; that is, one before the first + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(begin())`. + + @liveexample{The following code shows an example for `rend()`.,rend} + + @sa see @ref crend() -- returns a const reverse iterator to the end + @sa see @ref rbegin() -- returns a reverse iterator to the beginning + @sa see @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + reverse_iterator rend() noexcept + { + return reverse_iterator(begin()); + } + + /*! + @copydoc basic_json::crend() + */ + const_reverse_iterator rend() const noexcept + { + return crend(); + } + + /*! + @brief returns a const reverse iterator to the last element + + Returns a const iterator to the reverse-beginning; that is, the last + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rbegin()`. + + @liveexample{The following code shows an example for `crbegin()`.,crbegin} + + @sa see @ref rbegin() -- returns a reverse iterator to the beginning + @sa see @ref rend() -- returns a reverse iterator to the end + @sa see @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + const_reverse_iterator crbegin() const noexcept + { + return const_reverse_iterator(cend()); + } + + /*! + @brief returns a const reverse iterator to one before the first + + Returns a const reverse iterator to the reverse-end; that is, one before + the first element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rend()`. + + @liveexample{The following code shows an example for `crend()`.,crend} + + @sa see @ref rend() -- returns a reverse iterator to the end + @sa see @ref rbegin() -- returns a reverse iterator to the beginning + @sa see @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + const_reverse_iterator crend() const noexcept + { + return const_reverse_iterator(cbegin()); + } + + public: + /*! + @brief wrapper to access iterator member functions in range-based for + + This function allows to access @ref iterator::key() and @ref + iterator::value() during range-based for loops. In these loops, a + reference to the JSON values is returned, so there is no access to the + underlying iterator. + + For loop without iterator_wrapper: + + @code{cpp} + for (auto it = j_object.begin(); it != j_object.end(); ++it) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + Range-based for loop without iterator proxy: + + @code{cpp} + for (auto it : j_object) + { + // "it" is of type json::reference and has no key() member + std::cout << "value: " << it << '\n'; + } + @endcode + + Range-based for loop with iterator proxy: + + @code{cpp} + for (auto it : json::iterator_wrapper(j_object)) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + @note When iterating over an array, `key()` will return the index of the + element as string (see example). + + @param[in] ref reference to a JSON value + @return iteration proxy object wrapping @a ref with an interface to use in + range-based for loops + + @liveexample{The following code shows how the wrapper is used,iterator_wrapper} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @note The name of this function is not yet final and may change in the + future. + + @deprecated This stream operator is deprecated and will be removed in + future 4.0.0 of the library. Please use @ref items() instead; + that is, replace `json::iterator_wrapper(j)` with `j.items()`. + */ + JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) + static iteration_proxy iterator_wrapper(reference ref) noexcept + { + return ref.items(); + } + + /*! + @copydoc iterator_wrapper(reference) + */ + JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) + static iteration_proxy iterator_wrapper(const_reference ref) noexcept + { + return ref.items(); + } + + /*! + @brief helper to access iterator member functions in range-based for + + This function allows to access @ref iterator::key() and @ref + iterator::value() during range-based for loops. In these loops, a + reference to the JSON values is returned, so there is no access to the + underlying iterator. + + For loop without `items()` function: + + @code{cpp} + for (auto it = j_object.begin(); it != j_object.end(); ++it) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + Range-based for loop without `items()` function: + + @code{cpp} + for (auto it : j_object) + { + // "it" is of type json::reference and has no key() member + std::cout << "value: " << it << '\n'; + } + @endcode + + Range-based for loop with `items()` function: + + @code{cpp} + for (auto& el : j_object.items()) + { + std::cout << "key: " << el.key() << ", value:" << el.value() << '\n'; + } + @endcode + + The `items()` function also allows to use + [structured bindings](https://en.cppreference.com/w/cpp/language/structured_binding) + (C++17): + + @code{cpp} + for (auto& [key, val] : j_object.items()) + { + std::cout << "key: " << key << ", value:" << val << '\n'; + } + @endcode + + @note When iterating over an array, `key()` will return the index of the + element as string (see example). For primitive types (e.g., numbers), + `key()` returns an empty string. + + @warning Using `items()` on temporary objects is dangerous. Make sure the + object's lifetime exeeds the iteration. See + for more + information. + + @return iteration proxy object wrapping @a ref with an interface to use in + range-based for loops + + @liveexample{The following code shows how the function is used.,items} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 3.1.0, structured bindings support since 3.5.0. + */ + iteration_proxy items() noexcept + { + return iteration_proxy(*this); + } + + /*! + @copydoc items() + */ + iteration_proxy items() const noexcept + { + return iteration_proxy(*this); + } + + /// @} + + + ////////////// + // capacity // + ////////////// + + /// @name capacity + /// @{ + + /*! + @brief checks whether the container is empty. + + Checks if a JSON value has no elements (i.e. whether its @ref size is `0`). + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `true` + boolean | `false` + string | `false` + number | `false` + binary | `false` + object | result of function `object_t::empty()` + array | result of function `array_t::empty()` + + @liveexample{The following code uses `empty()` to check if a JSON + object contains any elements.,empty} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `empty()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return whether a string stored as JSON value + is empty - it returns whether the JSON container itself is empty which is + false in the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `begin() == end()`. + + @sa see @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + bool empty() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return true; + } + + case value_t::array: + { + // delegate call to array_t::empty() + return m_value.array->empty(); + } + + case value_t::object: + { + // delegate call to object_t::empty() + return m_value.object->empty(); + } + + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + // all other types are nonempty + return false; + } + } + } + + /*! + @brief returns the number of elements + + Returns the number of elements in a JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` + boolean | `1` + string | `1` + number | `1` + binary | `1` + object | result of function object_t::size() + array | result of function array_t::size() + + @liveexample{The following code calls `size()` on the different value + types.,size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their size() functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return the length of a string stored as JSON + value - it returns the number of elements in the JSON value which is 1 in + the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `std::distance(begin(), end())`. + + @sa see @ref empty() -- checks whether the container is empty + @sa see @ref max_size() -- returns the maximal number of elements + + @since version 1.0.0 + */ + size_type size() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return 0; + } + + case value_t::array: + { + // delegate call to array_t::size() + return m_value.array->size(); + } + + case value_t::object: + { + // delegate call to object_t::size() + return m_value.object->size(); + } + + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + // all other types have size 1 + return 1; + } + } + } + + /*! + @brief returns the maximum possible number of elements + + Returns the maximum number of elements a JSON value is able to hold due to + system or library implementation limitations, i.e. `std::distance(begin(), + end())` for the JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` (same as `size()`) + boolean | `1` (same as `size()`) + string | `1` (same as `size()`) + number | `1` (same as `size()`) + binary | `1` (same as `size()`) + object | result of function `object_t::max_size()` + array | result of function `array_t::max_size()` + + @liveexample{The following code calls `max_size()` on the different value + types. Note the output is implementation specific.,max_size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `max_size()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of returning `b.size()` where `b` is the largest + possible JSON value. + + @sa see @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + size_type max_size() const noexcept + { + switch (m_type) + { + case value_t::array: + { + // delegate call to array_t::max_size() + return m_value.array->max_size(); + } + + case value_t::object: + { + // delegate call to object_t::max_size() + return m_value.object->max_size(); + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + // all other types have max_size() == size() + return size(); + } + } + } + + /// @} + + + /////////////// + // modifiers // + /////////////// + + /// @name modifiers + /// @{ + + /*! + @brief clears the contents + + Clears the content of a JSON value and resets it to the default value as + if @ref basic_json(value_t) would have been called with the current value + type from @ref type(): + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + binary | An empty byte vector + object | `{}` + array | `[]` + + @post Has the same effect as calling + @code {.cpp} + *this = basic_json(type()); + @endcode + + @liveexample{The example below shows the effect of `clear()` to different + JSON types.,clear} + + @complexity Linear in the size of the JSON value. + + @iterators All iterators, pointers and references related to this container + are invalidated. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @sa see @ref basic_json(value_t) -- constructor that creates an object with the + same value than calling `clear()` + + @since version 1.0.0 + */ + void clear() noexcept + { + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = 0; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = 0; + break; + } + + case value_t::number_float: + { + m_value.number_float = 0.0; + break; + } + + case value_t::boolean: + { + m_value.boolean = false; + break; + } + + case value_t::string: + { + m_value.string->clear(); + break; + } + + case value_t::binary: + { + m_value.binary->clear(); + break; + } + + case value_t::array: + { + m_value.array->clear(); + break; + } + + case value_t::object: + { + m_value.object->clear(); + break; + } + + case value_t::null: + case value_t::discarded: + default: + break; + } + } + + /*! + @brief add an object to an array + + Appends the given element @a val to the end of the JSON value. If the + function is called on a JSON null value, an empty array is created before + appending @a val. + + @param[in] val the value to add to the JSON array + + @throw type_error.308 when called on a type other than JSON array or + null; example: `"cannot use push_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON array. Note how the `null` value was silently + converted to a JSON array.,push_back} + + @since version 1.0.0 + */ + void push_back(basic_json&& val) + { + // push_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()), *this)); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (move semantics) + const auto old_capacity = m_value.array->capacity(); + m_value.array->push_back(std::move(val)); + set_parent(m_value.array->back(), old_capacity); + // if val is moved from, basic_json move constructor marks it null so we do not call the destructor + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(basic_json&& val) + { + push_back(std::move(val)); + return *this; + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + void push_back(const basic_json& val) + { + // push_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()), *this)); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array + const auto old_capacity = m_value.array->capacity(); + m_value.array->push_back(val); + set_parent(m_value.array->back(), old_capacity); + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(const basic_json& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + Inserts the given element @a val to the JSON object. If the function is + called on a JSON null value, an empty object is created before inserting + @a val. + + @param[in] val the value to add to the JSON object + + @throw type_error.308 when called on a type other than JSON object or + null; example: `"cannot use push_back() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON object. Note how the `null` value was silently + converted to a JSON object.,push_back__object_t__value} + + @since version 1.0.0 + */ + void push_back(const typename object_t::value_type& val) + { + // push_back only works for null objects or objects + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()), *this)); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to object + auto res = m_value.object->insert(val); + set_parent(res.first->second); + } + + /*! + @brief add an object to an object + @copydoc push_back(const typename object_t::value_type&) + */ + reference operator+=(const typename object_t::value_type& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + This function allows to use `push_back` with an initializer list. In case + + 1. the current value is an object, + 2. the initializer list @a init contains only two elements, and + 3. the first element of @a init is a string, + + @a init is converted into an object element and added using + @ref push_back(const typename object_t::value_type&). Otherwise, @a init + is converted to a JSON value and added using @ref push_back(basic_json&&). + + @param[in] init an initializer list + + @complexity Linear in the size of the initializer list @a init. + + @note This function is required to resolve an ambiguous overload error, + because pairs like `{"key", "value"}` can be both interpreted as + `object_t::value_type` or `std::initializer_list`, see + https://github.com/nlohmann/json/issues/235 for more information. + + @liveexample{The example shows how initializer lists are treated as + objects when possible.,push_back__initializer_list} + */ + void push_back(initializer_list_t init) + { + if (is_object() && init.size() == 2 && (*init.begin())->is_string()) + { + basic_json&& key = init.begin()->moved_or_copied(); + push_back(typename object_t::value_type( + std::move(key.get_ref()), (init.begin() + 1)->moved_or_copied())); + } + else + { + push_back(basic_json(init)); + } + } + + /*! + @brief add an object to an object + @copydoc push_back(initializer_list_t) + */ + reference operator+=(initializer_list_t init) + { + push_back(init); + return *this; + } + + /*! + @brief add an object to an array + + Creates a JSON value from the passed parameters @a args to the end of the + JSON value. If the function is called on a JSON null value, an empty array + is created before appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @return reference to the inserted element + + @throw type_error.311 when called on a type other than JSON array or + null; example: `"cannot use emplace_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` can be used to add + elements to a JSON array. Note how the `null` value was silently converted + to a JSON array.,emplace_back} + + @since version 2.0.8, returns reference since 3.7.0 + */ + template + reference emplace_back(Args&& ... args) + { + // emplace_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace_back() with " + std::string(type_name()), *this)); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (perfect forwarding) + const auto old_capacity = m_value.array->capacity(); + m_value.array->emplace_back(std::forward(args)...); + return set_parent(m_value.array->back(), old_capacity); + } + + /*! + @brief add an object to an object if key does not exist + + Inserts a new element into a JSON object constructed in-place with the + given @a args if there is no element with the key in the container. If the + function is called on a JSON null value, an empty object is created before + appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @return a pair consisting of an iterator to the inserted element, or the + already-existing element if no insertion happened, and a bool + denoting whether the insertion took place. + + @throw type_error.311 when called on a type other than JSON object or + null; example: `"cannot use emplace() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `emplace()` can be used to add elements + to a JSON object. Note how the `null` value was silently converted to a + JSON object. Further note how no value is added if there was already one + value stored with the same key.,emplace} + + @since version 2.0.8 + */ + template + std::pair emplace(Args&& ... args) + { + // emplace only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace() with " + std::string(type_name()), *this)); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to array (perfect forwarding) + auto res = m_value.object->emplace(std::forward(args)...); + set_parent(res.first->second); + + // create result iterator and set iterator to the result of emplace + auto it = begin(); + it.m_it.object_iterator = res.first; + + // return pair of iterator and boolean + return {it, res.second}; + } + + /// Helper for insertion of an iterator + /// @note: This uses std::distance to support GCC 4.8, + /// see https://github.com/nlohmann/json/pull/1257 + template + iterator insert_iterator(const_iterator pos, Args&& ... args) + { + iterator result(this); + JSON_ASSERT(m_value.array != nullptr); + + auto insert_pos = std::distance(m_value.array->begin(), pos.m_it.array_iterator); + m_value.array->insert(pos.m_it.array_iterator, std::forward(args)...); + result.m_it.array_iterator = m_value.array->begin() + insert_pos; + + // This could have been written as: + // result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, cnt, val); + // but the return value of insert is missing in GCC 4.8, so it is written this way instead. + + set_parents(); + return result; + } + + /*! + @brief inserts element + + Inserts element @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] val element to insert + @return iterator pointing to the inserted @a val. + + @throw type_error.309 if called on JSON values other than arrays; + example: `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Constant plus linear in the distance between @a pos and end of + the container. + + @liveexample{The example shows how `insert()` is used.,insert} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const basic_json& val) + { + // insert only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value", *this)); + } + + // insert to array and return iterator + return insert_iterator(pos, val); + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()), *this)); + } + + /*! + @brief inserts element + @copydoc insert(const_iterator, const basic_json&) + */ + iterator insert(const_iterator pos, basic_json&& val) + { + return insert(pos, val); + } + + /*! + @brief inserts elements + + Inserts @a cnt copies of @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] cnt number of copies of @a val to insert + @param[in] val element to insert + @return iterator pointing to the first element inserted, or @a pos if + `cnt==0` + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Linear in @a cnt plus linear in the distance between @a pos + and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__count} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, size_type cnt, const basic_json& val) + { + // insert only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value", *this)); + } + + // insert to array and return iterator + return insert_iterator(pos, cnt, val); + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()), *this)); + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)` before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + @throw invalid_iterator.211 if @a first or @a last are iterators into + container for which insert is called; example: `"passed iterators may not + belong to container"` + + @return iterator pointing to the first element inserted, or @a pos if + `first==last` + + @complexity Linear in `std::distance(first, last)` plus linear in the + distance between @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__range} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const_iterator first, const_iterator last) + { + // insert only works for arrays + if (JSON_HEDLEY_UNLIKELY(!is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()), *this)); + } + + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value", *this)); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit", *this)); + } + + if (JSON_HEDLEY_UNLIKELY(first.m_object == this)) + { + JSON_THROW(invalid_iterator::create(211, "passed iterators may not belong to container", *this)); + } + + // insert to array and return iterator + return insert_iterator(pos, first.m_it.array_iterator, last.m_it.array_iterator); + } + + /*! + @brief inserts elements + + Inserts elements from initializer list @a ilist before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] ilist initializer list to insert the values from + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @return iterator pointing to the first element inserted, or @a pos if + `ilist` is empty + + @complexity Linear in `ilist.size()` plus linear in the distance between + @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__ilist} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, initializer_list_t ilist) + { + // insert only works for arrays + if (JSON_HEDLEY_UNLIKELY(!is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()), *this)); + } + + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value", *this)); + } + + // insert to array and return iterator + return insert_iterator(pos, ilist.begin(), ilist.end()); + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)`. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than objects; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if iterator @a first or @a last does does not + point to an object; example: `"iterators first and last must point to + objects"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity Logarithmic: `O(N*log(size() + N))`, where `N` is the number + of elements to insert. + + @liveexample{The example shows how `insert()` is used.,insert__range_object} + + @since version 3.0.0 + */ + void insert(const_iterator first, const_iterator last) + { + // insert only works for objects + if (JSON_HEDLEY_UNLIKELY(!is_object())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()), *this)); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit", *this)); + } + + // passed iterators must belong to objects + if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object())) + { + JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects", *this)); + } + + m_value.object->insert(first.m_it.object_iterator, last.m_it.object_iterator); + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from JSON object @a j and overwrites existing keys. + + @param[in] j JSON object to read values from + @param[in] merge_objects when true, existing keys are not overwritten, but + contents of objects are merged recursively + (default: false) + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0, `merge_objects` parameter added in 3.10.4. + */ + void update(const_reference j, bool merge_objects = false) + { + update(j.begin(), j.end(), merge_objects); + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from from range `[first, last)` and overwrites existing + keys. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + @param[in] merge_objects when true, existing keys are not overwritten, but + contents of objects are merged recursively + (default: false) + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + @throw type_error.312 if iterator @a first or @a last does does not + point to an object; example: `"cannot use update() with string"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used__range.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0, `merge_objects` parameter added in 3.10.4. + */ + void update(const_iterator first, const_iterator last, bool merge_objects = false) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + if (JSON_HEDLEY_UNLIKELY(!is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()), *this)); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit", *this)); + } + + // passed iterators must belong to objects + if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(first.m_object->type_name()), *first.m_object)); + } + + for (auto it = first; it != last; ++it) + { + if (merge_objects && it.value().is_object()) + { + auto it2 = m_value.object->find(it.key()); + if (it2 != m_value.object->end()) + { + it2->second.update(it.value(), true); + continue; + } + } + m_value.object->operator[](it.key()) = it.value(); +#if JSON_DIAGNOSTICS + m_value.object->operator[](it.key()).m_parent = this; +#endif + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of the JSON value with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other JSON value to exchange the contents with + + @complexity Constant. + + @liveexample{The example below shows how JSON values can be swapped with + `swap()`.,swap__reference} + + @since version 1.0.0 + */ + void swap(reference other) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + std::swap(m_type, other.m_type); + std::swap(m_value, other.m_value); + + set_parents(); + other.set_parents(); + assert_invariant(); + } + + /*! + @brief exchanges the values + + Exchanges the contents of the JSON value from @a left with those of @a right. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. implemented as a friend function callable via ADL. + + @param[in,out] left JSON value to exchange the contents with + @param[in,out] right JSON value to exchange the contents with + + @complexity Constant. + + @liveexample{The example below shows how JSON values can be swapped with + `swap()`.,swap__reference} + + @since version 1.0.0 + */ + friend void swap(reference left, reference right) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + left.swap(right); + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON array with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other array to exchange the contents with + + @throw type_error.310 when JSON value is not an array; example: `"cannot + use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how arrays can be swapped with + `swap()`.,swap__array_t} + + @since version 1.0.0 + */ + void swap(array_t& other) // NOLINT(bugprone-exception-escape) + { + // swap only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + std::swap(*(m_value.array), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON object with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other object to exchange the contents with + + @throw type_error.310 when JSON value is not an object; example: + `"cannot use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how objects can be swapped with + `swap()`.,swap__object_t} + + @since version 1.0.0 + */ + void swap(object_t& other) // NOLINT(bugprone-exception-escape) + { + // swap only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + std::swap(*(m_value.object), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON string with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other string to exchange the contents with + + @throw type_error.310 when JSON value is not a string; example: `"cannot + use swap() with boolean"` + + @complexity Constant. + + @liveexample{The example below shows how strings can be swapped with + `swap()`.,swap__string_t} + + @since version 1.0.0 + */ + void swap(string_t& other) // NOLINT(bugprone-exception-escape) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_string())) + { + std::swap(*(m_value.string), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()), *this)); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON string with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other binary to exchange the contents with + + @throw type_error.310 when JSON value is not a string; example: `"cannot + use swap() with boolean"` + + @complexity Constant. + + @liveexample{The example below shows how strings can be swapped with + `swap()`.,swap__binary_t} + + @since version 3.8.0 + */ + void swap(binary_t& other) // NOLINT(bugprone-exception-escape) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_binary())) + { + std::swap(*(m_value.binary), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()), *this)); + } + } + + /// @copydoc swap(binary_t&) + void swap(typename binary_t::container_type& other) // NOLINT(bugprone-exception-escape) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_binary())) + { + std::swap(*(m_value.binary), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()), *this)); + } + } + + /// @} + + public: + ////////////////////////////////////////// + // lexicographical comparison operators // + ////////////////////////////////////////// + + /// @name lexicographical comparison operators + /// @{ + + /*! + @brief comparison: equal + + Compares two JSON values for equality according to the following rules: + - Two JSON values are equal if (1) they are from the same type and (2) + their stored values are the same according to their respective + `operator==`. + - Integer and floating-point numbers are automatically converted before + comparison. Note that two NaN values are always treated as unequal. + - Two JSON null values are equal. + + @note Floating-point inside JSON values numbers are compared with + `json::number_float_t::operator==` which is `double::operator==` by + default. To compare floating-point while respecting an epsilon, an alternative + [comparison function](https://github.com/mariokonrad/marnav/blob/master/include/marnav/math/floatingpoint.hpp#L34-#L39) + could be used, for instance + @code {.cpp} + template::value, T>::type> + inline bool is_same(T a, T b, T epsilon = std::numeric_limits::epsilon()) noexcept + { + return std::abs(a - b) <= epsilon; + } + @endcode + Or you can self-defined operator equal function like this: + @code {.cpp} + bool my_equal(const_reference lhs, const_reference rhs) { + const auto lhs_type lhs.type(); + const auto rhs_type rhs.type(); + if (lhs_type == rhs_type) { + switch(lhs_type) + // self_defined case + case value_t::number_float: + return std::abs(lhs - rhs) <= std::numeric_limits::epsilon(); + // other cases remain the same with the original + ... + } + ... + } + @endcode + + @note NaN values never compare equal to themselves or to other NaN values. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are equal + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Linear. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__equal} + + @since version 1.0.0 + */ + friend bool operator==(const_reference lhs, const_reference rhs) noexcept + { +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + return *lhs.m_value.array == *rhs.m_value.array; + + case value_t::object: + return *lhs.m_value.object == *rhs.m_value.object; + + case value_t::null: + return true; + + case value_t::string: + return *lhs.m_value.string == *rhs.m_value.string; + + case value_t::boolean: + return lhs.m_value.boolean == rhs.m_value.boolean; + + case value_t::number_integer: + return lhs.m_value.number_integer == rhs.m_value.number_integer; + + case value_t::number_unsigned: + return lhs.m_value.number_unsigned == rhs.m_value.number_unsigned; + + case value_t::number_float: + return lhs.m_value.number_float == rhs.m_value.number_float; + + case value_t::binary: + return *lhs.m_value.binary == *rhs.m_value.binary; + + case value_t::discarded: + default: + return false; + } + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_integer) == rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) + { + return lhs.m_value.number_float == static_cast(rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_float == static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) + { + return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_integer; + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_integer == static_cast(rhs.m_value.number_unsigned); + } + + return false; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(const_reference lhs, ScalarType rhs) noexcept + { + return lhs == basic_json(rhs); + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) == rhs; + } + + /*! + @brief comparison: not equal + + Compares two JSON values for inequality by calculating `not (lhs == rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are not equal + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__notequal} + + @since version 1.0.0 + */ + friend bool operator!=(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs == rhs); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(const_reference lhs, ScalarType rhs) noexcept + { + return lhs != basic_json(rhs); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) != rhs; + } + + /*! + @brief comparison: less than + + Compares whether one JSON value @a lhs is less than another JSON value @a + rhs according to the following rules: + - If @a lhs and @a rhs have the same type, the values are compared using + the default `<` operator. + - Integer and floating-point numbers are automatically converted before + comparison + - In case @a lhs and @a rhs have different types, the values are ignored + and the order of the types is considered, see + @ref operator<(const value_t, const value_t). + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__less} + + @since version 1.0.0 + */ + friend bool operator<(const_reference lhs, const_reference rhs) noexcept + { + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + // note parentheses are necessary, see + // https://github.com/nlohmann/json/issues/1530 + return (*lhs.m_value.array) < (*rhs.m_value.array); + + case value_t::object: + return (*lhs.m_value.object) < (*rhs.m_value.object); + + case value_t::null: + return false; + + case value_t::string: + return (*lhs.m_value.string) < (*rhs.m_value.string); + + case value_t::boolean: + return (lhs.m_value.boolean) < (rhs.m_value.boolean); + + case value_t::number_integer: + return (lhs.m_value.number_integer) < (rhs.m_value.number_integer); + + case value_t::number_unsigned: + return (lhs.m_value.number_unsigned) < (rhs.m_value.number_unsigned); + + case value_t::number_float: + return (lhs.m_value.number_float) < (rhs.m_value.number_float); + + case value_t::binary: + return (*lhs.m_value.binary) < (*rhs.m_value.binary); + + case value_t::discarded: + default: + return false; + } + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_integer) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_integer < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_integer; + } + + // We only reach this line if we cannot compare values. In that case, + // we compare types. Note we have to call the operator explicitly, + // because MSVC has problems otherwise. + return operator<(lhs_type, rhs_type); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(const_reference lhs, ScalarType rhs) noexcept + { + return lhs < basic_json(rhs); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) < rhs; + } + + /*! + @brief comparison: less than or equal + + Compares whether one JSON value @a lhs is less than or equal to another + JSON value by calculating `not (rhs < lhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greater} + + @since version 1.0.0 + */ + friend bool operator<=(const_reference lhs, const_reference rhs) noexcept + { + return !(rhs < lhs); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(const_reference lhs, ScalarType rhs) noexcept + { + return lhs <= basic_json(rhs); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) <= rhs; + } + + /*! + @brief comparison: greater than + + Compares whether one JSON value @a lhs is greater than another + JSON value by calculating `not (lhs <= rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__lessequal} + + @since version 1.0.0 + */ + friend bool operator>(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs <= rhs); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(const_reference lhs, ScalarType rhs) noexcept + { + return lhs > basic_json(rhs); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) > rhs; + } + + /*! + @brief comparison: greater than or equal + + Compares whether one JSON value @a lhs is greater than or equal to another + JSON value by calculating `not (lhs < rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greaterequal} + + @since version 1.0.0 + */ + friend bool operator>=(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs < rhs); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(const_reference lhs, ScalarType rhs) noexcept + { + return lhs >= basic_json(rhs); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) >= rhs; + } + + /// @} + + /////////////////// + // serialization // + /////////////////// + + /// @name serialization + /// @{ +#ifndef JSON_NO_IO + /*! + @brief serialize to stream + + Serialize the given JSON value @a j to the output stream @a o. The JSON + value will be serialized using the @ref dump member function. + + - The indentation of the output can be controlled with the member variable + `width` of the output stream @a o. For instance, using the manipulator + `std::setw(4)` on @a o sets the indentation level to `4` and the + serialization result is the same as calling `dump(4)`. + + - The indentation character can be controlled with the member variable + `fill` of the output stream @a o. For instance, the manipulator + `std::setfill('\\t')` sets indentation to use a tab character rather than + the default space character. + + @param[in,out] o stream to serialize to + @param[in] j JSON value to serialize + + @return the stream @a o + + @throw type_error.316 if a string stored inside the JSON value is not + UTF-8 encoded + + @complexity Linear. + + @liveexample{The example below shows the serialization with different + parameters to `width` to adjust the indentation level.,operator_serialize} + + @since version 1.0.0; indentation character added in version 3.0.0 + */ + friend std::ostream& operator<<(std::ostream& o, const basic_json& j) + { + // read width member and use it as indentation parameter if nonzero + const bool pretty_print = o.width() > 0; + const auto indentation = pretty_print ? o.width() : 0; + + // reset width to 0 for subsequent calls to this stream + o.width(0); + + // do the actual serialization + serializer s(detail::output_adapter(o), o.fill()); + s.dump(j, pretty_print, false, static_cast(indentation)); + return o; + } + + /*! + @brief serialize to stream + @deprecated This stream operator is deprecated and will be removed in + future 4.0.0 of the library. Please use + @ref operator<<(std::ostream&, const basic_json&) + instead; that is, replace calls like `j >> o;` with `o << j;`. + @since version 1.0.0; deprecated since version 3.0.0 + */ + JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator<<(std::ostream&, const basic_json&)) + friend std::ostream& operator>>(const basic_json& j, std::ostream& o) + { + return o << j; + } +#endif // JSON_NO_IO + /// @} + + + ///////////////////// + // deserialization // + ///////////////////// + + /// @name deserialization + /// @{ + + /*! + @brief deserialize from a compatible input + + @tparam InputType A compatible input, for instance + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the parser callback function + @a cb or reading from the input @a i has a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `parse()` function reading + from an array.,parse__array__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__string__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__istream__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function reading + from a contiguous container.,parse__contiguouscontainer__parser_callback_t} + + @since version 2.0.3 (contiguous containers); version 3.9.0 allowed to + ignore comments. + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json parse(InputType&& i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(detail::input_adapter(std::forward(i)), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + /*! + @brief deserialize from a pair of character iterators + + The value_type of the iterator must be a integral type with size of 1, 2 or + 4 bytes, which will be interpreted respectively as UTF-8, UTF-16 and UTF-32. + + @param[in] first iterator to start of character range + @param[in] last iterator to end of character range + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json parse(IteratorType first, + IteratorType last, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(detail::input_adapter(std::move(first), std::move(last)), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, parse(ptr, ptr + len)) + static basic_json parse(detail::span_input_adapter&& i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(i.get(), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + /*! + @brief check if the input is valid JSON + + Unlike the @ref parse(InputType&&, const parser_callback_t,const bool) + function, this function neither throws an exception in case of invalid JSON + input (i.e., a parse error) nor creates diagnostic information. + + @tparam InputType A compatible input, for instance + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return Whether the input read from @a i is valid JSON. + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `accept()` function reading + from a string.,accept__string} + */ + template + static bool accept(InputType&& i, + const bool ignore_comments = false) + { + return parser(detail::input_adapter(std::forward(i)), nullptr, false, ignore_comments).accept(true); + } + + template + static bool accept(IteratorType first, IteratorType last, + const bool ignore_comments = false) + { + return parser(detail::input_adapter(std::move(first), std::move(last)), nullptr, false, ignore_comments).accept(true); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, accept(ptr, ptr + len)) + static bool accept(detail::span_input_adapter&& i, + const bool ignore_comments = false) + { + return parser(i.get(), nullptr, false, ignore_comments).accept(true); + } + + /*! + @brief generate SAX events + + The SAX event lister must follow the interface of @ref json_sax. + + This function reads from a compatible input. Examples are: + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in,out] sax SAX event listener + @param[in] format the format to parse (JSON, CBOR, MessagePack, or UBJSON) + @param[in] strict whether the input has to be consumed completely + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default); only applies to the JSON file format. + + @return return value of the last processed SAX event + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the SAX consumer @a sax has + a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `sax_parse()` function + reading from string and processing the events with a user-defined SAX + event consumer.,sax_parse} + + @since version 3.2.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + static bool sax_parse(InputType&& i, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = detail::input_adapter(std::forward(i)); + return format == input_format_t::json + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } + + template + JSON_HEDLEY_NON_NULL(3) + static bool sax_parse(IteratorType first, IteratorType last, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = detail::input_adapter(std::move(first), std::move(last)); + return format == input_format_t::json + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } + + template + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, sax_parse(ptr, ptr + len, ...)) + JSON_HEDLEY_NON_NULL(2) + static bool sax_parse(detail::span_input_adapter&& i, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = i.get(); + return format == input_format_t::json + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } +#ifndef JSON_NO_IO + /*! + @brief deserialize from stream + @deprecated This stream operator is deprecated and will be removed in + version 4.0.0 of the library. Please use + @ref operator>>(std::istream&, basic_json&) + instead; that is, replace calls like `j << i;` with `i >> j;`. + @since version 1.0.0; deprecated since version 3.0.0 + */ + JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator>>(std::istream&, basic_json&)) + friend std::istream& operator<<(basic_json& j, std::istream& i) + { + return operator>>(i, j); + } + + /*! + @brief deserialize from stream + + Deserializes an input stream to a JSON value. + + @param[in,out] i input stream to read a serialized JSON value from + @param[in,out] j JSON value to write the deserialized input to + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below shows how a JSON value is constructed by + reading a serialization from a stream.,operator_deserialize} + + @sa parse(std::istream&, const parser_callback_t) for a variant with a + parser callback function to filter values while parsing + + @since version 1.0.0 + */ + friend std::istream& operator>>(std::istream& i, basic_json& j) + { + parser(detail::input_adapter(i)).parse(false, j); + return i; + } +#endif // JSON_NO_IO + /// @} + + /////////////////////////// + // convenience functions // + /////////////////////////// + + /*! + @brief return the type as string + + Returns the type name as string to be used in error messages - usually to + indicate that a function was called on a wrong JSON type. + + @return a string representation of a the @a m_type member: + Value type | return value + ----------- | ------------- + null | `"null"` + boolean | `"boolean"` + string | `"string"` + number | `"number"` (for all number types) + object | `"object"` + array | `"array"` + binary | `"binary"` + discarded | `"discarded"` + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Constant. + + @liveexample{The following code exemplifies `type_name()` for all JSON + types.,type_name} + + @sa see @ref type() -- return the type of the JSON value + @sa see @ref operator value_t() -- return the type of the JSON value (implicit) + + @since version 1.0.0, public since 2.1.0, `const char*` and `noexcept` + since 3.0.0 + */ + JSON_HEDLEY_RETURNS_NON_NULL + const char* type_name() const noexcept + { + { + switch (m_type) + { + case value_t::null: + return "null"; + case value_t::object: + return "object"; + case value_t::array: + return "array"; + case value_t::string: + return "string"; + case value_t::boolean: + return "boolean"; + case value_t::binary: + return "binary"; + case value_t::discarded: + return "discarded"; + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + default: + return "number"; + } + } + } + + + JSON_PRIVATE_UNLESS_TESTED: + ////////////////////// + // member variables // + ////////////////////// + + /// the type of the current element + value_t m_type = value_t::null; + + /// the value of the current element + json_value m_value = {}; + +#if JSON_DIAGNOSTICS + /// a pointer to a parent value (for debugging purposes) + basic_json* m_parent = nullptr; +#endif + + ////////////////////////////////////////// + // binary serialization/deserialization // + ////////////////////////////////////////// + + /// @name binary serialization/deserialization support + /// @{ + + public: + /*! + @brief create a CBOR serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the CBOR (Concise + Binary Object Representation) serialization format. CBOR is a binary + serialization format which aims to be more compact than JSON itself, yet + more efficient to parse. + + The library uses the following mapping from JSON values types to + CBOR types according to the CBOR specification (RFC 7049): + + JSON value type | value/range | CBOR type | first byte + --------------- | ------------------------------------------ | ---------------------------------- | --------------- + null | `null` | Null | 0xF6 + boolean | `true` | True | 0xF5 + boolean | `false` | False | 0xF4 + number_integer | -9223372036854775808..-2147483649 | Negative integer (8 bytes follow) | 0x3B + number_integer | -2147483648..-32769 | Negative integer (4 bytes follow) | 0x3A + number_integer | -32768..-129 | Negative integer (2 bytes follow) | 0x39 + number_integer | -128..-25 | Negative integer (1 byte follow) | 0x38 + number_integer | -24..-1 | Negative integer | 0x20..0x37 + number_integer | 0..23 | Integer | 0x00..0x17 + number_integer | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_integer | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_integer | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A + number_integer | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B + number_unsigned | 0..23 | Integer | 0x00..0x17 + number_unsigned | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_unsigned | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_unsigned | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A + number_unsigned | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B + number_float | *any value representable by a float* | Single-Precision Float | 0xFA + number_float | *any value NOT representable by a float* | Double-Precision Float | 0xFB + string | *length*: 0..23 | UTF-8 string | 0x60..0x77 + string | *length*: 23..255 | UTF-8 string (1 byte follow) | 0x78 + string | *length*: 256..65535 | UTF-8 string (2 bytes follow) | 0x79 + string | *length*: 65536..4294967295 | UTF-8 string (4 bytes follow) | 0x7A + string | *length*: 4294967296..18446744073709551615 | UTF-8 string (8 bytes follow) | 0x7B + array | *size*: 0..23 | array | 0x80..0x97 + array | *size*: 23..255 | array (1 byte follow) | 0x98 + array | *size*: 256..65535 | array (2 bytes follow) | 0x99 + array | *size*: 65536..4294967295 | array (4 bytes follow) | 0x9A + array | *size*: 4294967296..18446744073709551615 | array (8 bytes follow) | 0x9B + object | *size*: 0..23 | map | 0xA0..0xB7 + object | *size*: 23..255 | map (1 byte follow) | 0xB8 + object | *size*: 256..65535 | map (2 bytes follow) | 0xB9 + object | *size*: 65536..4294967295 | map (4 bytes follow) | 0xBA + object | *size*: 4294967296..18446744073709551615 | map (8 bytes follow) | 0xBB + binary | *size*: 0..23 | byte string | 0x40..0x57 + binary | *size*: 23..255 | byte string (1 byte follow) | 0x58 + binary | *size*: 256..65535 | byte string (2 bytes follow) | 0x59 + binary | *size*: 65536..4294967295 | byte string (4 bytes follow) | 0x5A + binary | *size*: 4294967296..18446744073709551615 | byte string (8 bytes follow) | 0x5B + + Binary values with subtype are mapped to tagged values (0xD8..0xDB) + depending on the subtype, followed by a byte string, see "binary" cells + in the table above. + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a CBOR value. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @note The following CBOR types are not used in the conversion: + - UTF-8 strings terminated by "break" (0x7F) + - arrays terminated by "break" (0x9F) + - maps terminated by "break" (0xBF) + - byte strings terminated by "break" (0x5F) + - date/time (0xC0..0xC1) + - bignum (0xC2..0xC3) + - decimal fraction (0xC4) + - bigfloat (0xC5) + - expected conversions (0xD5..0xD7) + - simple values (0xE0..0xF3, 0xF8) + - undefined (0xF7) + - half-precision floats (0xF9) + - break (0xFF) + + @param[in] j JSON value to serialize + @return CBOR serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in CBOR format.,to_cbor} + + @sa http://cbor.io + @sa see @ref from_cbor(InputType&&, const bool, const bool, const cbor_tag_handler_t) for the + analogous deserialization + @sa see @ref to_msgpack(const basic_json&) for the related MessagePack format + @sa see @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9; compact representation of floating-point numbers + since version 3.8.0 + */ + static std::vector to_cbor(const basic_json& j) + { + std::vector result; + to_cbor(j, result); + return result; + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + /*! + @brief create a MessagePack serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the MessagePack + serialization format. MessagePack is a binary serialization format which + aims to be more compact than JSON itself, yet more efficient to parse. + + The library uses the following mapping from JSON values types to + MessagePack types according to the MessagePack specification: + + JSON value type | value/range | MessagePack type | first byte + --------------- | --------------------------------- | ---------------- | ---------- + null | `null` | nil | 0xC0 + boolean | `true` | true | 0xC3 + boolean | `false` | false | 0xC2 + number_integer | -9223372036854775808..-2147483649 | int64 | 0xD3 + number_integer | -2147483648..-32769 | int32 | 0xD2 + number_integer | -32768..-129 | int16 | 0xD1 + number_integer | -128..-33 | int8 | 0xD0 + number_integer | -32..-1 | negative fixint | 0xE0..0xFF + number_integer | 0..127 | positive fixint | 0x00..0x7F + number_integer | 128..255 | uint 8 | 0xCC + number_integer | 256..65535 | uint 16 | 0xCD + number_integer | 65536..4294967295 | uint 32 | 0xCE + number_integer | 4294967296..18446744073709551615 | uint 64 | 0xCF + number_unsigned | 0..127 | positive fixint | 0x00..0x7F + number_unsigned | 128..255 | uint 8 | 0xCC + number_unsigned | 256..65535 | uint 16 | 0xCD + number_unsigned | 65536..4294967295 | uint 32 | 0xCE + number_unsigned | 4294967296..18446744073709551615 | uint 64 | 0xCF + number_float | *any value representable by a float* | float 32 | 0xCA + number_float | *any value NOT representable by a float* | float 64 | 0xCB + string | *length*: 0..31 | fixstr | 0xA0..0xBF + string | *length*: 32..255 | str 8 | 0xD9 + string | *length*: 256..65535 | str 16 | 0xDA + string | *length*: 65536..4294967295 | str 32 | 0xDB + array | *size*: 0..15 | fixarray | 0x90..0x9F + array | *size*: 16..65535 | array 16 | 0xDC + array | *size*: 65536..4294967295 | array 32 | 0xDD + object | *size*: 0..15 | fix map | 0x80..0x8F + object | *size*: 16..65535 | map 16 | 0xDE + object | *size*: 65536..4294967295 | map 32 | 0xDF + binary | *size*: 0..255 | bin 8 | 0xC4 + binary | *size*: 256..65535 | bin 16 | 0xC5 + binary | *size*: 65536..4294967295 | bin 32 | 0xC6 + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a MessagePack value. + + @note The following values can **not** be converted to a MessagePack value: + - strings with more than 4294967295 bytes + - byte strings with more than 4294967295 bytes + - arrays with more than 4294967295 elements + - objects with more than 4294967295 elements + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @param[in] j JSON value to serialize + @return MessagePack serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in MessagePack format.,to_msgpack} + + @sa http://msgpack.org + @sa see @ref from_msgpack for the analogous deserialization + @sa see @ref to_cbor(const basic_json& for the related CBOR format + @sa see @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9 + */ + static std::vector to_msgpack(const basic_json& j) + { + std::vector result; + to_msgpack(j, result); + return result; + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + /*! + @brief create a UBJSON serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the UBJSON + (Universal Binary JSON) serialization format. UBJSON aims to be more compact + than JSON itself, yet more efficient to parse. + + The library uses the following mapping from JSON values types to + UBJSON types according to the UBJSON specification: + + JSON value type | value/range | UBJSON type | marker + --------------- | --------------------------------- | ----------- | ------ + null | `null` | null | `Z` + boolean | `true` | true | `T` + boolean | `false` | false | `F` + number_integer | -9223372036854775808..-2147483649 | int64 | `L` + number_integer | -2147483648..-32769 | int32 | `l` + number_integer | -32768..-129 | int16 | `I` + number_integer | -128..127 | int8 | `i` + number_integer | 128..255 | uint8 | `U` + number_integer | 256..32767 | int16 | `I` + number_integer | 32768..2147483647 | int32 | `l` + number_integer | 2147483648..9223372036854775807 | int64 | `L` + number_unsigned | 0..127 | int8 | `i` + number_unsigned | 128..255 | uint8 | `U` + number_unsigned | 256..32767 | int16 | `I` + number_unsigned | 32768..2147483647 | int32 | `l` + number_unsigned | 2147483648..9223372036854775807 | int64 | `L` + number_unsigned | 2147483649..18446744073709551615 | high-precision | `H` + number_float | *any value* | float64 | `D` + string | *with shortest length indicator* | string | `S` + array | *see notes on optimized format* | array | `[` + object | *see notes on optimized format* | map | `{` + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a UBJSON value. + + @note The following values can **not** be converted to a UBJSON value: + - strings with more than 9223372036854775807 bytes (theoretical) + + @note The following markers are not used in the conversion: + - `Z`: no-op values are not created. + - `C`: single-byte strings are serialized with `S` markers. + + @note Any UBJSON output created @ref to_ubjson can be successfully parsed + by @ref from_ubjson. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @note The optimized formats for containers are supported: Parameter + @a use_size adds size information to the beginning of a container and + removes the closing marker. Parameter @a use_type further checks + whether all elements of a container have the same type and adds the + type marker to the beginning of the container. The @a use_type + parameter must only be used together with @a use_size = true. Note + that @a use_size = true alone may result in larger representations - + the benefit of this parameter is that the receiving side is + immediately informed on the number of elements of the container. + + @note If the JSON data contains the binary type, the value stored is a list + of integers, as suggested by the UBJSON documentation. In particular, + this means that serialization and the deserialization of a JSON + containing binary values into UBJSON and back will result in a + different JSON object. + + @param[in] j JSON value to serialize + @param[in] use_size whether to add size annotations to container types + @param[in] use_type whether to add type annotations to container types + (must be combined with @a use_size = true) + @return UBJSON serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in UBJSON format.,to_ubjson} + + @sa http://ubjson.org + @sa see @ref from_ubjson(InputType&&, const bool, const bool) for the + analogous deserialization + @sa see @ref to_cbor(const basic_json& for the related CBOR format + @sa see @ref to_msgpack(const basic_json&) for the related MessagePack format + + @since version 3.1.0 + */ + static std::vector to_ubjson(const basic_json& j, + const bool use_size = false, + const bool use_type = false) + { + std::vector result; + to_ubjson(j, result, use_size, use_type); + return result; + } + + static void to_ubjson(const basic_json& j, detail::output_adapter o, + const bool use_size = false, const bool use_type = false) + { + binary_writer(o).write_ubjson(j, use_size, use_type); + } + + static void to_ubjson(const basic_json& j, detail::output_adapter o, + const bool use_size = false, const bool use_type = false) + { + binary_writer(o).write_ubjson(j, use_size, use_type); + } + + + /*! + @brief Serializes the given JSON object `j` to BSON and returns a vector + containing the corresponding BSON-representation. + + BSON (Binary JSON) is a binary format in which zero or more ordered key/value pairs are + stored as a single entity (a so-called document). + + The library uses the following mapping from JSON values types to BSON types: + + JSON value type | value/range | BSON type | marker + --------------- | --------------------------------- | ----------- | ------ + null | `null` | null | 0x0A + boolean | `true`, `false` | boolean | 0x08 + number_integer | -9223372036854775808..-2147483649 | int64 | 0x12 + number_integer | -2147483648..2147483647 | int32 | 0x10 + number_integer | 2147483648..9223372036854775807 | int64 | 0x12 + number_unsigned | 0..2147483647 | int32 | 0x10 + number_unsigned | 2147483648..9223372036854775807 | int64 | 0x12 + number_unsigned | 9223372036854775808..18446744073709551615| -- | -- + number_float | *any value* | double | 0x01 + string | *any value* | string | 0x02 + array | *any value* | document | 0x04 + object | *any value* | document | 0x03 + binary | *any value* | binary | 0x05 + + @warning The mapping is **incomplete**, since only JSON-objects (and things + contained therein) can be serialized to BSON. + Also, integers larger than 9223372036854775807 cannot be serialized to BSON, + and the keys may not contain U+0000, since they are serialized a + zero-terminated c-strings. + + @throw out_of_range.407 if `j.is_number_unsigned() && j.get() > 9223372036854775807` + @throw out_of_range.409 if a key in `j` contains a NULL (U+0000) + @throw type_error.317 if `!j.is_object()` + + @pre The input `j` is required to be an object: `j.is_object() == true`. + + @note Any BSON output created via @ref to_bson can be successfully parsed + by @ref from_bson. + + @param[in] j JSON value to serialize + @return BSON serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in BSON format.,to_bson} + + @sa http://bsonspec.org/spec.html + @sa see @ref from_bson(detail::input_adapter&&, const bool strict) for the + analogous deserialization + @sa see @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + @sa see @ref to_cbor(const basic_json&) for the related CBOR format + @sa see @ref to_msgpack(const basic_json&) for the related MessagePack format + */ + static std::vector to_bson(const basic_json& j) + { + std::vector result; + to_bson(j, result); + return result; + } + + /*! + @brief Serializes the given JSON object `j` to BSON and forwards the + corresponding BSON-representation to the given output_adapter `o`. + @param j The JSON object to convert to BSON. + @param o The output adapter that receives the binary BSON representation. + @pre The input `j` shall be an object: `j.is_object() == true` + @sa see @ref to_bson(const basic_json&) + */ + static void to_bson(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_bson(j); + } + + /*! + @copydoc to_bson(const basic_json&, detail::output_adapter) + */ + static void to_bson(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_bson(j); + } + + + /*! + @brief create a JSON value from an input in CBOR format + + Deserializes a given input @a i to a JSON value using the CBOR (Concise + Binary Object Representation) serialization format. + + The library maps CBOR types to JSON value types as follows: + + CBOR type | JSON value type | first byte + ---------------------- | --------------- | ---------- + Integer | number_unsigned | 0x00..0x17 + Unsigned integer | number_unsigned | 0x18 + Unsigned integer | number_unsigned | 0x19 + Unsigned integer | number_unsigned | 0x1A + Unsigned integer | number_unsigned | 0x1B + Negative integer | number_integer | 0x20..0x37 + Negative integer | number_integer | 0x38 + Negative integer | number_integer | 0x39 + Negative integer | number_integer | 0x3A + Negative integer | number_integer | 0x3B + Byte string | binary | 0x40..0x57 + Byte string | binary | 0x58 + Byte string | binary | 0x59 + Byte string | binary | 0x5A + Byte string | binary | 0x5B + UTF-8 string | string | 0x60..0x77 + UTF-8 string | string | 0x78 + UTF-8 string | string | 0x79 + UTF-8 string | string | 0x7A + UTF-8 string | string | 0x7B + UTF-8 string | string | 0x7F + array | array | 0x80..0x97 + array | array | 0x98 + array | array | 0x99 + array | array | 0x9A + array | array | 0x9B + array | array | 0x9F + map | object | 0xA0..0xB7 + map | object | 0xB8 + map | object | 0xB9 + map | object | 0xBA + map | object | 0xBB + map | object | 0xBF + False | `false` | 0xF4 + True | `true` | 0xF5 + Null | `null` | 0xF6 + Half-Precision Float | number_float | 0xF9 + Single-Precision Float | number_float | 0xFA + Double-Precision Float | number_float | 0xFB + + @warning The mapping is **incomplete** in the sense that not all CBOR + types can be converted to a JSON value. The following CBOR types + are not supported and will yield parse errors (parse_error.112): + - date/time (0xC0..0xC1) + - bignum (0xC2..0xC3) + - decimal fraction (0xC4) + - bigfloat (0xC5) + - expected conversions (0xD5..0xD7) + - simple values (0xE0..0xF3, 0xF8) + - undefined (0xF7) + + @warning CBOR allows map keys of any type, whereas JSON only allows + strings as keys in object values. Therefore, CBOR maps with keys + other than UTF-8 strings are rejected (parse_error.113). + + @note Any CBOR output created @ref to_cbor can be successfully parsed by + @ref from_cbor. + + @param[in] i an input in CBOR format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] tag_handler how to treat CBOR tags (optional, error by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from CBOR were + used in the given input @a v or if the input is not valid CBOR + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in CBOR + format to a JSON value.,from_cbor} + + @sa http://cbor.io + @sa see @ref to_cbor(const basic_json&) for the analogous serialization + @sa see @ref from_msgpack(InputType&&, const bool, const bool) for the + related MessagePack format + @sa see @ref from_ubjson(InputType&&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0; added @a allow_exceptions parameter + since 3.2.0; added @a tag_handler parameter since 3.9.0. + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_cbor(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_cbor(InputType&&, const bool, const bool, const cbor_tag_handler_t) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_cbor(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) + static basic_json from_cbor(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + return from_cbor(ptr, ptr + len, strict, allow_exceptions, tag_handler); + } + + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) + static basic_json from_cbor(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @brief create a JSON value from an input in MessagePack format + + Deserializes a given input @a i to a JSON value using the MessagePack + serialization format. + + The library maps MessagePack types to JSON value types as follows: + + MessagePack type | JSON value type | first byte + ---------------- | --------------- | ---------- + positive fixint | number_unsigned | 0x00..0x7F + fixmap | object | 0x80..0x8F + fixarray | array | 0x90..0x9F + fixstr | string | 0xA0..0xBF + nil | `null` | 0xC0 + false | `false` | 0xC2 + true | `true` | 0xC3 + float 32 | number_float | 0xCA + float 64 | number_float | 0xCB + uint 8 | number_unsigned | 0xCC + uint 16 | number_unsigned | 0xCD + uint 32 | number_unsigned | 0xCE + uint 64 | number_unsigned | 0xCF + int 8 | number_integer | 0xD0 + int 16 | number_integer | 0xD1 + int 32 | number_integer | 0xD2 + int 64 | number_integer | 0xD3 + str 8 | string | 0xD9 + str 16 | string | 0xDA + str 32 | string | 0xDB + array 16 | array | 0xDC + array 32 | array | 0xDD + map 16 | object | 0xDE + map 32 | object | 0xDF + bin 8 | binary | 0xC4 + bin 16 | binary | 0xC5 + bin 32 | binary | 0xC6 + ext 8 | binary | 0xC7 + ext 16 | binary | 0xC8 + ext 32 | binary | 0xC9 + fixext 1 | binary | 0xD4 + fixext 2 | binary | 0xD5 + fixext 4 | binary | 0xD6 + fixext 8 | binary | 0xD7 + fixext 16 | binary | 0xD8 + negative fixint | number_integer | 0xE0-0xFF + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @param[in] i an input in MessagePack format convertible to an input + adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from MessagePack were + used in the given input @a i or if the input is not valid MessagePack + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + MessagePack format to a JSON value.,from_msgpack} + + @sa http://msgpack.org + @sa see @ref to_msgpack(const basic_json&) for the analogous serialization + @sa see @ref from_cbor(InputType&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa see @ref from_ubjson(InputType&&, const bool, const bool) for + the related UBJSON format + @sa see @ref from_bson(InputType&&, const bool, const bool) for + the related BSON format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0; added @a allow_exceptions parameter + since 3.2.0 + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_msgpack(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_msgpack(InputType&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_msgpack(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) + static basic_json from_msgpack(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_msgpack(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) + static basic_json from_msgpack(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + /*! + @brief create a JSON value from an input in UBJSON format + + Deserializes a given input @a i to a JSON value using the UBJSON (Universal + Binary JSON) serialization format. + + The library maps UBJSON types to JSON value types as follows: + + UBJSON type | JSON value type | marker + ----------- | --------------------------------------- | ------ + no-op | *no value, next value is read* | `N` + null | `null` | `Z` + false | `false` | `F` + true | `true` | `T` + float32 | number_float | `d` + float64 | number_float | `D` + uint8 | number_unsigned | `U` + int8 | number_integer | `i` + int16 | number_integer | `I` + int32 | number_integer | `l` + int64 | number_integer | `L` + high-precision number | number_integer, number_unsigned, or number_float - depends on number string | 'H' + string | string | `S` + char | string | `C` + array | array (optimized values are supported) | `[` + object | object (optimized values are supported) | `{` + + @note The mapping is **complete** in the sense that any UBJSON value can + be converted to a JSON value. + + @param[in] i an input in UBJSON format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if a parse error occurs + @throw parse_error.113 if a string could not be parsed successfully + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + UBJSON format to a JSON value.,from_ubjson} + + @sa http://ubjson.org + @sa see @ref to_ubjson(const basic_json&, const bool, const bool) for the + analogous serialization + @sa see @ref from_cbor(InputType&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa see @ref from_msgpack(InputType&&, const bool, const bool) for + the related MessagePack format + @sa see @ref from_bson(InputType&&, const bool, const bool) for + the related BSON format + + @since version 3.1.0; added @a allow_exceptions parameter since 3.2.0 + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_ubjson(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_ubjson(InputType&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_ubjson(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) + static basic_json from_ubjson(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_ubjson(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) + static basic_json from_ubjson(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + /*! + @brief Create a JSON value from an input in BSON format + + Deserializes a given input @a i to a JSON value using the BSON (Binary JSON) + serialization format. + + The library maps BSON record types to JSON value types as follows: + + BSON type | BSON marker byte | JSON value type + --------------- | ---------------- | --------------------------- + double | 0x01 | number_float + string | 0x02 | string + document | 0x03 | object + array | 0x04 | array + binary | 0x05 | binary + undefined | 0x06 | still unsupported + ObjectId | 0x07 | still unsupported + boolean | 0x08 | boolean + UTC Date-Time | 0x09 | still unsupported + null | 0x0A | null + Regular Expr. | 0x0B | still unsupported + DB Pointer | 0x0C | still unsupported + JavaScript Code | 0x0D | still unsupported + Symbol | 0x0E | still unsupported + JavaScript Code | 0x0F | still unsupported + int32 | 0x10 | number_integer + Timestamp | 0x11 | still unsupported + 128-bit decimal float | 0x13 | still unsupported + Max Key | 0x7F | still unsupported + Min Key | 0xFF | still unsupported + + @warning The mapping is **incomplete**. The unsupported mappings + are indicated in the table above. + + @param[in] i an input in BSON format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.114 if an unsupported BSON record type is encountered + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + BSON format to a JSON value.,from_bson} + + @sa http://bsonspec.org/spec.html + @sa see @ref to_bson(const basic_json&) for the analogous serialization + @sa see @ref from_cbor(InputType&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa see @ref from_msgpack(InputType&&, const bool, const bool) for + the related MessagePack format + @sa see @ref from_ubjson(InputType&&, const bool, const bool) for the + related UBJSON format + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_bson(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_bson(InputType&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_bson(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) + static basic_json from_bson(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_bson(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) + static basic_json from_bson(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg) + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + /// @} + + ////////////////////////// + // JSON Pointer support // + ////////////////////////// + + /// @name JSON Pointer functions + /// @{ + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. Similar to @ref operator[](const typename + object_t::key_type&), `null` values are created in arrays and objects if + necessary. + + In particular: + - If the JSON pointer points to an object key that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. + - If the JSON pointer points to an array index that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. All indices between the current maximum and the given + index are also filled with `null`. + - The special value `-` is treated as a synonym for the index past the + end. + + @param[in] ptr a JSON pointer + + @return reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer} + + @since version 2.0.0 + */ + reference operator[](const json_pointer& ptr) + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. The function does not change the JSON + value; no `null` values are created. In particular, the special value + `-` yields an exception. + + @param[in] ptr JSON pointer to the desired element + + @return const reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer_const} + + @since version 2.0.0 + */ + const_reference operator[](const json_pointer& ptr) const + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a reference to the element at with specified JSON pointer @a ptr, + with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.403 if the JSON pointer describes a key of an object + which cannot be found. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer} + */ + reference at(const json_pointer& ptr) + { + return ptr.get_checked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a const reference to the element at with specified JSON pointer @a + ptr, with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.403 if the JSON pointer describes a key of an object + which cannot be found. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer_const} + */ + const_reference at(const json_pointer& ptr) const + { + return ptr.get_checked(this); + } + + /*! + @brief return flattened JSON value + + The function creates a JSON object whose keys are JSON pointers (see [RFC + 6901](https://tools.ietf.org/html/rfc6901)) and whose values are all + primitive. The original JSON value can be restored using the @ref + unflatten() function. + + @return an object that maps JSON pointers to primitive values + + @note Empty objects and arrays are flattened to `null` and will not be + reconstructed correctly by the @ref unflatten() function. + + @complexity Linear in the size the JSON value. + + @liveexample{The following code shows how a JSON object is flattened to an + object whose keys consist of JSON pointers.,flatten} + + @sa see @ref unflatten() for the reverse function + + @since version 2.0.0 + */ + basic_json flatten() const + { + basic_json result(value_t::object); + json_pointer::flatten("", *this, result); + return result; + } + + /*! + @brief unflatten a previously flattened JSON value + + The function restores the arbitrary nesting of a JSON value that has been + flattened before using the @ref flatten() function. The JSON value must + meet certain constraints: + 1. The value must be an object. + 2. The keys must be JSON pointers (see + [RFC 6901](https://tools.ietf.org/html/rfc6901)) + 3. The mapped values must be primitive JSON types. + + @return the original JSON from a flattened version + + @note Empty objects and arrays are flattened by @ref flatten() to `null` + values and can not unflattened to their original type. Apart from + this example, for a JSON value `j`, the following is always true: + `j == j.flatten().unflatten()`. + + @complexity Linear in the size the JSON value. + + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + + @liveexample{The following code shows how a flattened JSON object is + unflattened into the original nested JSON object.,unflatten} + + @sa see @ref flatten() for the reverse function + + @since version 2.0.0 + */ + basic_json unflatten() const + { + return json_pointer::unflatten(*this); + } + + /// @} + + ////////////////////////// + // JSON Patch functions // + ////////////////////////// + + /// @name JSON Patch functions + /// @{ + + /*! + @brief applies a JSON patch + + [JSON Patch](http://jsonpatch.com) defines a JSON document structure for + expressing a sequence of operations to apply to a JSON) document. With + this function, a JSON Patch is applied to the current JSON value by + executing all operations from the patch. + + @param[in] json_patch JSON patch document + @return patched document + + @note The application of a patch is atomic: Either all operations succeed + and the patched document is returned or an exception is thrown. In + any case, the original value is not changed: the patch is applied + to a copy of the value. + + @throw parse_error.104 if the JSON patch does not consist of an array of + objects + + @throw parse_error.105 if the JSON patch is malformed (e.g., mandatory + attributes are missing); example: `"operation add must have member path"` + + @throw out_of_range.401 if an array index is out of range. + + @throw out_of_range.403 if a JSON pointer inside the patch could not be + resolved successfully in the current JSON value; example: `"key baz not + found"` + + @throw out_of_range.405 if JSON pointer has no parent ("add", "remove", + "move") + + @throw other_error.501 if "test" operation was unsuccessful + + @complexity Linear in the size of the JSON value and the length of the + JSON patch. As usually only a fraction of the JSON value is affected by + the patch, the complexity can usually be neglected. + + @liveexample{The following code shows how a JSON patch is applied to a + value.,patch} + + @sa see @ref diff -- create a JSON patch by comparing two JSON values + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + @sa [RFC 6901 (JSON Pointer)](https://tools.ietf.org/html/rfc6901) + + @since version 2.0.0 + */ + basic_json patch(const basic_json& json_patch) const + { + // make a working copy to apply the patch to + basic_json result = *this; + + // the valid JSON Patch operations + enum class patch_operations {add, remove, replace, move, copy, test, invalid}; + + const auto get_op = [](const std::string & op) + { + if (op == "add") + { + return patch_operations::add; + } + if (op == "remove") + { + return patch_operations::remove; + } + if (op == "replace") + { + return patch_operations::replace; + } + if (op == "move") + { + return patch_operations::move; + } + if (op == "copy") + { + return patch_operations::copy; + } + if (op == "test") + { + return patch_operations::test; + } + + return patch_operations::invalid; + }; + + // wrapper for "add" operation; add value at ptr + const auto operation_add = [&result](json_pointer & ptr, basic_json val) + { + // adding to the root of the target document means replacing it + if (ptr.empty()) + { + result = val; + return; + } + + // make sure the top element of the pointer exists + json_pointer top_pointer = ptr.top(); + if (top_pointer != ptr) + { + result.at(top_pointer); + } + + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.back(); + ptr.pop_back(); + basic_json& parent = result[ptr]; + + switch (parent.m_type) + { + case value_t::null: + case value_t::object: + { + // use operator[] to add value + parent[last_path] = val; + break; + } + + case value_t::array: + { + if (last_path == "-") + { + // special case: append to back + parent.push_back(val); + } + else + { + const auto idx = json_pointer::array_index(last_path); + if (JSON_HEDLEY_UNLIKELY(idx > parent.size())) + { + // avoid undefined behavior + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range", parent)); + } + + // default case: insert add offset + parent.insert(parent.begin() + static_cast(idx), val); + } + break; + } + + // if there exists a parent it cannot be primitive + case value_t::string: // LCOV_EXCL_LINE + case value_t::boolean: // LCOV_EXCL_LINE + case value_t::number_integer: // LCOV_EXCL_LINE + case value_t::number_unsigned: // LCOV_EXCL_LINE + case value_t::number_float: // LCOV_EXCL_LINE + case value_t::binary: // LCOV_EXCL_LINE + case value_t::discarded: // LCOV_EXCL_LINE + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE + } + }; + + // wrapper for "remove" operation; remove value at ptr + const auto operation_remove = [this, &result](json_pointer & ptr) + { + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.back(); + ptr.pop_back(); + basic_json& parent = result.at(ptr); + + // remove child + if (parent.is_object()) + { + // perform range check + auto it = parent.find(last_path); + if (JSON_HEDLEY_LIKELY(it != parent.end())) + { + parent.erase(it); + } + else + { + JSON_THROW(out_of_range::create(403, "key '" + last_path + "' not found", *this)); + } + } + else if (parent.is_array()) + { + // note erase performs range check + parent.erase(json_pointer::array_index(last_path)); + } + }; + + // type check: top level value must be an array + if (JSON_HEDLEY_UNLIKELY(!json_patch.is_array())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects", json_patch)); + } + + // iterate and apply the operations + for (const auto& val : json_patch) + { + // wrapper to get a value for an operation + const auto get_value = [&val](const std::string & op, + const std::string & member, + bool string_type) -> basic_json & + { + // find value + auto it = val.m_value.object->find(member); + + // context-sensitive error message + const auto error_msg = (op == "op") ? "operation" : "operation '" + op + "'"; + + // check if desired value is present + if (JSON_HEDLEY_UNLIKELY(it == val.m_value.object->end())) + { + // NOLINTNEXTLINE(performance-inefficient-string-concatenation) + JSON_THROW(parse_error::create(105, 0, error_msg + " must have member '" + member + "'", val)); + } + + // check if result is of type string + if (JSON_HEDLEY_UNLIKELY(string_type && !it->second.is_string())) + { + // NOLINTNEXTLINE(performance-inefficient-string-concatenation) + JSON_THROW(parse_error::create(105, 0, error_msg + " must have string member '" + member + "'", val)); + } + + // no error: return value + return it->second; + }; + + // type check: every element of the array must be an object + if (JSON_HEDLEY_UNLIKELY(!val.is_object())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects", val)); + } + + // collect mandatory members + const auto op = get_value("op", "op", true).template get(); + const auto path = get_value(op, "path", true).template get(); + json_pointer ptr(path); + + switch (get_op(op)) + { + case patch_operations::add: + { + operation_add(ptr, get_value("add", "value", false)); + break; + } + + case patch_operations::remove: + { + operation_remove(ptr); + break; + } + + case patch_operations::replace: + { + // the "path" location must exist - use at() + result.at(ptr) = get_value("replace", "value", false); + break; + } + + case patch_operations::move: + { + const auto from_path = get_value("move", "from", true).template get(); + json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + basic_json v = result.at(from_ptr); + + // The move operation is functionally identical to a + // "remove" operation on the "from" location, followed + // immediately by an "add" operation at the target + // location with the value that was just removed. + operation_remove(from_ptr); + operation_add(ptr, v); + break; + } + + case patch_operations::copy: + { + const auto from_path = get_value("copy", "from", true).template get(); + const json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + basic_json v = result.at(from_ptr); + + // The copy is functionally identical to an "add" + // operation at the target location using the value + // specified in the "from" member. + operation_add(ptr, v); + break; + } + + case patch_operations::test: + { + bool success = false; + JSON_TRY + { + // check if "value" matches the one at "path" + // the "path" location must exist - use at() + success = (result.at(ptr) == get_value("test", "value", false)); + } + JSON_INTERNAL_CATCH (out_of_range&) + { + // ignore out of range errors: success remains false + } + + // throw an exception if test fails + if (JSON_HEDLEY_UNLIKELY(!success)) + { + JSON_THROW(other_error::create(501, "unsuccessful: " + val.dump(), val)); + } + + break; + } + + case patch_operations::invalid: + default: + { + // op must be "add", "remove", "replace", "move", "copy", or + // "test" + JSON_THROW(parse_error::create(105, 0, "operation value '" + op + "' is invalid", val)); + } + } + } + + return result; + } + + /*! + @brief creates a diff as a JSON patch + + Creates a [JSON Patch](http://jsonpatch.com) so that value @a source can + be changed into the value @a target by calling @ref patch function. + + @invariant For two JSON values @a source and @a target, the following code + yields always `true`: + @code {.cpp} + source.patch(diff(source, target)) == target; + @endcode + + @note Currently, only `remove`, `add`, and `replace` operations are + generated. + + @param[in] source JSON value to compare from + @param[in] target JSON value to compare against + @param[in] path helper value to create JSON pointers + + @return a JSON patch to convert the @a source to @a target + + @complexity Linear in the lengths of @a source and @a target. + + @liveexample{The following code shows how a JSON patch is created as a + diff for two JSON values.,diff} + + @sa see @ref patch -- apply a JSON patch + @sa see @ref merge_patch -- apply a JSON Merge Patch + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + + @since version 2.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json diff(const basic_json& source, const basic_json& target, + const std::string& path = "") + { + // the patch + basic_json result(value_t::array); + + // if the values are the same, return empty patch + if (source == target) + { + return result; + } + + if (source.type() != target.type()) + { + // different types: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + return result; + } + + switch (source.type()) + { + case value_t::array: + { + // first pass: traverse common elements + std::size_t i = 0; + while (i < source.size() && i < target.size()) + { + // recursive call to compare array values at index i + auto temp_diff = diff(source[i], target[i], path + "/" + std::to_string(i)); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + ++i; + } + + // i now reached the end of at least one array + // in a second pass, traverse the remaining elements + + // remove my remaining elements + const auto end_index = static_cast(result.size()); + while (i < source.size()) + { + // add operations in reverse order to avoid invalid + // indices + result.insert(result.begin() + end_index, object( + { + {"op", "remove"}, + {"path", path + "/" + std::to_string(i)} + })); + ++i; + } + + // add other remaining elements + while (i < target.size()) + { + result.push_back( + { + {"op", "add"}, + {"path", path + "/-"}, + {"value", target[i]} + }); + ++i; + } + + break; + } + + case value_t::object: + { + // first pass: traverse this object's elements + for (auto it = source.cbegin(); it != source.cend(); ++it) + { + // escape the key name to be used in a JSON patch + const auto path_key = path + "/" + detail::escape(it.key()); + + if (target.find(it.key()) != target.end()) + { + // recursive call to compare object values at key it + auto temp_diff = diff(it.value(), target[it.key()], path_key); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + } + else + { + // found a key that is not in o -> remove it + result.push_back(object( + { + {"op", "remove"}, {"path", path_key} + })); + } + } + + // second pass: traverse other object's elements + for (auto it = target.cbegin(); it != target.cend(); ++it) + { + if (source.find(it.key()) == source.end()) + { + // found a key that is not in this -> add it + const auto path_key = path + "/" + detail::escape(it.key()); + result.push_back( + { + {"op", "add"}, {"path", path_key}, + {"value", it.value()} + }); + } + } + + break; + } + + case value_t::null: + case value_t::string: + case value_t::boolean: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::number_float: + case value_t::binary: + case value_t::discarded: + default: + { + // both primitive type: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + break; + } + } + + return result; + } + + /// @} + + //////////////////////////////// + // JSON Merge Patch functions // + //////////////////////////////// + + /// @name JSON Merge Patch functions + /// @{ + + /*! + @brief applies a JSON Merge Patch + + The merge patch format is primarily intended for use with the HTTP PATCH + method as a means of describing a set of modifications to a target + resource's content. This function applies a merge patch to the current + JSON value. + + The function implements the following algorithm from Section 2 of + [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396): + + ``` + define MergePatch(Target, Patch): + if Patch is an Object: + if Target is not an Object: + Target = {} // Ignore the contents and set it to an empty Object + for each Name/Value pair in Patch: + if Value is null: + if Name exists in Target: + remove the Name/Value pair from Target + else: + Target[Name] = MergePatch(Target[Name], Value) + return Target + else: + return Patch + ``` + + Thereby, `Target` is the current object; that is, the patch is applied to + the current value. + + @param[in] apply_patch the patch to apply + + @complexity Linear in the lengths of @a patch. + + @liveexample{The following code shows how a JSON Merge Patch is applied to + a JSON document.,merge_patch} + + @sa see @ref patch -- apply a JSON patch + @sa [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396) + + @since version 3.0.0 + */ + void merge_patch(const basic_json& apply_patch) + { + if (apply_patch.is_object()) + { + if (!is_object()) + { + *this = object(); + } + for (auto it = apply_patch.begin(); it != apply_patch.end(); ++it) + { + if (it.value().is_null()) + { + erase(it.key()); + } + else + { + operator[](it.key()).merge_patch(it.value()); + } + } + } + else + { + *this = apply_patch; + } + } + + /// @} +}; + +/*! +@brief user-defined to_string function for JSON values + +This function implements a user-defined to_string for JSON objects. + +@param[in] j a JSON object +@return a std::string object +*/ + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +std::string to_string(const NLOHMANN_BASIC_JSON_TPL& j) +{ + return j.dump(); +} +} // namespace nlohmann + +/////////////////////// +// nonmember support // +/////////////////////// + +// specialization of std::swap, and std::hash +namespace std +{ + +/// hash value for JSON objects +template<> +struct hash +{ + /*! + @brief return a hash value for a JSON object + + @since version 1.0.0 + */ + std::size_t operator()(const nlohmann::json& j) const + { + return nlohmann::detail::hash(j); + } +}; + +/// specialization for std::less +/// @note: do not remove the space after '<', +/// see https://github.com/nlohmann/json/pull/679 +template<> +struct less<::nlohmann::detail::value_t> +{ + /*! + @brief compare two value_t enum values + @since version 3.0.0 + */ + bool operator()(nlohmann::detail::value_t lhs, + nlohmann::detail::value_t rhs) const noexcept + { + return nlohmann::detail::operator<(lhs, rhs); + } +}; + +// C++20 prohibit function specialization in the std namespace. +#ifndef JSON_HAS_CPP_20 + +/*! +@brief exchanges the values of two JSON objects + +@since version 1.0.0 +*/ +template<> +inline void swap(nlohmann::json& j1, nlohmann::json& j2) noexcept( // NOLINT(readability-inconsistent-declaration-parameter-name) + is_nothrow_move_constructible::value&& // NOLINT(misc-redundant-expression) + is_nothrow_move_assignable::value + ) +{ + j1.swap(j2); +} + +#endif + +} // namespace std + +/*! +@brief user-defined string literal for JSON values + +This operator implements a user-defined string literal for JSON objects. It +can be used by adding `"_json"` to a string literal and returns a JSON object +if no parse error occurred. + +@param[in] s a string representation of a JSON object +@param[in] n the length of string @a s +@return a JSON object + +@since version 1.0.0 +*/ +JSON_HEDLEY_NON_NULL(1) +inline nlohmann::json operator "" _json(const char* s, std::size_t n) +{ + return nlohmann::json::parse(s, s + n); +} + +/*! +@brief user-defined string literal for JSON pointer + +This operator implements a user-defined string literal for JSON Pointers. It +can be used by adding `"_json_pointer"` to a string literal and returns a JSON pointer +object if no parse error occurred. + +@param[in] s a string representation of a JSON Pointer +@param[in] n the length of string @a s +@return a JSON pointer object + +@since version 2.0.0 +*/ +JSON_HEDLEY_NON_NULL(1) +inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +{ + return nlohmann::json::json_pointer(std::string(s, n)); +} + +// #include + + +// restore clang diagnostic settings +#if defined(__clang__) + #pragma clang diagnostic pop +#endif + +// clean up +#undef JSON_ASSERT +#undef JSON_INTERNAL_CATCH +#undef JSON_CATCH +#undef JSON_THROW +#undef JSON_TRY +#undef JSON_PRIVATE_UNLESS_TESTED +#undef JSON_HAS_CPP_11 +#undef JSON_HAS_CPP_14 +#undef JSON_HAS_CPP_17 +#undef JSON_HAS_CPP_20 +#undef NLOHMANN_BASIC_JSON_TPL_DECLARATION +#undef NLOHMANN_BASIC_JSON_TPL +#undef JSON_EXPLICIT +#undef NLOHMANN_CAN_CALL_STD_FUNC_IMPL + +// #include + + +#undef JSON_HEDLEY_ALWAYS_INLINE +#undef JSON_HEDLEY_ARM_VERSION +#undef JSON_HEDLEY_ARM_VERSION_CHECK +#undef JSON_HEDLEY_ARRAY_PARAM +#undef JSON_HEDLEY_ASSUME +#undef JSON_HEDLEY_BEGIN_C_DECLS +#undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_BUILTIN +#undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_EXTENSION +#undef JSON_HEDLEY_CLANG_HAS_FEATURE +#undef JSON_HEDLEY_CLANG_HAS_WARNING +#undef JSON_HEDLEY_COMPCERT_VERSION +#undef JSON_HEDLEY_COMPCERT_VERSION_CHECK +#undef JSON_HEDLEY_CONCAT +#undef JSON_HEDLEY_CONCAT3 +#undef JSON_HEDLEY_CONCAT3_EX +#undef JSON_HEDLEY_CONCAT_EX +#undef JSON_HEDLEY_CONST +#undef JSON_HEDLEY_CONSTEXPR +#undef JSON_HEDLEY_CONST_CAST +#undef JSON_HEDLEY_CPP_CAST +#undef JSON_HEDLEY_CRAY_VERSION +#undef JSON_HEDLEY_CRAY_VERSION_CHECK +#undef JSON_HEDLEY_C_DECL +#undef JSON_HEDLEY_DEPRECATED +#undef JSON_HEDLEY_DEPRECATED_FOR +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION +#undef JSON_HEDLEY_DIAGNOSTIC_POP +#undef JSON_HEDLEY_DIAGNOSTIC_PUSH +#undef JSON_HEDLEY_DMC_VERSION +#undef JSON_HEDLEY_DMC_VERSION_CHECK +#undef JSON_HEDLEY_EMPTY_BASES +#undef JSON_HEDLEY_EMSCRIPTEN_VERSION +#undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK +#undef JSON_HEDLEY_END_C_DECLS +#undef JSON_HEDLEY_FLAGS +#undef JSON_HEDLEY_FLAGS_CAST +#undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_BUILTIN +#undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_EXTENSION +#undef JSON_HEDLEY_GCC_HAS_FEATURE +#undef JSON_HEDLEY_GCC_HAS_WARNING +#undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK +#undef JSON_HEDLEY_GCC_VERSION +#undef JSON_HEDLEY_GCC_VERSION_CHECK +#undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_BUILTIN +#undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_EXTENSION +#undef JSON_HEDLEY_GNUC_HAS_FEATURE +#undef JSON_HEDLEY_GNUC_HAS_WARNING +#undef JSON_HEDLEY_GNUC_VERSION +#undef JSON_HEDLEY_GNUC_VERSION_CHECK +#undef JSON_HEDLEY_HAS_ATTRIBUTE +#undef JSON_HEDLEY_HAS_BUILTIN +#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS +#undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_HAS_EXTENSION +#undef JSON_HEDLEY_HAS_FEATURE +#undef JSON_HEDLEY_HAS_WARNING +#undef JSON_HEDLEY_IAR_VERSION +#undef JSON_HEDLEY_IAR_VERSION_CHECK +#undef JSON_HEDLEY_IBM_VERSION +#undef JSON_HEDLEY_IBM_VERSION_CHECK +#undef JSON_HEDLEY_IMPORT +#undef JSON_HEDLEY_INLINE +#undef JSON_HEDLEY_INTEL_CL_VERSION +#undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK +#undef JSON_HEDLEY_INTEL_VERSION +#undef JSON_HEDLEY_INTEL_VERSION_CHECK +#undef JSON_HEDLEY_IS_CONSTANT +#undef JSON_HEDLEY_IS_CONSTEXPR_ +#undef JSON_HEDLEY_LIKELY +#undef JSON_HEDLEY_MALLOC +#undef JSON_HEDLEY_MCST_LCC_VERSION +#undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK +#undef JSON_HEDLEY_MESSAGE +#undef JSON_HEDLEY_MSVC_VERSION +#undef JSON_HEDLEY_MSVC_VERSION_CHECK +#undef JSON_HEDLEY_NEVER_INLINE +#undef JSON_HEDLEY_NON_NULL +#undef JSON_HEDLEY_NO_ESCAPE +#undef JSON_HEDLEY_NO_RETURN +#undef JSON_HEDLEY_NO_THROW +#undef JSON_HEDLEY_NULL +#undef JSON_HEDLEY_PELLES_VERSION +#undef JSON_HEDLEY_PELLES_VERSION_CHECK +#undef JSON_HEDLEY_PGI_VERSION +#undef JSON_HEDLEY_PGI_VERSION_CHECK +#undef JSON_HEDLEY_PREDICT +#undef JSON_HEDLEY_PRINTF_FORMAT +#undef JSON_HEDLEY_PRIVATE +#undef JSON_HEDLEY_PUBLIC +#undef JSON_HEDLEY_PURE +#undef JSON_HEDLEY_REINTERPRET_CAST +#undef JSON_HEDLEY_REQUIRE +#undef JSON_HEDLEY_REQUIRE_CONSTEXPR +#undef JSON_HEDLEY_REQUIRE_MSG +#undef JSON_HEDLEY_RESTRICT +#undef JSON_HEDLEY_RETURNS_NON_NULL +#undef JSON_HEDLEY_SENTINEL +#undef JSON_HEDLEY_STATIC_ASSERT +#undef JSON_HEDLEY_STATIC_CAST +#undef JSON_HEDLEY_STRINGIFY +#undef JSON_HEDLEY_STRINGIFY_EX +#undef JSON_HEDLEY_SUNPRO_VERSION +#undef JSON_HEDLEY_SUNPRO_VERSION_CHECK +#undef JSON_HEDLEY_TINYC_VERSION +#undef JSON_HEDLEY_TINYC_VERSION_CHECK +#undef JSON_HEDLEY_TI_ARMCL_VERSION +#undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL2000_VERSION +#undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL430_VERSION +#undef JSON_HEDLEY_TI_CL430_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL6X_VERSION +#undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL7X_VERSION +#undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK +#undef JSON_HEDLEY_TI_CLPRU_VERSION +#undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK +#undef JSON_HEDLEY_TI_VERSION +#undef JSON_HEDLEY_TI_VERSION_CHECK +#undef JSON_HEDLEY_UNAVAILABLE +#undef JSON_HEDLEY_UNLIKELY +#undef JSON_HEDLEY_UNPREDICTABLE +#undef JSON_HEDLEY_UNREACHABLE +#undef JSON_HEDLEY_UNREACHABLE_RETURN +#undef JSON_HEDLEY_VERSION +#undef JSON_HEDLEY_VERSION_DECODE_MAJOR +#undef JSON_HEDLEY_VERSION_DECODE_MINOR +#undef JSON_HEDLEY_VERSION_DECODE_REVISION +#undef JSON_HEDLEY_VERSION_ENCODE +#undef JSON_HEDLEY_WARNING +#undef JSON_HEDLEY_WARN_UNUSED_RESULT +#undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG +#undef JSON_HEDLEY_FALL_THROUGH + + + +#endif // INCLUDE_NLOHMANN_JSON_HPP_ diff --git a/gui/dependencies/pcg32/pcg32.h b/gui/dependencies/pcg32/pcg32.h new file mode 100644 index 0000000000000000000000000000000000000000..9ef404f5ff4308794814907ab408d324d93ed8ae --- /dev/null +++ b/gui/dependencies/pcg32/pcg32.h @@ -0,0 +1,201 @@ +/* + * Tiny self-contained version of the PCG Random Number Generation for C++ + * put together from pieces of the much larger C/C++ codebase. + * Wenzel Jakob, February 2015 + * + * The PCG random number generator was developed by Melissa O'Neill + * + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * For additional information about the PCG random number generation scheme, + * including its license and other licensing options, visit + * + * http://www.pcg-random.org + * + * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors. + */ + +#pragma once + +#include + +#define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL +#define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL +#define PCG32_MULT 0x5851f42d4c957f2dULL + +namespace tcnn { + +/// PCG32 Pseudorandom number generator +struct pcg32 { + /// Initialize the pseudorandom number generator with default seed + TCNN_HOST_DEVICE pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {} + + /// Initialize the pseudorandom number generator with the \ref seed() function + TCNN_HOST_DEVICE pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); } + + /** + * \brief Seed the pseudorandom number generator + * + * Specified in two parts: a state initializer and a sequence selection + * constant (a.k.a. stream id) + */ + TCNN_HOST_DEVICE void seed(uint64_t initstate, uint64_t initseq = 1) { + state = 0U; + inc = (initseq << 1u) | 1u; + next_uint(); + state += initstate; + next_uint(); + } + + /// Generate a uniformly distributed unsigned 32-bit random number + TCNN_HOST_DEVICE uint32_t next_uint() { + uint64_t oldstate = state; + state = oldstate * PCG32_MULT + inc; + uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u); + uint32_t rot = (uint32_t) (oldstate >> 59u); + return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31)); + } + + /// Generate a uniformly distributed number, r, where 0 <= r < bound + TCNN_HOST_DEVICE uint32_t next_uint(uint32_t bound) { + // To avoid bias, we need to make the range of the RNG a multiple of + // bound, which we do by dropping output less than a threshold. + // A naive scheme to calculate the threshold would be to do + // + // uint32_t threshold = 0x100000000ull % bound; + // + // but 64-bit div/mod is slower than 32-bit div/mod (especially on + // 32-bit platforms). In essence, we do + // + // uint32_t threshold = (0x100000000ull-bound) % bound; + // + // because this version will calculate the same modulus, but the LHS + // value is less than 2^32. + + uint32_t threshold = (~bound+1u) % bound; + + // Uniformity guarantees that this loop will terminate. In practice, it + // should usually terminate quickly; on average (assuming all bounds are + // equally likely), 82.25% of the time, we can expect it to require just + // one iteration. In the worst case, someone passes a bound of 2^31 + 1 + // (i.e., 2147483649), which invalidates almost 50% of the range. In + // practice, bounds are typically small and only a tiny amount of the range + // is eliminated. + for (;;) { + uint32_t r = next_uint(); + if (r >= threshold) + return r % bound; + } + } + + /// Generate a single precision floating point value on the interval [0, 1) + TCNN_HOST_DEVICE float next_float() { + /* Trick from MTGP: generate an uniformly distributed + single precision number in [1,2) and subtract 1. */ + union { + uint32_t u; + float f; + } x; + x.u = (next_uint() >> 9) | 0x3f800000u; + return x.f - 1.0f; + } + + /** + * \brief Generate a double precision floating point value on the interval [0, 1) + * + * \remark Since the underlying random number generator produces 32 bit output, + * only the first 32 mantissa bits will be filled (however, the resolution is still + * finer than in \ref next_float(), which only uses 23 mantissa bits) + */ + TCNN_HOST_DEVICE double next_double() { + /* Trick from MTGP: generate an uniformly distributed + double precision number in [1,2) and subtract 1. */ + union { + uint64_t u; + double d; + } x; + x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL; + return x.d - 1.0; + } + + /** + * \brief Multi-step advance function (jump-ahead, jump-back) + * + * The method used here is based on Brown, "Random Number Generation + * with Arbitrary Stride", Transactions of the American Nuclear + * Society (Nov. 1994). The algorithm is very similar to fast + * exponentiation. + * + * The default value of 2^32 ensures that the PRNG is advanced + * sufficiently far that there is (likely) no overlap with + * previously drawn random numbers, even if small advancements. + * are made inbetween. + */ + TCNN_HOST_DEVICE void advance(int64_t delta_ = (1ll<<32)) { + uint64_t + cur_mult = PCG32_MULT, + cur_plus = inc, + acc_mult = 1u, + acc_plus = 0u; + + /* Even though delta is an unsigned integer, we can pass a signed + integer to go backwards, it just goes "the long way round". */ + uint64_t delta = (uint64_t) delta_; + + while (delta > 0) { + if (delta & 1) { + acc_mult *= cur_mult; + acc_plus = acc_plus * cur_mult + cur_plus; + } + cur_plus = (cur_mult + 1) * cur_plus; + cur_mult *= cur_mult; + delta /= 2; + } + state = acc_mult * state + acc_plus; + } + + /// Compute the distance between two PCG32 pseudorandom number generators + TCNN_HOST_DEVICE int64_t operator-(const pcg32 &other) const { + uint64_t + cur_mult = PCG32_MULT, + cur_plus = inc, + cur_state = other.state, + the_bit = 1u, + distance = 0u; + + while (state != cur_state) { + if ((state & the_bit) != (cur_state & the_bit)) { + cur_state = cur_state * cur_mult + cur_plus; + distance |= the_bit; + } + + the_bit <<= 1; + cur_plus = (cur_mult + 1ULL) * cur_plus; + cur_mult *= cur_mult; + } + + return (int64_t) distance; + } + + /// Equality operator + TCNN_HOST_DEVICE bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; } + + /// Inequality operator + TCNN_HOST_DEVICE bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; } + + uint64_t state; // RNG state. All values are possible. + uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd. +}; + +} diff --git a/gui/dependencies/playne-equivalence/playne_equivalence.cuh b/gui/dependencies/playne-equivalence/playne_equivalence.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5f3c1453e3a04565dcea5433704c501e54898d45 --- /dev/null +++ b/gui/dependencies/playne-equivalence/playne_equivalence.cuh @@ -0,0 +1,159 @@ +// MIT License + +// Copyright (c) 2019 - Daniel Peter Playne + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Taken from https://github.com/DanielPlayne/playne-equivalence-algorithm + +#pragma once + +// ---------- Find the root of a chain ---------- +__device__ __inline__ unsigned int find_root(unsigned int* labels, unsigned int label) { + // Resolve Label + unsigned int next = labels[label]; + + // Follow chain + while (label != next) { + // Move to next + label = next; + next = labels[label]; + } + + // Return label + return label; +} + +// ---------- Label Reduction ---------- +__device__ __inline__ unsigned int reduction(unsigned int* g_labels, unsigned int label1, unsigned int label2) { + // Get next labels + unsigned int next1 = (label1 != label2) ? g_labels[label1] : 0; + unsigned int next2 = (label1 != label2) ? g_labels[label2] : 0; + + // Find label1 + while ((label1 != label2) && (label1 != next1)) { + // Adopt label + label1 = next1; + + // Fetch next label + next1 = g_labels[label1]; + } + + // Find label2 + while ((label1 != label2) && (label2 != next2)) { + // Adopt label + label2 = next2; + + // Fetch next label + next2 = g_labels[label2]; + } + + unsigned int label3; + // While Labels are different + while (label1 != label2) { + // Label 2 should be smallest + if (label1 < label2) { + // Swap Labels + label1 = label1 ^ label2; + label2 = label1 ^ label2; + label1 = label1 ^ label2; + } + + // AtomicMin label1 to label2 + label3 = atomicMin(&g_labels[label1], label2); + label1 = (label1 == label3) ? label2 : label3; + } + + // Return label1 + return label1; +} + +// ---------------------------------------- +// Device Kernels +// ---------------------------------------- + +// Initialise Kernel +__global__ void init_labels(const uint32_t cX, const uint32_t cY, const uint32_t cXY, unsigned int* g_labels, const unsigned char* g_image) { + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + + // Check Thread Range + if ((ix < cX) && (iy < cY)) { + // Fetch three image values + const unsigned char pyx = g_image[iy * cX + ix]; + + // Neighbour Connections + const bool nym1x = (iy > 0) ? (pyx == g_image[(iy - 1) * cX + ix]) : false; + const bool nyxm1 = (ix > 0) ? (pyx == g_image[iy * cX + ix - 1]) : false; + + // Label + unsigned int label; + + // Initialise Label + label = (nyxm1) ? iy * cX + ix - 1 : iy * cX + ix; + label = (nym1x) ? (iy - 1) * cX + ix : label; + + // Write to Global Memory + g_labels[iy * cX + ix] = label; + } +} + +// Resolve Kernel +__global__ void resolve_labels(const uint32_t cX, const uint32_t cY, const uint32_t cXY, unsigned int* g_labels) { + // Calculate index + const unsigned int id = ((blockIdx.y * blockDim.y) + threadIdx.y) * cX + ((blockIdx.x * blockDim.x) + threadIdx.x); + + // Check Thread Range + if (id < cXY) { + // Resolve Label + g_labels[id] = find_root(g_labels, g_labels[id]); + } +} + +// Label Reduction +__global__ void label_reduction(const uint32_t cX, const uint32_t cY, const uint32_t cXY, unsigned int* g_labels, unsigned char* g_image) { + // Calculate index + const unsigned int iy = ((blockIdx.y * blockDim.y) + threadIdx.y); + const unsigned int ix = ((blockIdx.x * blockDim.x) + threadIdx.x); + + // Check Thread Range + if ((ix < cX) && (iy < cY)) { + // Compare Image Values + const unsigned char pyx = g_image[iy * cX + ix]; + const bool nyxm1 = (ix > 0) ? (pyx == g_image[iy * cX + ix - 1]) : false; + + // If connected to neighbour + if (nyxm1) { + // Neighbouring values + const bool nym1xm1 = ((iy > 0) && (ix > 0)) ? (pyx == g_image[(iy - 1) * cX + ix - 1]) : false; + const bool nym1x = (iy > 0) ? (pyx == g_image[(iy - 1) * cX + ix]) : false; + + // Check Critical + if (nym1x && !nym1xm1) { + // Get labels + unsigned int label1 = g_labels[iy * cX + ix]; + unsigned int label2 = g_labels[iy * cX + ix - 1]; + + // Reduction + reduction(g_labels, label1, label2); + } + } + } +} diff --git a/gui/dependencies/pybind11_json/LICENSE b/gui/dependencies/pybind11_json/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0568c99d45ed0f51efc89cb64d0f358f81bffe0a --- /dev/null +++ b/gui/dependencies/pybind11_json/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019-2022, Martin Renou and pybind11_json contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/gui/dependencies/pybind11_json/pybind11_json.hpp b/gui/dependencies/pybind11_json/pybind11_json.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b51d1431cab26f6939029e5d026fbef0738c753e --- /dev/null +++ b/gui/dependencies/pybind11_json/pybind11_json.hpp @@ -0,0 +1,193 @@ +/*************************************************************************** +* Copyright (c) 2019, Martin Renou * +* * +* Distributed under the terms of the BSD 3-Clause License. * +* * +* The full license is in the file LICENSE, distributed with this software. * +****************************************************************************/ + +#ifndef PYBIND11_JSON_HPP +#define PYBIND11_JSON_HPP + +#include +#include + +namespace py = pybind11; +namespace nl = nlohmann; + +namespace pyjson +{ + inline py::object from_json(const nl::json& j) + { + if (j.is_null()) + { + return py::none(); + } + else if (j.is_boolean()) + { + return py::bool_(j.get()); + } + else if (j.is_number_integer()) + { + return py::int_(j.get()); + } + else if (j.is_number_float()) + { + return py::float_(j.get()); + } + else if (j.is_string()) + { + return py::str(j.get()); + } + else if (j.is_array()) + { + py::list obj; + for (const auto& el : j) + { + obj.append(from_json(el)); + } + return std::move(obj); + } + else // Object + { + py::dict obj; + for (nl::json::const_iterator it = j.cbegin(); it != j.cend(); ++it) + { + obj[py::str(it.key())] = from_json(it.value()); + } + return std::move(obj); + } + } + + inline nl::json to_json(const py::handle& obj) + { + if (obj.ptr() == nullptr || obj.is_none()) + { + return nullptr; + } + if (py::isinstance(obj)) + { + return obj.cast(); + } + if (py::isinstance(obj)) + { + return obj.cast(); + } + if (py::isinstance(obj)) + { + return obj.cast(); + } + if (py::isinstance(obj)) + { + py::module base64 = py::module::import("base64"); + return base64.attr("b64encode")(obj).attr("decode")("utf-8").cast(); + } + if (py::isinstance(obj)) + { + return obj.cast(); + } + if (py::isinstance(obj) || py::isinstance(obj)) + { + auto out = nl::json::array(); + for (const py::handle value : obj) + { + out.push_back(to_json(value)); + } + return out; + } + if (py::isinstance(obj)) + { + auto out = nl::json::object(); + for (const py::handle key : obj) + { + out[py::str(key).cast()] = to_json(obj[key]); + } + return out; + } + throw std::runtime_error("to_json not implemented for this type of object: " + py::repr(obj).cast()); + } +} + +// nlohmann_json serializers +namespace nlohmann +{ + #define MAKE_NLJSON_SERIALIZER_DESERIALIZER(T) \ + template <> \ + struct adl_serializer \ + { \ + inline static void to_json(json& j, const T& obj) \ + { \ + j = pyjson::to_json(obj); \ + } \ + \ + inline static T from_json(const json& j) \ + { \ + return pyjson::from_json(j); \ + } \ + }; + + #define MAKE_NLJSON_SERIALIZER_ONLY(T) \ + template <> \ + struct adl_serializer \ + { \ + inline static void to_json(json& j, const T& obj) \ + { \ + j = pyjson::to_json(obj); \ + } \ + }; + + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::object); + + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::bool_); + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::int_); + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::float_); + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::str); + + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::list); + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::tuple); + MAKE_NLJSON_SERIALIZER_DESERIALIZER(py::dict); + + MAKE_NLJSON_SERIALIZER_ONLY(py::handle); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::item_accessor); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::list_accessor); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::tuple_accessor); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::sequence_accessor); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::str_attr_accessor); + MAKE_NLJSON_SERIALIZER_ONLY(py::detail::obj_attr_accessor); + + #undef MAKE_NLJSON_SERIALIZER + #undef MAKE_NLJSON_SERIALIZER_ONLY +} + +// pybind11 caster +namespace pybind11 +{ + namespace detail + { + template <> struct type_caster + { + public: + PYBIND11_TYPE_CASTER(nl::json, _("json")); + + bool load(handle src, bool) + { + try { + value = pyjson::to_json(src); + return true; + } + catch (...) + { + return false; + } + } + + static handle cast(nl::json src, return_value_policy /* policy */, handle /* parent */) + { + object obj = pyjson::from_json(src); + return obj.release(); + } + }; + } +} + +#endif diff --git a/gui/dependencies/stb_image/LICENSE b/gui/dependencies/stb_image/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a77ae91f3ec0abb0028c9d9bd8d34d880b454d38 --- /dev/null +++ b/gui/dependencies/stb_image/LICENSE @@ -0,0 +1,37 @@ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/gui/dependencies/stb_image/stb_image.h b/gui/dependencies/stb_image/stb_image.h new file mode 100644 index 0000000000000000000000000000000000000000..5e807a0a6e7cdbfbbf48dff5f5d3f3693c2bc851 --- /dev/null +++ b/gui/dependencies/stb_image/stb_image.h @@ -0,0 +1,7987 @@ +/* stb_image - v2.28 - public domain image loader - http://nothings.org/stb + no warranty implied; use at your own risk + + Do this: + #define STB_IMAGE_IMPLEMENTATION + before you include this file in *one* C or C++ file to create the implementation. + + // i.e. it should look like this: + #include ... + #include ... + #include ... + #define STB_IMAGE_IMPLEMENTATION + #include "stb_image.h" + + You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + + + QUICK NOTES: + Primarily of interest to game developers and other people who can + avoid problematic images and only need the trivial interface + + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) + PNG 1/2/4/8/16-bit-per-channel + + TGA (not sure what subset, if a subset) + BMP non-1bpp, non-RLE + PSD (composited view only, no extra channels, 8/16 bit-per-channel) + + GIF (*comp always reports as 4-channel) + HDR (radiance rgbE format) + PIC (Softimage PIC) + PNM (PPM and PGM binary only) + + Animated GIF still needs a proper API, but here's one way to do it: + http://gist.github.com/urraka/685d9a6340b26b830d49 + + - decode from memory or through FILE (define STBI_NO_STDIO to remove code) + - decode from arbitrary I/O callbacks + - SIMD acceleration on x86/x64 (SSE2) and ARM (NEON) + + Full documentation under "DOCUMENTATION" below. + + +LICENSE + + See end of file for license information. + +RECENT REVISION HISTORY: + + 2.28 (2023-01-29) many error fixes, security errors, just tons of stuff + 2.27 (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes + 2.26 (2020-07-13) many minor fixes + 2.25 (2020-02-02) fix warnings + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically + 2.23 (2019-08-11) fix clang static analysis warning + 2.22 (2019-03-04) gif fixes, fix warnings + 2.21 (2019-02-25) fix typo in comment + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes + 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 + RGB-format JPEG; remove white matting in PSD; + allocate large structures on the stack; + correct channel count for PNG & BMP + 2.10 (2016-01-22) avoid warning introduced in 2.09 + 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + + See end of file for full revision history. + + + ============================ Contributors ========================= + + Image formats Extensions, features + Sean Barrett (jpeg, png, bmp) Jetro Lauha (stbi_info) + Nicolas Schulz (hdr, psd) Martin "SpartanJ" Golini (stbi_info) + Jonathan Dummer (tga) James "moose2000" Brown (iPhone PNG) + Jean-Marc Lienher (gif) Ben "Disch" Wenger (io callbacks) + Tom Seddon (pic) Omar Cornut (1/2/4-bit PNG) + Thatcher Ulrich (psd) Nicolas Guillemot (vertical flip) + Ken Miller (pgm, ppm) Richard Mitton (16-bit PSD) + github:urraka (animated gif) Junggon Kim (PNM comments) + Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) + socks-the-fox (16-bit PNG) + Jeremy Sawicki (handle all ImageNet JPGs) + Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) + Arseny Kapoulkine Simon Breuss (16-bit PNM) + John-Mark Allen + Carmelo J Fdez-Aguera + + Bug & warning fixes + Marc LeBlanc David Woo Guillaume George Martins Mozeiko + Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski + Phil Jordan Dave Moore Roy Eltham + Hayaki Saito Nathan Reed Won Chun + Luke Graham Johan Duparc Nick Verigakis the Horde3D community + Thomas Ruf Ronny Chevalier github:rlyeh + Janez Zemva John Bartholomew Michal Cichon github:romigrou + Jonathan Blow Ken Hamada Tero Hanninen github:svdijk + Eugene Golushkov Laurent Gomila Cort Stratton github:snagar + Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex + Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Neil Bickford Matthew Gregan github:poppolopoppo + Julian Raschke Gregory Mullen Christian Floisand github:darealshinji + Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 + Brad Weinberger Matvey Cherevko github:mosra + Luca Sas Alexander Veselov Zack Middleton [reserved] + Ryan C. Gordon [reserved] [reserved] + DO NOT ADD YOUR NAME HERE + + Jacko Dirks + + To add your name to the credits, pick a random blank space in the middle and fill it. + 80% of merge conflicts on stb PRs are due to people adding their name at the end + of the credits. +*/ + +#ifndef STBI_INCLUDE_STB_IMAGE_H +#define STBI_INCLUDE_STB_IMAGE_H + +// DOCUMENTATION +// +// Limitations: +// - no 12-bit-per-channel JPEG +// - no JPEGs with arithmetic coding +// - GIF always returns *comp=4 +// +// Basic usage (see HDR discussion below for HDR usage): +// int x,y,n; +// unsigned char *data = stbi_load(filename, &x, &y, &n, 0); +// // ... process data if not NULL ... +// // ... x = width, y = height, n = # 8-bit components per pixel ... +// // ... replace '0' with '1'..'4' to force that many components per pixel +// // ... but 'n' will always be the number that it would have been if you said 0 +// stbi_image_free(data); +// +// Standard parameters: +// int *x -- outputs image width in pixels +// int *y -- outputs image height in pixels +// int *channels_in_file -- outputs # of image components in image file +// int desired_channels -- if non-zero, # of image components requested in result +// +// The return value from an image loader is an 'unsigned char *' which points +// to the pixel data, or NULL on an allocation failure or if the image is +// corrupt or invalid. The pixel data consists of *y scanlines of *x pixels, +// with each pixel consisting of N interleaved 8-bit components; the first +// pixel pointed to is top-left-most in the image. There is no padding between +// image scanlines or between pixels, regardless of format. The number of +// components N is 'desired_channels' if desired_channels is non-zero, or +// *channels_in_file otherwise. If desired_channels is non-zero, +// *channels_in_file has the number of components that _would_ have been +// output otherwise. E.g. if you set desired_channels to 4, you will always +// get RGBA output, but you can check *channels_in_file to see if it's trivially +// opaque because e.g. there were only 3 channels in the source image. +// +// An output image with N components has the following components interleaved +// in this order in each pixel: +// +// N=#comp components +// 1 grey +// 2 grey, alpha +// 3 red, green, blue +// 4 red, green, blue, alpha +// +// If image loading fails for any reason, the return value will be NULL, +// and *x, *y, *channels_in_file will be unchanged. The function +// stbi_failure_reason() can be queried for an extremely brief, end-user +// unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly +// more user-friendly ones. +// +// Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. +// +// To query the width, height and component count of an image without having to +// decode the full file, you can use the stbi_info family of functions: +// +// int x,y,n,ok; +// ok = stbi_info(filename, &x, &y, &n); +// // returns ok=1 and sets x, y, n if image is a supported format, +// // 0 otherwise. +// +// Note that stb_image pervasively uses ints in its public API for sizes, +// including sizes of memory buffers. This is now part of the API and thus +// hard to change without causing breakage. As a result, the various image +// loaders all have certain limits on image size; these differ somewhat +// by format but generally boil down to either just under 2GB or just under +// 1GB. When the decoded image would be larger than this, stb_image decoding +// will fail. +// +// Additionally, stb_image will reject image files that have any of their +// dimensions set to a larger value than the configurable STBI_MAX_DIMENSIONS, +// which defaults to 2**24 = 16777216 pixels. Due to the above memory limit, +// the only way to have an image with such dimensions load correctly +// is for it to have a rather extreme aspect ratio. Either way, the +// assumption here is that such larger images are likely to be malformed +// or malicious. If you do need to load an image with individual dimensions +// larger than that, and it still fits in the overall size limit, you can +// #define STBI_MAX_DIMENSIONS on your own to be something larger. +// +// =========================================================================== +// +// UNICODE: +// +// If compiling for Windows and you wish to use Unicode filenames, compile +// with +// #define STBI_WINDOWS_UTF8 +// and pass utf8-encoded filenames. Call stbi_convert_wchar_to_utf8 to convert +// Windows wchar_t filenames to utf8. +// +// =========================================================================== +// +// Philosophy +// +// stb libraries are designed with the following priorities: +// +// 1. easy to use +// 2. easy to maintain +// 3. good performance +// +// Sometimes I let "good performance" creep up in priority over "easy to maintain", +// and for best performance I may provide less-easy-to-use APIs that give higher +// performance, in addition to the easy-to-use ones. Nevertheless, it's important +// to keep in mind that from the standpoint of you, a client of this library, +// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// +// Some secondary priorities arise directly from the first two, some of which +// provide more explicit reasons why performance can't be emphasized. +// +// - Portable ("ease of use") +// - Small source code footprint ("easy to maintain") +// - No dependencies ("ease of use") +// +// =========================================================================== +// +// I/O callbacks +// +// I/O callbacks allow you to read from arbitrary sources, like packaged +// files or some other source. Data read from callbacks are processed +// through a small internal buffer (currently 128 bytes) to try to reduce +// overhead. +// +// The three functions you must define are "read" (reads some bytes of data), +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// +// =========================================================================== +// +// SIMD support +// +// The JPEG decoder will try to automatically use SIMD kernels on x86 when +// supported by the compiler. For ARM Neon support, you must explicitly +// request it. +// +// (The old do-it-yourself SIMD API is no longer supported in the current +// code.) +// +// On x86, SSE2 will automatically be used when available based on a run-time +// test; if not, the generic C versions are used as a fall-back. On ARM targets, +// the typical path is to have separate builds for NEON and non-NEON devices +// (at least this is true for iOS and Android). Therefore, the NEON support is +// toggled by a build flag: define STBI_NEON to get NEON loops. +// +// If for some reason you do not want to use any of SIMD code, or if +// you have issues compiling it, you can disable it entirely by +// defining STBI_NO_SIMD. +// +// =========================================================================== +// +// HDR image support (disable by defining STBI_NO_HDR) +// +// stb_image supports loading HDR images in general, and currently the Radiance +// .HDR file format specifically. You can still load any file through the existing +// interface; if you attempt to load an HDR file, it will be automatically remapped +// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; +// both of these constants can be reconfigured through this interface: +// +// stbi_hdr_to_ldr_gamma(2.2f); +// stbi_hdr_to_ldr_scale(1.0f); +// +// (note, do not use _inverse_ constants; stbi_image will invert them +// appropriately). +// +// Additionally, there is a new, parallel interface for loading files as +// (linear) floats to preserve the full dynamic range: +// +// float *data = stbi_loadf(filename, &x, &y, &n, 0); +// +// If you load LDR images through this interface, those images will +// be promoted to floating point values, run through the inverse of +// constants corresponding to the above: +// +// stbi_ldr_to_hdr_scale(1.0f); +// stbi_ldr_to_hdr_gamma(2.2f); +// +// Finally, given a filename (or an open file or memory block--see header +// file for details) containing image data, you can query for the "most +// appropriate" interface to use (that is, whether the image is HDR or +// not), using: +// +// stbi_is_hdr(char *filename); +// +// =========================================================================== +// +// iPhone PNG support: +// +// We optionally support converting iPhone-formatted PNGs (which store +// premultiplied BGRA) back to RGB, even though they're internally encoded +// differently. To enable this conversion, call +// stbi_convert_iphone_png_to_rgb(1). +// +// Call stbi_set_unpremultiply_on_load(1) as well to force a divide per +// pixel to remove any premultiplied alpha *only* if the image file explicitly +// says there's premultiplied data (currently only happens in iPhone images, +// and only if iPhone convert-to-rgb processing is on). +// +// =========================================================================== +// +// ADDITIONAL CONFIGURATION +// +// - You can suppress implementation of any of the decoders to reduce +// your code footprint by #defining one or more of the following +// symbols before creating the implementation. +// +// STBI_NO_JPEG +// STBI_NO_PNG +// STBI_NO_BMP +// STBI_NO_PSD +// STBI_NO_TGA +// STBI_NO_GIF +// STBI_NO_HDR +// STBI_NO_PIC +// STBI_NO_PNM (.ppm and .pgm) +// +// - You can request *only* certain decoders and suppress all other ones +// (this will be more forward-compatible, as addition of new decoders +// doesn't require you to disable them explicitly): +// +// STBI_ONLY_JPEG +// STBI_ONLY_PNG +// STBI_ONLY_BMP +// STBI_ONLY_PSD +// STBI_ONLY_TGA +// STBI_ONLY_GIF +// STBI_ONLY_HDR +// STBI_ONLY_PIC +// STBI_ONLY_PNM (.ppm and .pgm) +// +// - If you use STBI_NO_PNG (or _ONLY_ without PNG), and you still +// want the zlib decoder to be available, #define STBI_SUPPORT_ZLIB +// +// - If you define STBI_MAX_DIMENSIONS, stb_image will reject images greater +// than that size (in either width or height) without further processing. +// This is to let programs in the wild set an upper bound to prevent +// denial-of-service attacks on untrusted data, as one could generate a +// valid image of gigantic dimensions and force stb_image to allocate a +// huge block of memory and spend disproportionate time decoding it. By +// default this is set to (1 << 24), which is 16777216, but that's still +// very big. + +#ifndef STBI_NO_STDIO +#include +#endif // STBI_NO_STDIO + +#define STBI_VERSION 1 + +enum +{ + STBI_default = 0, // only used for desired_channels + + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 +}; + +#include +typedef unsigned char stbi_uc; +typedef unsigned short stbi_us; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef STBIDEF +#ifdef STB_IMAGE_STATIC +#define STBIDEF static +#else +#define STBIDEF extern +#endif +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// PRIMARY API - works on images of any type +// + +// +// load image by filename, open file, or memory buffer +// + +typedef struct +{ + int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof) (void *user); // returns nonzero if we are at end of file/data +} stbi_io_callbacks; + +//////////////////////////////////// +// +// 8-bits-per-channel interface +// + +STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after image +#endif + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +#endif + +#ifdef STBI_WINDOWS_UTF8 +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +#endif + +//////////////////////////////////// +// +// 16-bits-per-channel interface +// + +STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +#endif + +//////////////////////////////////// +// +// float-per-channel interface +// +#ifndef STBI_NO_LINEAR + STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); + + #ifndef STBI_NO_STDIO + STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); + #endif +#endif + +#ifndef STBI_NO_HDR + STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); + STBIDEF void stbi_hdr_to_ldr_scale(float scale); +#endif // STBI_NO_HDR + +#ifndef STBI_NO_LINEAR + STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); + STBIDEF void stbi_ldr_to_hdr_scale(float scale); +#endif // STBI_NO_LINEAR + +// stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr (char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); +#endif // STBI_NO_STDIO + + +// get a VERY brief reason for failure +// on most compilers (and ALL modern mainstream compilers) this is threadsafe +STBIDEF const char *stbi_failure_reason (void); + +// free the loaded image -- this is just free() +STBIDEF void stbi_image_free (void *retval_from_stbi_load); + +// get image dimensions & components without fully decoding +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit (char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); +#endif + + + +// for image formats that explicitly notate that they have premultiplied alpha, +// we just return the colors as stored in the file. set this flag to force +// unpremultiplication. results are undefined if the unpremultiply overflow. +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); + +// indicate whether we should process iphone images back to canonical format, +// or just pass them through "as-is" +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); + +// flip the image vertically, so the first pixel in the output array is the bottom left +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); + +// as above, but only applies to images loaded on the thread that calls the function +// this function is only available if your compiler supports thread-local variables; +// calling it will fail to link if your compiler doesn't +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply); +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert); +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); + +// ZLIB client - used by PNG, available for other purposes + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + + +#ifdef __cplusplus +} +#endif + +// +// +//// end header file ///////////////////////////////////////////////////// +#endif // STBI_INCLUDE_STB_IMAGE_H + +#ifdef STB_IMAGE_IMPLEMENTATION + +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ + || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ + || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ + || defined(STBI_ONLY_ZLIB) + #ifndef STBI_ONLY_JPEG + #define STBI_NO_JPEG + #endif + #ifndef STBI_ONLY_PNG + #define STBI_NO_PNG + #endif + #ifndef STBI_ONLY_BMP + #define STBI_NO_BMP + #endif + #ifndef STBI_ONLY_PSD + #define STBI_NO_PSD + #endif + #ifndef STBI_ONLY_TGA + #define STBI_NO_TGA + #endif + #ifndef STBI_ONLY_GIF + #define STBI_NO_GIF + #endif + #ifndef STBI_ONLY_HDR + #define STBI_NO_HDR + #endif + #ifndef STBI_ONLY_PIC + #define STBI_NO_PIC + #endif + #ifndef STBI_ONLY_PNM + #define STBI_NO_PNM + #endif +#endif + +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif + + +#include +#include // ptrdiff_t on osx +#include +#include +#include + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) +#include // ldexp, pow +#endif + +#ifndef STBI_NO_STDIO +#include +#endif + +#ifndef STBI_ASSERT +#include +#define STBI_ASSERT(x) assert(x) +#endif + +#ifdef __cplusplus +#define STBI_EXTERN extern "C" +#else +#define STBI_EXTERN extern +#endif + + +#ifndef _MSC_VER + #ifdef __cplusplus + #define stbi_inline inline + #else + #define stbi_inline + #endif +#else + #define stbi_inline __forceinline +#endif + +#ifndef STBI_NO_THREAD_LOCALS + #if defined(__cplusplus) && __cplusplus >= 201103L + #define STBI_THREAD_LOCAL thread_local + #elif defined(__GNUC__) && __GNUC__ < 5 + #define STBI_THREAD_LOCAL __thread + #elif defined(_MSC_VER) + #define STBI_THREAD_LOCAL __declspec(thread) + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) + #define STBI_THREAD_LOCAL _Thread_local + #endif + + #ifndef STBI_THREAD_LOCAL + #if defined(__GNUC__) + #define STBI_THREAD_LOCAL __thread + #endif + #endif +#endif + +#if defined(_MSC_VER) || defined(__SYMBIAN32__) +typedef unsigned short stbi__uint16; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; +#else +#include +typedef uint16_t stbi__uint16; +typedef int16_t stbi__int16; +typedef uint32_t stbi__uint32; +typedef int32_t stbi__int32; +#endif + +// should produce compiler error if size is wrong +typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; + +#ifdef _MSC_VER +#define STBI_NOTUSED(v) (void)(v) +#else +#define STBI_NOTUSED(v) (void)sizeof(v) +#endif + +#ifdef _MSC_VER +#define STBI_HAS_LROTL +#endif + +#ifdef STBI_HAS_LROTL + #define stbi_lrot(x,y) _lrotl(x,y) +#else + #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (-(y) & 31))) +#endif + +#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +// ok +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#endif + +#ifndef STBI_MALLOC +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p,newsz) realloc(p,newsz) +#define STBI_FREE(p) free(p) +#endif + +#ifndef STBI_REALLOC_SIZED +#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#endif + +// x86/x64 detection +#if defined(__x86_64__) || defined(_M_X64) +#define STBI__X64_TARGET +#elif defined(__i386) || defined(_M_IX86) +#define STBI__X86_TARGET +#endif + +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +// gcc doesn't support sse2 intrinsics unless you compile with -msse2, +// which in turn means it gets to use SSE2 everywhere. This is unfortunate, +// but previous attempts to provide the SSE2 functions with runtime +// detection caused numerous issues. The way architecture extensions are +// exposed in GCC/Clang is, sadly, not really suited for one-file libs. +// New behavior: if compiled with -msse2, we use SSE2 without any +// detection; if not, we don't use it at all. +#define STBI_NO_SIMD +#endif + +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +// +// 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the +// Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. +// As a result, enabling SSE2 on 32-bit MinGW is dangerous when not +// simultaneously enabling "-mstackrealign". +// +// See https://github.com/nothings/stb/issues/81 for more information. +// +// So default to no SSE2 on 32-bit MinGW. If you've read this far and added +// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +#define STBI_NO_SIMD +#endif + +#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#define STBI_SSE2 +#include + +#ifdef _MSC_VER + +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) +{ + int info[4]; + __cpuid(info,1); + return info[3]; +} +#else +static int stbi__cpuid3(void) +{ + int res; + __asm { + mov eax,1 + cpuid + mov res,edx + } + return res; +} +#endif + +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) +{ + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; +} +#endif + +#else // assume GCC-style if not VC++ +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) +{ + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; +} +#endif + +#endif +#endif + +// ARM NEON +#if defined(STBI_NO_SIMD) && defined(STBI_NEON) +#undef STBI_NEON +#endif + +#ifdef STBI_NEON +#include +#ifdef _MSC_VER +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name +#else +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) +#endif +#endif + +#ifndef STBI_SIMD_ALIGN +#define STBI_SIMD_ALIGN(type, name) type name +#endif + +#ifndef STBI_MAX_DIMENSIONS +#define STBI_MAX_DIMENSIONS (1 << 24) +#endif + +/////////////////////////////////////////////// +// +// stbi__context struct and start_xxx functions + +// stbi__context structure is our basic context used by all images, so it +// contains all the IO context, plus some basic image information +typedef struct +{ + stbi__uint32 img_x, img_y; + int img_n, img_out_n; + + stbi_io_callbacks io; + void *io_user_data; + + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; + + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; +} stbi__context; + + +static void stbi__refill_buffer(stbi__context *s); + +// initialize a memory-decode context +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) +{ + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +} + +// initialize a callback-based context +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) +{ + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; +} + +#ifndef STBI_NO_STDIO + +static int stbi__stdio_read(void *user, char *data, int size) +{ + return (int) fread(data,1,size,(FILE*) user); +} + +static void stbi__stdio_skip(void *user, int n) +{ + int ch; + fseek((FILE*) user, n, SEEK_CUR); + ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ + } +} + +static int stbi__stdio_eof(void *user) +{ + return feof((FILE*) user) || ferror((FILE *) user); +} + +static stbi_io_callbacks stbi__stdio_callbacks = +{ + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, +}; + +static void stbi__start_file(stbi__context *s, FILE *f) +{ + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +} + +//static void stop_file(stbi__context *s) { } + +#endif // !STBI_NO_STDIO + +static void stbi__rewind(stbi__context *s) +{ + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; +} + +enum +{ + STBI_ORDER_RGB, + STBI_ORDER_BGR +}; + +typedef struct +{ + int bits_per_channel; + int num_channels; + int channel_order; +} stbi__result_info; + +#ifndef STBI_NO_JPEG +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PNG +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); +#endif + +#ifndef STBI_NO_BMP +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_TGA +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); +#endif + +#ifndef STBI_NO_HDR +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_GIF +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PNM +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_is16(stbi__context *s); +#endif + +static +#ifdef STBI_THREAD_LOCAL +STBI_THREAD_LOCAL +#endif +const char *stbi__g_failure_reason; + +STBIDEF const char *stbi_failure_reason(void) +{ + return stbi__g_failure_reason; +} + +#ifndef STBI_NO_FAILURE_STRINGS +static int stbi__err(const char *str) +{ + stbi__g_failure_reason = str; + return 0; +} +#endif + +static void *stbi__malloc(size_t size) +{ + return STBI_MALLOC(size); +} + +// stb_image uses ints pervasively, including for offset calculations. +// therefore the largest decoded image size we can support with the +// current code, even on 64-bit targets, is INT_MAX. this is not a +// significant limitation for the intended use case. +// +// we do, however, need to make sure our size calculations don't +// overflow. hence a few helper functions for size calculations that +// multiply integers together, making sure that they're non-negative +// and no overflow occurs. + +// return 1 if the sum is valid, 0 on overflow. +// negative terms are considered invalid. +static int stbi__addsizes_valid(int a, int b) +{ + if (b < 0) return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; +} + +// returns 1 if the product is valid, 0 on overflow. +// negative factors are considered invalid. +static int stbi__mul2sizes_valid(int a, int b) +{ + if (a < 0 || b < 0) return 0; + if (b == 0) return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX/b; +} + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow +static int stbi__mad2sizes_valid(int a, int b, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +} +#endif + +// returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow +static int stbi__mad3sizes_valid(int a, int b, int c, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__addsizes_valid(a*b*c, add); +} + +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +} +#endif + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// mallocs with size overflow checking +static void *stbi__malloc_mad2(int a, int b, int add) +{ + if (!stbi__mad2sizes_valid(a, b, add)) return NULL; + return stbi__malloc(a*b + add); +} +#endif + +static void *stbi__malloc_mad3(int a, int b, int c, int add) +{ + if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; + return stbi__malloc(a*b*c + add); +} + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) +{ + if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; + return stbi__malloc(a*b*c*d + add); +} +#endif + +// returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. +static int stbi__addints_valid(int a, int b) +{ + if ((a >= 0) != (b >= 0)) return 1; // a and b have different signs, so no overflow + if (a < 0 && b < 0) return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a <= INT_MAX - b; +} + +// returns 1 if the product of two signed shorts is valid, 0 on overflow. +static int stbi__mul2shorts_valid(short a, short b) +{ + if (b == 0 || b == -1) return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + if ((a >= 0) == (b >= 0)) return a <= SHRT_MAX/b; // product is positive, so similar to mul2sizes_valid + if (b < 0) return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a >= SHRT_MIN / b; +} + +// stbi__err - error +// stbi__errpf - error returning pointer to float +// stbi__errpuc - error returning pointer to unsigned char + +#ifdef STBI_NO_FAILURE_STRINGS + #define stbi__err(x,y) 0 +#elif defined(STBI_FAILURE_USERMSG) + #define stbi__err(x,y) stbi__err(y) +#else + #define stbi__err(x,y) stbi__err(x) +#endif + +#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) + +STBIDEF void stbi_image_free(void *retval_from_stbi_load) +{ + STBI_FREE(retval_from_stbi_load); +} + +#ifndef STBI_NO_LINEAR +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +#endif + +#ifndef STBI_NO_HDR +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +#endif + +static int stbi__vertically_flip_on_load_global = 0; + +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#else +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; + +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; +} + +#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ + ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) +#endif // STBI_THREAD_LOCAL + +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order + ri->num_channels = 0; + + // test the formats with a very explicit header first (at least a FOURCC + // or distinctive magic number first) + #ifndef STBI_NO_PNG + if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_GIF + if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PSD + if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); + #else + STBI_NOTUSED(bpc); + #endif + #ifndef STBI_NO_PIC + if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); + #endif + + // then the formats that can end up attempting to load with just 1 or 2 + // bytes matching expectations; these are prone to false positives, so + // try them later + #ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); + #endif + + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } + #endif + + #ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s,x,y,comp,req_comp, ri); + #endif + + return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *) stbi__malloc(img_len); + if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); + if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) +{ + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h>>1); row++) { + stbi_uc *row0 = bytes + row*bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } +} + +#ifndef STBI_NO_GIF +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) +{ + int slice; + int slice_size = w * h * bytes_per_pixel; + + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } +} +#endif + +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } + + // @TODO: move stbi__convert_format to here + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } + + return (unsigned char *) result; +} + +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } + + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } + + return (stbi__uint16 *) result; +} + +#if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) +{ + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } +} +#endif + +#ifndef STBI_NO_STDIO + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +#endif + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +} +#endif + +static FILE *stbi__fopen(char const *filename, char const *mode) +{ + FILE *f; +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f=0; +#else + f = fopen(filename, mode); +#endif + return f; +} + + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f,x,y,comp,req_comp); + fclose(f); + return result; +} + + +#endif //!STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); +} + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_mem(&s,buffer,len); + + result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); + } + + return result; +} +#endif + +#ifndef STBI_NO_LINEAR +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *data; + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data,x,y,comp,req_comp); + return hdr_data; + } + #endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); +} + +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} + +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} + +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_file(&s,f); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} +#endif // !STBI_NO_STDIO + +#endif // !STBI_NO_LINEAR + +// these is-hdr-or-not is defined independent of whether STBI_NO_LINEAR is +// defined, for API simplicity; if STBI_NO_LINEAR is defined, it always +// reports false! + +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; + #endif +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr (char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result=0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; +} + +STBIDEF int stbi_is_hdr_from_file(FILE *f) +{ + #ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s,f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; + #else + STBI_NOTUSED(f); + return 0; + #endif +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; + #endif +} + +#ifndef STBI_NO_LINEAR +static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; + +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +#endif + +static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; + +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } + + +////////////////////////////////////////////////////////////////////////////// +// +// Common code used by all image loaders +// + +enum +{ + STBI__SCAN_load=0, + STBI__SCAN_type, + STBI__SCAN_header +}; + +static void stbi__refill_buffer(stbi__context *s) +{ + int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); + s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start+1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) +{ + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +stbi_inline static int stbi__at_eof(stbi__context *s) +{ + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) return 1; + } + + return s->img_buffer >= s->img_buffer_end; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +// nothing +#else +static void stbi__skip(stbi__context *s, int n) +{ + if (n == 0) return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +// nothing +#else +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) +{ + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; + + memcpy(buffer, s->img_buffer, blen); + + count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); + res = (count == (n-blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } + + if (s->img_buffer+n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static int stbi__get16be(stbi__context *s) +{ + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static stbi__uint32 stbi__get32be(stbi__context *s) +{ + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); +} +#endif + +#if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) +// nothing +#else +static int stbi__get16le(stbi__context *s) +{ + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); +} +#endif + +#ifndef STBI_NO_BMP +static stbi__uint32 stbi__get32le(stbi__context *s) +{ + stbi__uint32 z = stbi__get16le(s); + z += (stbi__uint32)stbi__get16le(s) << 16; + return z; +} +#endif + +#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +////////////////////////////////////////////////////////////////////////////// +// +// generic converter from built-in img_n to req_comp +// individual types do this automatically as much as possible (e.g. jpeg +// does all cases internally since it needs to colorspace convert anyway, +// and it never has alpha, so very few cases ). png can automatically +// interleave an alpha=255 channel, but falls back to this for other cases +// +// assume data buffer is malloced, so malloc a new one and free that one +// only failure mode is malloc failing + +static stbi_uc stbi__compute_y(int r, int g, int b) +{ + return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + unsigned char *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + unsigned char *src = data + j * x * img_n ; + unsigned char *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) +{ + return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + stbi__uint16 *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + stbi__uint16 *src = data + j * x * img_n ; + stbi__uint16 *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#ifndef STBI_NO_LINEAR +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) +{ + int i,k,n; + float *output; + if (!data) return NULL; + output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); + } + } + if (n < comp) { + for (i=0; i < x*y; ++i) { + output[i*comp + n] = data[i*comp + n]/255.0f; + } + } + STBI_FREE(data); + return output; +} +#endif + +#ifndef STBI_NO_HDR +#define stbi__float2int(x) ((int) (x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) +{ + int i,k,n; + stbi_uc *output; + if (!data) return NULL; + output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + if (k < comp) { + float z = data[i*comp+k] * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + } + STBI_FREE(data); + return output; +} +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// "baseline" JPEG/JFIF decoder +// +// simple implementation +// - doesn't support delayed output of y-dimension +// - simple interface (only one output format: 8-bit interleaved RGB) +// - doesn't try to recover corrupt jpegs +// - doesn't allow partial loading, loading multiple at once +// - still fast on x86 (copying globals into locals doesn't help x86) +// - allocates lots of intermediate memory (full size of all components) +// - non-interleaved case requires this anyway +// - allows good upsampling (see next) +// high-quality +// - upsampled channels are bilinearly interpolated, even across blocks +// - quality integer IDCT derived from IJG's 'slow' +// performance +// - fast huffman; reasonable integer IDCT +// - some SIMD kernels for common paths on targets with SSE2/NEON +// - uses a lot of intermediate memory, could cache poorly + +#ifndef STBI_NO_JPEG + +// huffman decoding acceleration +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct +{ + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' +} stbi__huffman; + +typedef struct +{ + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + +// sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + +// definition of jpeg image component + struct + { + int id; + int h,v; + int tq; + int hd,ha; + int dc_pred; + + int x,y,w2,h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + +// kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +} stbi__jpeg; + +static int stbi__build_huffman(stbi__huffman *h, int *count) +{ + int i,j,k=0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i=0; i < 16; ++i) { + for (j=0; j < count[i]; ++j) { + h->size[k++] = (stbi_uc) (i+1); + if(k >= 257) return stbi__err("bad size list","Corrupt JPEG"); + } + } + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for(j=1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16) (code++); + if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16-j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i=0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS-s); + int m = 1 << (FAST_BITS-s); + for (j=0; j < m; ++j) { + h->fast[c+j] = (stbi_uc) i; + } + } + } + return 1; +} + +// build a table that decodes both magnitude and value of small ACs in +// one go. +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) +{ + int i; + for (i=0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); + } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) +{ + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char) c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); +} + +// (1 << n) - 1 +static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; + +// decode a jpeg huffman value from the bitstream +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) +{ + unsigned int temp; + int c,k; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k=FAST_BITS+1 ; ; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + if(c < 0 || c >= 256) // symbol id out of bounds! + return -1; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; +} + +// bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & (sgn - 1)); +} + +// get some unsigned bits +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) +{ + unsigned int k; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) +{ + unsigned int k; + if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); + if (j->code_bits < 1) return 0; // ran out of bits from stream, return 0s intead of continuing + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; +} + +// given a value that's at position X in the zigzag stream, +// where does it appear in the 8x8 matrix coded as row-major? +static const stbi_uc stbi__jpeg_dezigzag[64+15] = +{ + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63 +}; + +// decode one 64-entry block-- +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) +{ + int diff,dc,k; + int t; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("bad huffman code","Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data,0,64*sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta","Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, dequant[0])) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) break; // end block + k += 16; + } else { + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); + } + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) +{ + int diff,dc; + int t; + if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data,0,64*sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * (1 << j->succ_low)); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short) (1 << j->succ_low); + } + return 1; +} + +// @OPTIMIZE: store non-zigzagged during the decode passes, +// and only de-zigzag when dequantizing +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) +{ + int k; + if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } + + k = j->spec_start; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * (1 << shift)); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * (1 << shift)); + } + } + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short) (1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { + k = j->spec_start; + do { + int r,s; + int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } + + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short) s; + break; + } + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; +} + +// take a -128..127 value and stbi__clamp it and convert to 0..255 +stbi_inline static stbi_uc stbi__clamp(int x) +{ + // trick to use a single test to catch both cases + if ((unsigned int) x > 255) { + if (x < 0) return 0; + if (x > 255) return 255; + } + return (stbi_uc) x; +} + +#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) + +// derived from jidctint -- DCT_ISLOW +#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ + int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2+p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3*stbi__f2f(-1.847759065f); \ + t3 = p1 + p2*stbi__f2f( 0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2+p3); \ + t1 = stbi__fsh(p2-p3); \ + x0 = t0+t3; \ + x3 = t0-t3; \ + x1 = t1+t2; \ + x2 = t1-t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0+t2; \ + p4 = t1+t3; \ + p1 = t0+t3; \ + p2 = t1+t2; \ + p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ + t0 = t0*stbi__f2f( 0.298631336f); \ + t1 = t1*stbi__f2f( 2.053119869f); \ + t2 = t2*stbi__f2f( 3.072711026f); \ + t3 = t3*stbi__f2f( 1.501321110f); \ + p1 = p5 + p1*stbi__f2f(-0.899976223f); \ + p2 = p5 + p2*stbi__f2f(-2.562915447f); \ + p3 = p3*stbi__f2f(-1.961570560f); \ + p4 = p4*stbi__f2f(-0.390180644f); \ + t3 += p1+p4; \ + t2 += p2+p3; \ + t1 += p2+p4; \ + t0 += p1+p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) +{ + int i,val[64],*v=val; + stbi_uc *o; + short *d = data; + + // columns + for (i=0; i < 8; ++i,++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 + && d[40]==0 && d[48]==0 && d[56]==0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0]*4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; x1 += 512; x2 += 512; x3 += 512; + v[ 0] = (x0+t3) >> 10; + v[56] = (x0-t3) >> 10; + v[ 8] = (x1+t2) >> 10; + v[48] = (x1-t2) >> 10; + v[16] = (x2+t1) >> 10; + v[40] = (x2-t1) >> 10; + v[24] = (x3+t0) >> 10; + v[32] = (x3-t0) >> 10; + } + } + + for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128<<17); + x1 += 65536 + (128<<17); + x2 += 65536 + (128<<17); + x3 += 65536 + (128<<17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0+t3) >> 17); + o[7] = stbi__clamp((x0-t3) >> 17); + o[1] = stbi__clamp((x1+t2) >> 17); + o[6] = stbi__clamp((x1-t2) >> 17); + o[2] = stbi__clamp((x2+t1) >> 17); + o[5] = stbi__clamp((x2-t1) >> 17); + o[3] = stbi__clamp((x3+t0) >> 17); + o[4] = stbi__clamp((x3-t0) >> 17); + } +} + +#ifdef STBI_SSE2 +// sse2 integer IDCT. not the fastest possible implementation but it +// produces bit-identical results to the generic C version so it's +// fully "transparent". +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + + // dot product constant: even elems=x, odd elems=y + #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) + + // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) + // out(1) = c1[even]*x + c1[odd]*y + #define dct_rot(out0,out1, x,y,c0,c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + + // out = in << 12 (in 16-bit, out 32-bit) + #define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + + // wide add + #define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + + // wide sub + #define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + + // butterfly a/b, add bias, then shift by "s" and pack + #define dct_bfly32o(out0, out1, a,b,bias,s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + + // 8-bit interleave step (for transposes) + #define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + + // 16-bit interleave step (for transposes) + #define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + + #define dct_pass(bias,shift) \ + { \ + /* even part */ \ + dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ + dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0,row7, x0,x7,bias,shift); \ + dct_bfly32o(row1,row6, x1,x6,bias,shift); \ + dct_bfly32o(row2,row5, x2,x5,bias,shift); \ + dct_bfly32o(row3,row4, x3,x4,bias,shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); + + // load + row0 = _mm_load_si128((const __m128i *) (data + 0*8)); + row1 = _mm_load_si128((const __m128i *) (data + 1*8)); + row2 = _mm_load_si128((const __m128i *) (data + 2*8)); + row3 = _mm_load_si128((const __m128i *) (data + 3*8)); + row4 = _mm_load_si128((const __m128i *) (data + 4*8)); + row5 = _mm_load_si128((const __m128i *) (data + 5*8)); + row6 = _mm_load_si128((const __m128i *) (data + 6*8)); + row7 = _mm_load_si128((const __m128i *) (data + 7*8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *) out, p0); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p2); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p1); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p3); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); + } + +#undef dct_const +#undef dct_rot +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_interleave8 +#undef dct_interleave16 +#undef dct_pass +} + +#endif // STBI_SSE2 + +#ifdef STBI_NEON + +// NEON integer IDCT. should produce bit-identical +// results to the generic C version. +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) + +// wide add +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) + +// wide sub +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) + +// butterfly a/b, then shift using "shiftop" by "s" and pack +#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ + dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ + dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ + dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ + } + + // load + row0 = vld1q_s16(data + 0*8); + row1 = vld1q_s16(data + 1*8); + row2 = vld1q_s16(data + 2*8); + row3 = vld1q_s16(data + 3*8); + row4 = vld1q_s16(data + 4*8); + row5 = vld1q_s16(data + 5*8); + row6 = vld1q_s16(data + 6*8); + row7 = vld1q_s16(data + 7*8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { +// these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. +// whether compilers actually get this is another story, sadly. +#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } +#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); + +#undef dct_trn16 +#undef dct_trn32 +#undef dct_trn64 + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } +#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); out += out_stride; + vst1_u8(out, p1); out += out_stride; + vst1_u8(out, p2); out += out_stride; + vst1_u8(out, p3); out += out_stride; + vst1_u8(out, p4); out += out_stride; + vst1_u8(out, p5); out += out_stride; + vst1_u8(out, p6); out += out_stride; + vst1_u8(out, p7); + +#undef dct_trn8_8 +#undef dct_trn8_16 +#undef dct_trn8_32 + } + +#undef dct_long_mul +#undef dct_long_mac +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_pass +} + +#endif // STBI_NEON + +#define STBI__MARKER_none 0xff +// if there's a pending marker from the entropy stream, return that +// otherwise, fetch from the stream and get a marker. if there's no +// marker, return 0xff, which is never a valid marker value +static stbi_uc stbi__get_marker(stbi__jpeg *j) +{ + stbi_uc x; + if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } + x = stbi__get8(j->s); + if (x != 0xff) return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; +} + +// in each scan, we'll have scan_n components, and the order +// of the components is specified by order[] +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) + +// after a restart interval, stbi__jpeg_reset the entropy decoder and +// the dc prediction +static void stbi__jpeg_reset(stbi__jpeg *j) +{ + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) +{ + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i,j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + STBI_SIMD_ALIGN(short, data[64]); + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x)*8; + int y2 = (j*z->img_comp[n].v + y)*8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } else { + if (z->scan_n == 1) { + int i,j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x); + int y2 = (j*z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) +{ + int i; + for (i=0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) +{ + if (z->progressive) { + // dequantize and idct the data + int i,j,n; + for (n=0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + } + } + } + } +} + +static int stbi__process_marker(stbi__jpeg *z, int m) +{ + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker","Corrupt JPEG"); + + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; + + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s)-2; + while (L > 0) { + int q = stbi__get8(z->s); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15,i; + if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); + if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); + + for (i=0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L==0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s)-2; + while (L > 0) { + stbi_uc *v; + int sizes[16],i,n=0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); + for (i=0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + if(n > 256) return stbi__err("bad DHT header","Corrupt JPEG"); // Loop over i < n would write past end of values! + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; + v = z->huff_dc[th].values; + } else { + if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; + v = z->huff_ac[th].values; + } + for (i=0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L==0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len","Corrupt JPEG"); + else + return stbi__err("bad APP len","Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J','F','I','F','\0'}; + int ok = 1; + int i; + for (i=0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; + int ok = 1; + int i; + for (i=0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; + } + } + + stbi__skip(z->s, L); + return 1; + } + + return stbi__err("unknown marker","Corrupt JPEG"); +} + +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) +{ + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); + if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); + for (i=0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) return 0; // no match + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; +} + +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) +{ + int i; + for (i=0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) +{ + stbi__context *s = z->s; + int Lf,p,i,q, h_max=1,v_max=1,c; + Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG + p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); + s->img_n = c; + for (i=0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); + + z->rgb = 0; + for (i=0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = { 'R', 'G', 'B' }; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); + + for (i=0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; + } + + // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios + // and I've never seen a non-corrupted JPEG file actually use them + for (i=0; i < s->img_n; ++i) { + if (h_max % z->img_comp[i].h != 0) return stbi__err("bad H","Corrupt JPEG"); + if (v_max % z->img_comp[i].v != 0) return stbi__err("bad V","Corrupt JPEG"); + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; + + for (i=0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); + } + } + + return 1; +} + +// use comparisons since in some cases we handle more than one case (e.g. SOF) +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) +{ + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); + if (scan == STBI__SCAN_type) return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z,m)) return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); + m = stbi__get_marker(z); + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) return 0; + return 1; +} + +static int stbi__skip_jpeg_junk_at_end(stbi__jpeg *j) +{ + // some JPEGs have junk at end, skip over it but if we find what looks + // like a valid marker, resume there + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + while (x == 255) { // might be a marker + if (stbi__at_eof(j->s)) return STBI__MARKER_none; + x = stbi__get8(j->s); + if (x != 0x00 && x != 0xff) { + // not a stuffed zero or lead-in to another marker, looks + // like an actual marker, return it + return x; + } + // stuffed zero has x=0 now which ends the loop, meaning we go + // back to regular scan loop. + // repeated 0xff keeps trying to read the next byte of the marker. + } + } + return STBI__MARKER_none; +} + +// decode image to YCbCr format +static int stbi__decode_jpeg_image(stbi__jpeg *j) +{ + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) return 0; + if (!stbi__parse_entropy_coded_data(j)) return 0; + if (j->marker == STBI__MARKER_none ) { + j->marker = stbi__skip_jpeg_junk_at_end(j); + // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 + } + m = stbi__get_marker(j); + if (STBI__RESTART(m)) + m = stbi__get_marker(j); + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); + m = stbi__get_marker(j); + } else { + if (!stbi__process_marker(j, m)) return 1; + m = stbi__get_marker(j); + } + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; +} + +// static jfif-centered resampling (across block boundaries) + +typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, + int w, int hs); + +#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i=0; i < w; ++i) + out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0]*3 + input[1] + 2); + for (i=1; i < w-1; ++i) { + int n = 3*input[i]+2; + out[i*2+0] = stbi__div4(n+input[i-1]); + out[i*2+1] = stbi__div4(n+input[i+1]); + } + out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); + out[i*2+1] = input[w-1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i,t0,t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3*in_near[0] + in_far[0]; + out[0] = stbi__div4(t1+2); + for (i=1; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); + + STBI_NOTUSED(hs); + + return out; +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i=0,t0,t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3*in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w-1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *) (out + i*2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i*2, o); +#endif + + // "previous" value for next iter + t1 = 3*in_near[i+7] + in_far[i+7]; + } + + t0 = t1; + t1 = 3*in_near[i] + in_far[i]; + out[i*2] = stbi__div16(3*t1 + t0 + 8); + + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); + + STBI_NOTUSED(hs); + + return out; +} +#endif + +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // resample with nearest-neighbor + int i,j; + STBI_NOTUSED(in_far); + for (i=0; i < w; ++i) + for (j=0; j < hs; ++j) + out[i*hs+j] = in_near[i]; + return out; +} + +// this is a reduced-precision calculation of YCbCr-to-RGB introduced +// to make sure the code produces the same results in both SIMD and scalar +#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) +{ + int i; + for (i=0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) +{ + int i = 0; + +#ifdef STBI_SSE2 + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); + __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); + __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); + __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); + __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i+7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *) (out + 0), o0); + _mm_storeu_si128((__m128i *) (out + 16), o1); + out += 32; + } + } +#endif + +#ifdef STBI_NEON + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); + int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); + int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); + int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); + + for (; i+7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8*4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} +#endif + +// set up the kernels +static void stbi__setup_jpeg(stbi__jpeg *j) +{ + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; + +#ifdef STBI_SSE2 + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } +#endif + +#ifdef STBI_NEON + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; +#endif +} + +// clean up the temporary component buffers +static void stbi__cleanup_jpeg(stbi__jpeg *j) +{ + stbi__free_jpeg_components(j, j->s->img_n, 0); +} + +typedef struct +{ + resample_row_func resample; + stbi_uc *line0,*line1; + int hs,vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on +} stbi__resample; + +// fast 0..255 * 0..255 => 0..255 rounded multiplication +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) +{ + unsigned int t = x*y + 128; + return (stbi_uc) ((t + (t >>8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) +{ + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // nothing to do if no components requested; check this now to avoid + // accessing uninitialized coutput[0] later + if (decode_n <= 0) { stbi__cleanup_jpeg(z); return NULL; } + + // resample and color-convert + { + int k; + unsigned int i,j; + stbi_uc *output; + stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; + + stbi__resample res_comp[4]; + + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs-1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; + else r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + // now go ahead and resample + for (j=0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, + y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; + } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else + for (i=0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i=0; i < z->s->img_x; ++i) + *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i=0; i < z->s->img_x; ++i, out += 2) { + out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; + else + for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } + } + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + unsigned char* result; + stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__errpuc("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x,y,comp,req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) +{ + int r; + stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) +{ + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind( j->s ); + return 0; + } + if (x) *x = j->s->img_x; + if (y) *y = j->s->img_y; + if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) +{ + int result; + stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; +} +#endif + +// public domain zlib decode v0.2 Sean Barrett 2006-11-18 +// simple implementation +// - all input must be provided in an upfront buffer +// - all output is written to a single output buffer (can malloc/realloc) +// performance +// - fast huffman + +#ifndef STBI_NO_ZLIB + +// fast-way is faster to check than jpeg huffman, but slow way is slower +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet + +// zlib-style huffman encoding +// (jpegs packs from left, zlib from right, so can't share code) +typedef struct +{ + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[STBI__ZNSYMS]; + stbi__uint16 value[STBI__ZNSYMS]; +} stbi__zhuffman; + +stbi_inline static int stbi__bitreverse16(int n) +{ + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); + return n; +} + +stbi_inline static int stbi__bit_reverse(int v, int bits) +{ + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16-bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) +{ + int i,k=0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i=0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i=1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i=1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16) code; + z->firstsymbol[i] = (stbi__uint16) k; + code = (code + sizes[i]); + if (sizes[i]) + if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); + z->maxcode[i] = code << (16-i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i=0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); + z->size [c] = (stbi_uc ) s; + z->value[c] = (stbi__uint16) i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s],s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } + } + ++next_code[s]; + } + } + return 1; +} + +// zlib-from-memory implementation for PNG reading +// because PNG allows splitting the zlib stream arbitrarily, +// and it's annoying structurally to have PNG call ZLIB call PNG, +// we require PNG read all the IDATs and combine them into a single +// memory buffer + +typedef struct +{ + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; + + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; + + stbi__zhuffman z_length, z_distance; +} stbi__zbuf; + +stbi_inline static int stbi__zeof(stbi__zbuf *z) +{ + return (z->zbuffer >= z->zbuffer_end); +} + +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) +{ + return stbi__zeof(z) ? 0 : *z->zbuffer++; +} + +static void stbi__fill_bits(stbi__zbuf *z) +{ + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) +{ + unsigned int k; + if (z->num_bits < n) stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s,k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s=STBI__ZFAST_BITS+1; ; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) return -1; // invalid code! + // code size is s, so: + b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= STBI__ZNSYMS) return -1; // some data was corrupt somewhere! + if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); + cur = (unsigned int) (z->zout - z->zout_start); + limit = old_limit = (unsigned) (z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; +} + +static const int stbi__zlength_base[31] = { + 3,4,5,6,7,8,9,10,11,13, + 15,17,19,23,27,31,35,43,51,59, + 67,83,99,115,131,163,195,227,258,0,0 }; + +static const int stbi__zlength_extra[31]= +{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; + +static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, +257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; + +static const int stbi__zdist_extra[32] = +{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) +{ + char *zout = a->zout; + for(;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) return 0; + zout = a->zout; + } + *zout++ = (char) z; + } else { + stbi_uc *p; + int len,dist; + if (z == 256) { + a->zout = zout; + return 1; + } + if (z >= 286) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0 || z >= 30) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) return 0; + zout = a->zout; + } + p = (stbi_uc *) (zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { do *zout++ = v; while (--len); } + } else { + if (len) { do *zout++ = *p++; while (--len); } + } + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) +{ + static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286+32+137];//padding for maximum single op + stbi_uc codelength_sizes[19]; + int i,n; + + int hlit = stbi__zreceive(a,5) + 257; + int hdist = stbi__zreceive(a,5) + 1; + int hclen = stbi__zreceive(a,4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i=0; i < hclen; ++i) { + int s = stbi__zreceive(a,3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc) c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a,2)+3; + if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n-1]; + } else if (c == 17) { + c = stbi__zreceive(a,3)+3; + } else if (c == 18) { + c = stbi__zreceive(a,7)+11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); + } + if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes+n, fill, c); + n += c; + } + } + if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) +{ + stbi_uc header[4]; + int len,nlen,k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) +{ + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = +{ + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 +}; +static const stbi_uc stbi__zdefault_distance[32] = +{ + 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 +}; +/* +Init algorithm: +{ + int i; // use <= to match clearly with spec + for (i=0; i <= 143; ++i) stbi__zdefault_length[i] = 8; + for ( ; i <= 255; ++i) stbi__zdefault_length[i] = 9; + for ( ; i <= 279; ++i) stbi__zdefault_length[i] = 7; + for ( ; i <= 287; ++i) stbi__zdefault_length[i] = 8; + + for (i=0; i <= 31; ++i) stbi__zdefault_distance[i] = 5; +} +*/ + +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) +{ + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a,1); + type = stbi__zreceive(a,2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , STBI__ZNSYMS)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; + } else { + if (!stbi__compute_huffman_codes(a)) return 0; + } + if (!stbi__parse_huffman_block(a)) return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) +{ + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) +{ + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int) (a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(16384); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer+len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int) (a.zout - a.zout_start); + else + return -1; +} +#endif + +// public domain "baseline" PNG decoder v0.10 Sean Barrett 2006-11-18 +// simple implementation +// - only 8-bit samples +// - no CRC checking +// - allocates lots of intermediate memory +// - avoids problem of streaming data between subsystems +// - avoids explicit window management +// performance +// - uses stb_zlib, a PD zlib implementation with fast huffman decoding + +#ifndef STBI_NO_PNG +typedef struct +{ + stbi__uint32 length; + stbi__uint32 type; +} stbi__pngchunk; + +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) +{ + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; +} + +static int stbi__check_png_header(stbi__context *s) +{ + static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; + int i; + for (i=0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); + return 1; +} + +typedef struct +{ + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; +} stbi__png; + + +enum { + STBI__F_none=0, + STBI__F_sub=1, + STBI__F_up=2, + STBI__F_avg=3, + STBI__F_paeth=4, + // synthetic filters used for first scanline to avoid needing a dummy row of 0s + STBI__F_avg_first, + STBI__F_paeth_first +}; + +static stbi_uc first_row_filter[5] = +{ + STBI__F_none, + STBI__F_sub, + STBI__F_none, + STBI__F_avg_first, + STBI__F_paeth_first +}; + +static int stbi__paeth(int a, int b, int c) +{ + int p = a + b - c; + int pa = abs(p-a); + int pb = abs(p-b); + int pc = abs(p-c); + if (pa <= pb && pa <= pc) return a; + if (pb <= pc) return b; + return c; +} + +static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; + +// create the png data from post-deflated data +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) +{ + int bytes = (depth == 16? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i,j,stride = x*out_n*bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n*bytes; + int filter_bytes = img_n*bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); + a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, + // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), + // so just check for raw_len < img_len always. + if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); + + for (j=0; j < y; ++j) { + stbi_uc *cur = a->out + stride*j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter","Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); + cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k=0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none : cur[k] = raw[k]; break; + case STBI__F_sub : cur[k] = raw[k]; break; + case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; + case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; + case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; + case STBI__F_avg_first : cur[k] = raw[k]; break; + case STBI__F_paeth_first: cur[k] = raw[k]; break; + } + } + + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes+1] = 255; // first pixel bottom byte + } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } + + // this is a little gross, so that we don't switch per-pixel or per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1)*filter_bytes; + #define STBI__CASE(f) \ + case f: \ + for (k=0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: memcpy(cur, raw, nk); break; + STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; + STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; + STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; + STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; + STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; + STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; + } + #undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n+1 == out_n); + #define STBI__CASE(f) \ + case f: \ + for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ + for (k=0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; + STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; + STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; + STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; + STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; + STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; + STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; + } + #undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride*j; // start at the beginning of the row again + for (i=0; i < x; ++i,cur+=output_bytes) { + cur[filter_bytes+1] = 255; + } + } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j=0; j < y; ++j) { + stbi_uc *cur = a->out + stride*j; + stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit + // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than desired. + // we can allocate enough data that this never writes out of memory, but it + // could also overwrite the next scanline. can it overwrite non-empty data + // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. + // so we need to explicitly clamp the final ones + + if (depth == 4) { + for (k=x*img_n; k >= 2; k-=2, ++in) { + *cur++ = scale * ((*in >> 4) ); + *cur++ = scale * ((*in ) & 0x0f); + } + if (k > 0) *cur++ = scale * ((*in >> 4) ); + } else if (depth == 2) { + for (k=x*img_n; k >= 4; k-=4, ++in) { + *cur++ = scale * ((*in >> 6) ); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in ) & 0x03); + } + if (k > 0) *cur++ = scale * ((*in >> 6) ); + if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k=x*img_n; k >= 8; k-=8, ++in) { + *cur++ = scale * ((*in >> 7) ); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in ) & 0x01); + } + if (k > 0) *cur++ = scale * ((*in >> 7) ); + if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride*j; + if (img_n == 1) { + for (q=x-1; q >= 0; --q) { + cur[q*2+1] = 255; + cur[q*2+0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q=x-1; q >= 0; --q) { + cur[q*4+3] = 255; + cur[q*4+2] = cur[q*3+2]; + cur[q*4+1] = cur[q*3+1]; + cur[q*4+0] = cur[q*3+0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16*)cur; + + for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) +{ + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + if (!final) return stbi__err("outofmem", "Out of memory"); + for (p=0; p < 7; ++p) { + int xorig[] = { 0,4,0,2,0,1,0 }; + int yorig[] = { 0,0,4,0,2,0,1 }; + int xspc[] = { 8,8,4,4,2,2,1 }; + int yspc[] = { 8,8,8,4,4,2,2 }; + int i,j,x,y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j=0; j < y; ++j) { + for (i=0; i < x; ++i) { + int out_y = j*yspc[p]+yorig[p]; + int out_x = i*xspc[p]+xorig[p]; + memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, + a->out + (j*x+i)*out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; + + return 1; +} + +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i=0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i=0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16*) z->out; + + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) +{ + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) return stbi__err("outofmem", "Out of memory"); + + // between here and free(out) below, exitting would leak + temp_out = p; + + if (pal_img_n == 3) { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p += 3; + } + } else { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p[3] = palette[n+3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; +} + +static int stbi__unpremultiply_on_load_global = 0; +static int stbi__de_iphone_flag_global = 0; + +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_global = flag_true_if_should_convert; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global +#define stbi__de_iphone_flag stbi__de_iphone_flag_global +#else +static STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set; +static STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set; + +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; + stbi__unpremultiply_on_load_set = 1; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_local = flag_true_if_should_convert; + stbi__de_iphone_flag_set = 1; +} + +#define stbi__unpremultiply_on_load (stbi__unpremultiply_on_load_set \ + ? stbi__unpremultiply_on_load_local \ + : stbi__unpremultiply_on_load_global) +#define stbi__de_iphone_flag (stbi__de_iphone_flag_set \ + ? stbi__de_iphone_flag_local \ + : stbi__de_iphone_flag_global) +#endif // STBI_THREAD_LOCAL + +static void stbi__de_iphone(stbi__png *z) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i=0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = ( t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; + } + } else { + // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) +{ + stbi_uc palette[1024], pal_img_n=0; + stbi_uc has_trans=0, tc[3]={0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; + int first=1,k,interlace=0, color=0, is_iphone=0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) return 0; + + if (scan == STBI__SCAN_type) return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C','g','B','I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I','H','D','R'): { + int comp,filter; + if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); + first = 0; + if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); + comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); + filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); + interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); + if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); + } else { + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); + } + // even with SCAN_header, have to scan to see if we have a tRNS + break; + } + + case STBI__PNG_TYPE('P','L','T','E'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); + for (i=0; i < pal_len; ++i) { + palette[i*4+0] = stbi__get8(s); + palette[i*4+1] = stbi__get8(s); + palette[i*4+2] = stbi__get8(s); + palette[i*4+3] = 255; + } + break; + } + + case STBI__PNG_TYPE('t','R','N','S'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } + if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); + if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); + pal_img_n = 4; + for (i=0; i < c.length; ++i) + palette[i*4+3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); + if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); + has_trans = 1; + // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. + if (scan == STBI__SCAN_header) { ++s->img_n; return 1; } + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + } + } + break; + } + + case STBI__PNG_TYPE('I','D','A','T'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); + if (scan == STBI__SCAN_header) { + // header scan definitely stops at first IDAT + if (pal_img_n) + s->img_n = pal_img_n; + return 1; + } + if (c.length > (1u << 30)) return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); + if ((int)(ioff + c.length) < (int)ioff) return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); + ioff += c.length; + break; + } + + case STBI__PNG_TYPE('I','E','N','D'): { + stbi__uint32 raw_len, bpl; + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) return 1; + if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); + if (z->expanded == NULL) return 0; // zlib should set error + STBI_FREE(z->idata); z->idata = NULL; + if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) + s->img_out_n = s->img_n+1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } + + default: + // if critical, fail + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { + #ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); + #endif + return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) +{ + void *result=NULL; + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) *n = p->s->img_n; + } + STBI_FREE(p->out); p->out = NULL; + STBI_FREE(p->expanded); p->expanded = NULL; + STBI_FREE(p->idata); p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi__png p; + p.s = s; + return stbi__do_png(&p, x,y,comp,req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) +{ + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) +{ + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind( p->s ); + return 0; + } + if (x) *x = p->s->img_x; + if (y) *y = p->s->img_y; + if (comp) *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) +{ + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; +} +#endif + +// Microsoft/Windows BMP image + +#ifndef STBI_NO_BMP +static int stbi__bmp_test_raw(stbi__context *s) +{ + int r; + int sz; + if (stbi__get8(s) != 'B') return 0; + if (stbi__get8(s) != 'M') return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) +{ + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; +} + + +// returns 0..31 for the highest set bit +static int stbi__high_bit(unsigned int z) +{ + int n=0; + if (z == 0) return -1; + if (z >= 0x10000) { n += 16; z >>= 16; } + if (z >= 0x00100) { n += 8; z >>= 8; } + if (z >= 0x00010) { n += 4; z >>= 4; } + if (z >= 0x00004) { n += 2; z >>= 2; } + if (z >= 0x00002) { n += 1;/* >>= 1;*/ } + return n; +} + +static int stbi__bitcount(unsigned int a) +{ + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; +} + +// extract an arbitrarily-aligned N-bit value (N=bits) +// from v, and then make it 8-bits long and fractionally +// extend it to full full range. +static int stbi__shiftsigned(unsigned int v, int shift, int bits) +{ + static unsigned int mul_table[9] = { + 0, + 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, + 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0,0,1,0,2,4,6,0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8-bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct +{ + int bpp, offset, hsz; + unsigned int mr,mg,mb,ma, all_a; + int extra_read; +} stbi__bmp_data; + +static int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress) +{ + // BI_BITFIELDS specifies masks explicitly, don't override + if (compress == 3) + return 1; + + if (compress == 0) { + if (info->bpp == 16) { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } else if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + } else { + // otherwise, use defaults, which is all-0 + info->mr = info->mg = info->mb = info->ma = 0; + } + return 1; + } + return 0; // error +} + +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) +{ + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + if (compress >= 4) return stbi__errpuc("BMP JPEG/PNG", "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + if (compress == 3 && info->bpp != 16 && info->bpp != 32) return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + stbi__bmp_set_mask_defaults(info, compress); + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else { + // V4/V5 header + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + stbi__bmp_set_mask_defaults(info, compress); + stbi__get32le(s); // discard color space + for (i=0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved + } + } + } + return (void *) 1; +} + + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + unsigned int mr=0,mg=0,mb=0,ma=0, all_a; + stbi_uc pal[256][4]; + int psize=0,i,j,width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int) s->img_y) > 0; + s->img_y = abs((int) s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + // accept some number of extra bytes after the header, but if the offset points either to before + // the header ends or implies a large amount of extra data, reject the file as malformed + int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256*4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { + return stbi__errpuc("bad header", "Corrupt BMP"); + } + // we established that bytes_read_so_far is positive and sensible. + // the first half of this test rejects offsets that are either too small positives, or + // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn + // ensures the number computed in the second half of the test can't overflow. + if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + stbi__skip(s, info.offset - bytes_read_so_far); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z=0; + if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } + for (i=0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) width = s->img_x; + else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } + pad = (-width)&3; + if (info.bpp == 1) { + for (j=0; j < (int) s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i=0; i < (int) s->img_x; ++i) { + int color = (v>>bit_offset)&0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + if((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); + } + } else { + for (j=0; j < (int) s->img_y; ++j) { + for (i=0; i < (int) s->img_x; i += 2) { + int v=stbi__get8(s),v2=0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + } + stbi__skip(s, pad); + } + } + } else { + int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; + int z = 0; + int easy=0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) width = 3 * s->img_x; + else if (info.bpp == 16) width = 2*s->img_x; + else /* bpp = 32 and pad = 0 */ width=0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + } + for (j=0; j < (int) s->img_y; ++j) { + if (easy) { + for (i=0; i < (int) s->img_x; ++i) { + unsigned char a; + out[z+2] = stbi__get8(s); + out[z+1] = stbi__get8(s); + out[z+0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i=0; i < (int) s->img_x; ++i) { + stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) out[z++] = STBI__BYTECAST(a); + } + } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j=0; j < (int) s->img_y>>1; ++j) { + stbi_uc *p1 = out + j *s->img_x*target; + stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; + for (i=0; i < (int) s->img_x*target; ++i) { + t = p1[i]; p1[i] = p2[i]; p2[i] = t; + } + } + } + + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + return out; +} +#endif + +// Targa Truevision - TGA +// by Jonathan Dummer +#ifndef STBI_NO_TGA +// returns STBI_rgb or whatever, 0 on error +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) +{ + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) *is_rgb16 = 0; + switch(bits_per_pixel) { + case 8: return STBI_grey; + case 16: if(is_grey) return STBI_grey_alpha; + // fallthrough + case 15: if(is_rgb16) *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: return bits_per_pixel/8; + default: return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) +{ + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if( tga_colormap_type > 1 ) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if ( tga_colormap_type == 1 ) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; + } + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { + stbi__rewind(s); + return 0; + } + stbi__skip(s,4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE + } + stbi__skip(s,9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if( tga_w < 1 ) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if( tga_h < 1 ) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense + stbi__rewind(s); + return 0; + } + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + } + if(!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) *x = tga_w; + if (y) *y = tga_h; + if (comp) *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) +{ + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if ( tga_color_type == 1 ) { // colormapped (paletted) image + if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + stbi__skip(s,4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s,9); // skip colormap specification and image x/y origin + } + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + +errorEnd: + stbi__rewind(s); + return res; +} + +// read 16bit value and convert to 24bit RGB +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) +{ + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later + out[0] = (stbi_uc)((r * 255)/31); + out[1] = (stbi_uc)((g * 255)/31); + out[2] = (stbi_uc)((b * 255)/31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16=0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // do a tiny bit of precessing + if ( tga_image_type >= 8 ) + { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); + + if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset ); + + if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { + for (i=0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height -i - 1 : i; + stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if ( tga_indexed) + { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start ); + // load the palette + tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); + } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i=0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + } + // load the data + for (i=0; i < tga_width * tga_height; ++i) + { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if ( tga_is_RLE ) + { + if ( RLE_count == 0 ) + { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if ( !RLE_repeating ) + { + read_next_pixel = 1; + } + } else + { + read_next_pixel = 1; + } + // OK, if I need to read a pixel, do it now + if ( read_next_pixel ) + { + // load however much data we did have + if ( tga_indexed ) + { + // read in index, then perform the lookup + int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if ( pal_idx >= tga_palette_len ) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx+j]; + } + } else if(tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel + + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i*tga_comp+j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if ( tga_inverted ) + { + for (j = 0; j*2 < tga_height; ++j) + { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) + { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } + } + } + // clear my palette, if I had one + if ( tga_palette != NULL ) + { + STBI_FREE( tga_palette ); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) + { + unsigned char* tga_pixel = tga_data; + for (i=0; i < tga_width * tga_height; ++i) + { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } + + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = + tga_x_origin = tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; +} +#endif + +// ************************************************************************************************* +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context *s) +{ + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) +{ + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; + } + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w,h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6 ); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) + stbi__skip(s,stbi__get32be(s) ); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s) ); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s) ); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *) stbi__malloc(4 * w*h); + + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w*h; + + // Initialize the data to zero. + //memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. + // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. + // Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, + // which we're going to just skip. + stbi__skip(s, h * channelCount * 2 ); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out+channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } + + } else { + // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) + // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out+channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16) stbi__get16be(s); + } else { + stbi_uc *p = out+channel; + if (bitdepth == 16) { // input bpc + for (i = 0; i < pixelCount; i++, p += 4) + *p = (stbi_uc) (stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } + } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i=0; i < w*h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); + pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); + pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); + } + } + } else { + for (i=0; i < w*h; ++i) { + unsigned char *pixel = out + 4*i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); + pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); + pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); + } + } + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + if (comp) *comp = 4; + *y = h; + *x = w; + + return out; +} +#endif + +// ************************************************************************************************* +// Softimage PIC loader +// by Tom Seddon +// +// See http://softimage.wiki.softimage.com/index.php/INFO:_PIC_file_format +// See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ + +#ifndef STBI_NO_PIC +static int stbi__pic_is4(stbi__context *s,const char *str) +{ + int i; + for (i=0; i<4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; + + return 1; +} + +static int stbi__pic_test_core(stbi__context *s) +{ + int i; + + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) + return 0; + + for(i=0;i<84;++i) + stbi__get8(s); + + if (!stbi__pic_is4(s,"PICT")) + return 0; + + return 1; +} + +typedef struct +{ + stbi_uc size,type,channel; +} stbi__pic_packet; + +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) +{ + int mask=0x80, i; + + for (i=0; i<4; ++i, mask>>=1) { + if (channel & mask) { + if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); + dest[i]=stbi__get8(s); + } + } + + return dest; +} + +static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) +{ + int mask=0x80,i; + + for (i=0;i<4; ++i, mask>>=1) + if (channel&mask) + dest[i]=src[i]; +} + +static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) +{ + int act_comp=0,num_packets=0,y,chained; + stbi__pic_packet packets[10]; + + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; + + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return stbi__errpuc("bad format","too many packets"); + + packet = &packets[num_packets++]; + + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + + act_comp |= packet->channel; + + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); + if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + + for(y=0; ytype) { + default: + return stbi__errpuc("bad format","packet has bad compression type"); + + case 0: {//uncompressed + int x; + + for(x=0;xchannel,dest)) + return 0; + break; + } + + case 1://Pure RLE + { + int left=width, i; + + while (left>0) { + stbi_uc count,value[4]; + + count=stbi__get8(s); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); + + if (count > left) + count = (stbi_uc) left; + + if (!stbi__readval(s,packet->channel,value)) return 0; + + for(i=0; ichannel,dest,value); + left -= count; + } + } + break; + + case 2: {//Mixed RLE + int left=width; + while (left>0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count==128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file","scanline overrun"); + + if (!stbi__readval(s,packet->channel,value)) + return 0; + + for(i=0;ichannel,dest,value); + } else { // Raw + ++count; + if (count>left) return stbi__errpuc("bad file","scanline overrun"); + + for(i=0;ichannel,dest)) + return 0; + } + left-=count; + } + break; + } + } + } + } + + return result; +} + +static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) +{ + stbi_uc *result; + int i, x,y, internal_comp; + STBI_NOTUSED(ri); + + if (!comp) comp = &internal_comp; + + for (i=0; i<92; ++i) + stbi__get8(s); + + x = stbi__get16be(s); + y = stbi__get16be(s); + + if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + + stbi__get32be(s); //skip `ratio' + stbi__get16be(s); //skip `fields' + stbi__get16be(s); //skip `pad' + + // intermediate buffer is RGBA + result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); + if (!result) return stbi__errpuc("outofmem", "Out of memory"); + memset(result, 0xff, x*y*4); + + if (!stbi__pic_load_core(s,x,y,comp, result)) { + STBI_FREE(result); + result=0; + } + *px = x; + *py = y; + if (req_comp == 0) req_comp = *comp; + result=stbi__convert_format(result,4,req_comp,x,y); + + return result; +} + +static int stbi__pic_test(stbi__context *s) +{ + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; +} +#endif + +// ************************************************************************************************* +// GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb + +#ifndef STBI_NO_GIF +typedef struct +{ + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; +} stbi__gif_lzw; + +typedef struct +{ + int w,h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; +} stbi__gif; + +static int stbi__gif_test_raw(stbi__context *s) +{ + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') return 0; + if (stbi__get8(s) != 'a') return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) +{ + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) +{ + int i; + for (i=0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) +{ + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + + if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + + if (is_info) return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); + if (!g) return stbi__err("outofmem", "Out of memory"); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind( s ); + return 0; + } + if (x) *x = g->w; + if (y) *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) +{ + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) +{ + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc) init_code; + g->codes[init_code].suffix = (stbi_uc) init_code; + } + + // support no starting clear code + avail = clear+2; + oldcode = -1; + + len = 0; + for(;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32) stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s,len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } + + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } + + p->prefix = (stbi__int16) oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + + stbi__out_gif_code(g, (stbi__uint16) code); + + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } + + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image doesn't support it +// two back is the image from two frames ago, used for a very specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) +{ + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *) stbi__malloc(4 * pcount); + g->background = (stbi_uc *) stbi__malloc(4 * pcount); + g->history = (stbi_uc *) stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); + + // image is treated as "transparent" at the start - ie, nothing overwrites the current background; + // background colour is only used for pixels that are not rendered first frame, after that "background" + // color refers to the color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the old background + } + + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } + + // background is what out is after the undoing of the previou frame; + memcpy( g->background, g->out, 4 * g->w * g->h ); + } + + // clear my history; + memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; + + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; + + g->lflags = stbi__get8(s); + + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; + } + + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *) g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *) g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); + + o = stbi__process_gif_raster(s, g); + if (!o) return NULL; + + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); + } + } + } + + return o; + } + + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; + } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } + } + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); + } + break; + } + + case 0x3B: // gif stream termination code + return (stbi_uc *) s; // using '1' causes warning on some compilers + + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays) +{ + STBI_FREE(g->out); + STBI_FREE(g->history); + STBI_FREE(g->background); + + if (out) STBI_FREE(out); + if (delays && *delays) STBI_FREE(*delays); + return stbi__errpuc("outofmem", "Out of memory"); +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + + STBI_NOTUSED(out_size); + STBI_NOTUSED(delays_size); + + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } + + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); + if (!tmp) + return stbi__load_gif_main_outofmem(&g, out, delays); + else { + out = (stbi_uc*) tmp; + out_size = layers * stride; + } + + if (delays) { + int *new_delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); + if (!new_delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + *delays = new_delays; + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc*)stbi__malloc( layers * stride ); + if (!out) + return stbi__load_gif_main_outofmem(&g, out, delays); + out_size = layers * stride; + if (delays) { + *delays = (int*) stbi__malloc( layers * sizeof(int) ); + if (!*delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + delays_size = layers * sizeof(int); + } + } + memcpy( out + ((layers - 1) * stride), u, stride ); + if (layers >= 2) { + two_back = out - 2 * stride; + } + + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); + + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} + +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); + + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; + + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } + + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); + + return u; +} + +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) +{ + return stbi__gif_info_raw(s,x,y,comp); +} +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR loader +// originally by Nicolas Schulz +#ifndef STBI_NO_HDR +static int stbi__hdr_test_core(stbi__context *s, const char *signature) +{ + int i; + for (i=0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context* s) +{ + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if(!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) +{ + int len=0; + char c = '\0'; + + c = (char) stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN-1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char) stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) +{ + if ( input[3] != 0 ) { + float f1; + // Exponent + f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) output[1] = 1; + if (req_comp == 4) output[3] = 1; + } else { + switch (req_comp) { + case 4: output[3] = 1; /* fallthrough */ + case 3: output[0] = output[1] = output[2] = 0; + break; + case 2: output[1] = 1; /* fallthrough */ + case 1: output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1,c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s,buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int) strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) *comp = 3; + if (req_comp == 0) req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if ( width < 8 || width >= 32768) { + // Read flat data + for (j=0; j < height; ++j) { + for (i=0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); + } + } + } else { + // Read RLE-encoded data + scanline = NULL; + + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a decoded + // pixel (note this can't be a valid pixel--one of RGB must be >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc) c1; + rgbe[1] = (stbi_uc) c2; + rgbe[2] = (stbi_uc) len; + rgbe[3] = (stbi_uc) stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense + } + len <<= 8; + len |= stbi__get8(s); + if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } + if (scanline == NULL) { + scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } + } + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } + } + for (i=0; i < width; ++i) + stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind( s ); + return 0; + } + + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) { + stbi__rewind( s ); + return 0; + } + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *y = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *x = (int) strtol(token, NULL, 10); + *comp = 3; + return 1; +} +#endif // STBI_NO_HDR + +#ifndef STBI_NO_BMP +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) +{ + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + if (p == NULL) { + stbi__rewind( s ); + return 0; + } + if (x) *x = s->img_x; + if (y) *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; +} +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) +{ + int channelCount, dummy, depth; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind( s ); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) +{ + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + STBI_NOTUSED(stbi__get32be(s)); + STBI_NOTUSED(stbi__get32be(s)); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind( s ); + return 0; + } + return 1; +} +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) +{ + int act_comp=0,num_packets=0,chained,dummy; + stbi__pic_packet packets[10]; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind( s); + return 0; + } + if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind( s ); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return 0; + + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; + + if (stbi__at_eof(s)) { + stbi__rewind( s ); + return 0; + } + if (packet->size != 8) { + stbi__rewind( s ); + return 0; + } + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); + + return 1; +} +#endif + +// ************************************************************************************************* +// Portable Gray Map and Portable Pixel Map loader +// by Ken Miller +// +// PGM: http://netpbm.sourceforge.net/doc/pgm.html +// PPM: http://netpbm.sourceforge.net/doc/ppm.html +// +// Known limitations: +// Does not support comments in the header section +// Does not support ASCII image data (formats P2 and P3) + +#ifndef STBI_NO_PNM + +static int stbi__pnm_test(stbi__context *s) +{ + char p, t; + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind( s ); + return 0; + } + return 1; +} + +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + STBI_NOTUSED(ri); + + ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); + if (ri->bits_per_channel == 0) + return 0; + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + + if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) + return stbi__errpuc("too large", "PNM too large"); + + out = (stbi_uc *) stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { + STBI_FREE(out); + return stbi__errpuc("bad PNM", "PNM file truncated"); + } + + if (req_comp && req_comp != s->img_n) { + if (ri->bits_per_channel == 16) { + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, s->img_n, req_comp, s->img_x, s->img_y); + } else { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + } + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + return out; +} + +static int stbi__pnm_isspace(char c) +{ + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +} + +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) +{ + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char) stbi__get8(s); + + if (stbi__at_eof(s) || *c != '#') + break; + + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) + *c = (char) stbi__get8(s); + } +} + +static int stbi__pnm_isdigit(char c) +{ + return c >= '0' && c <= '9'; +} + +static int stbi__pnm_getinteger(stbi__context *s, char *c) +{ + int value = 0; + + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value*10 + (*c - '0'); + *c = (char) stbi__get8(s); + if((value > 214748364) || (value == 214748364 && *c > '7')) + return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); + } + + return value; +} + +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) +{ + int maxv, dummy; + char c, p, t; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + stbi__rewind(s); + + // Get identifier + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + + c = (char) stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); + + *x = stbi__pnm_getinteger(s, &c); // read width + if(*x == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + *y = stbi__pnm_getinteger(s, &c); // read height + if (*y == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + maxv = stbi__pnm_getinteger(s, &c); // read max value + if (maxv > 65535) + return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); + else if (maxv > 255) + return 16; + else + return 8; +} + +static int stbi__pnm_is16(stbi__context *s) +{ + if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) + return 1; + return 0; +} +#endif + +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) +{ + #ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) return 1; + #endif + + // test tga last because it's a crappy test! + #ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; + #endif + return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +} + +static int stbi__is_16_main(stbi__context *s) +{ + #ifndef STBI_NO_PNG + if (stbi__png_is16(s)) return 1; + #endif + + #ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) return 1; + #endif + + #ifndef STBI_NO_PNM + if (stbi__pnm_is16(s)) return 1; + #endif + return 0; +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s,x,y,comp); + fseek(f,pos,SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f,pos,SEEK_SET); + return r; +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__info_main(&s,x,y,comp); +} + +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__info_main(&s,x,y,comp); +} + +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__is_16_main(&s); +} + +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__is_16_main(&s); +} + +#endif // STB_IMAGE_IMPLEMENTATION + +/* + revision history: + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 1-bit BMP + *_is_16_bit api + avoid warnings + 2.16 (2017-07-23) all functions have 16-bit variants; + STBI_NO_STDIO works again; + compilation fixes; + fix rounding in unpremultiply; + optimize vertical flip; + disable raw_len validation; + documentation fixes + 2.15 (2017-03-18) fix png-1,2,4 bug; now all Imagenet JPGs decode; + warning fixes; disable run-time SSE detection on gcc; + uniform handling of optional "return" values; + thread-safe initialization of zlib tables + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) allocate large structures on the stack + remove white matting for transparent PSD + fix reported channel count for PNG & BMP + re-enable SSE2 in non-gcc 64-bit + support RGB-formatted JPEG + read 16-bit PNGs (only as 8-bit) + 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED + 2.09 (2016-01-16) allow comments in PNM files + 16-bit-per-pixel TGA (not bit-per-component) + info() for TGA could break due to .hdr handling + info() for BMP to shares code instead of sloppy parse + can use STBI_REALLOC_SIZED if allocator doesn't support realloc + code cleanup + 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA + 2.07 (2015-09-13) fix compiler warnings + partial animated GIF support + limited 16-bpc PSD support + #ifdef unused functions + bug with < 92 byte PIC,PNM,HDR,TGA + 2.06 (2015-04-19) fix bug where PSD returns wrong '*comp' value + 2.05 (2015-04-19) fix bug in progressive JPEG handling, fix warning + 2.04 (2015-04-15) try to re-enable SIMD on MinGW 64-bit + 2.03 (2015-04-12) extra corruption checking (mmozeiko) + stbi_set_flip_vertically_on_load (nguillemot) + fix NEON support; fix mingw support + 2.02 (2015-01-19) fix incorrect assert, fix warning + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 + 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG + 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) + progressive JPEG (stb) + PGM/PPM support (Ken Miller) + STBI_MALLOC,STBI_REALLOC,STBI_FREE + GIF bugfix -- seemingly never worked + STBI_NO_*, STBI_ONLY_* + 1.48 (2014-12-14) fix incorrectly-named assert() + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) + optimize PNG (ryg) + fix bug in interlaced PNG with user-specified channel count (stb) + 1.46 (2014-08-26) + fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG + 1.45 (2014-08-16) + fix MSVC-ARM internal compiler error by wrapping malloc + 1.44 (2014-08-07) + various warning fixes from Ronny Chevalier + 1.43 (2014-07-15) + fix MSVC-only compiler problem in code changed in 1.42 + 1.42 (2014-07-09) + don't define _CRT_SECURE_NO_WARNINGS (affects user code) + fixes to stbi__cleanup_jpeg path + added STBI_ASSERT to avoid requiring assert.h + 1.41 (2014-06-25) + fix search&replace from 1.36 that messed up comments/error messages + 1.40 (2014-06-22) + fix gcc struct-initialization warning + 1.39 (2014-06-15) + fix to TGA optimization when req_comp != number of components in TGA; + fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) + add support for BMP version 5 (more ignored fields) + 1.38 (2014-06-06) + suppress MSVC warnings on integer casts truncating values + fix accidental rename of 'skip' field of I/O + 1.37 (2014-06-04) + remove duplicate typedef + 1.36 (2014-06-03) + convert to header file single-file library + if de-iphone isn't set, load iphone images color-swapped instead of returning NULL + 1.35 (2014-05-27) + various warnings + fix broken STBI_SIMD path + fix bug where stbi_load_from_file no longer left file pointer in correct place + fix broken non-easy path for 32-bit BMP (possibly never used) + TGA optimization by Arseny Kapoulkine + 1.34 (unknown) + use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case + 1.33 (2011-07-14) + make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements + 1.32 (2011-07-13) + support for "info" function for all supported filetypes (SpartanJ) + 1.31 (2011-06-20) + a few more leak fixes, bug in PNG handling (SpartanJ) + 1.30 (2011-06-11) + added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + removed deprecated format-specific test/load functions + removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway + error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) + fix inefficiency in decoding 32-bit BMP (David Woo) + 1.29 (2010-08-16) + various warning fixes from Aurelien Pocheville + 1.28 (2010-08-01) + fix bug in GIF palette transparency (SpartanJ) + 1.27 (2010-08-01) + cast-to-stbi_uc to fix warnings + 1.26 (2010-07-24) + fix bug in file buffering for PNG reported by SpartanJ + 1.25 (2010-07-17) + refix trans_data warning (Won Chun) + 1.24 (2010-07-12) + perf improvements reading from files on platforms with lock-heavy fgetc() + minor perf improvements for jpeg + deprecated type-specific functions so we'll get feedback if they're needed + attempt to fix trans_data warning (Won Chun) + 1.23 fixed bug in iPhone support + 1.22 (2010-07-10) + removed image *writing* support + stbi_info support from Jetro Lauha + GIF support from Jean-Marc Lienher + iPhone PNG-extensions from James Brown + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) + 1.21 fix use of 'stbi_uc' in header (reported by jon blow) + 1.20 added support for Softimage PIC, by Tom Seddon + 1.19 bug in interlaced PNG corruption check (found by ryg) + 1.18 (2008-08-02) + fix a threading bug (local mutable static) + 1.17 support interlaced PNG + 1.16 major bugfix - stbi__convert_format converted one too many pixels + 1.15 initialize some fields for thread safety + 1.14 fix threadsafe conversion bug + header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + 1.13 threadsafe + 1.12 const qualifiers in the API + 1.11 Support installable IDCT, colorspace conversion routines + 1.10 Fixes for 64-bit (don't use "unsigned long") + optimized upsampling by Fabian "ryg" Giesen + 1.09 Fix format-conversion for PSD code (bad global variables!) + 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz + 1.07 attempt to fix C++ warning/errors again + 1.06 attempt to fix C++ warning/errors again + 1.05 fix TGA loading to return correct *comp and use good luminance calc + 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free + 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR + 1.02 support for (subset of) HDR files, float interface for preferred access to them + 1.01 fix bug: possible bug in handling right-side up bmps... not sure + fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all + 1.00 interface to zlib that skips zlib header + 0.99 correct handling of alpha in palette + 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 0.97 jpeg errors on too large a file; also catch another malloc failure + 0.96 fix detection of invalid v value - particleman@mollyrocket forum + 0.95 during header scan, seek to markers in case of padding + 0.94 STBI_NO_STDIO to disable stdio usage; rename all #defines the same + 0.93 handle jpegtran output; verbose errors + 0.92 read 4,8,16,24,32-bit BMP files of several formats + 0.91 output 24-bit Windows 3.0 BMP files + 0.90 fix a few more warnings; bump version number to approach 1.0 + 0.61 bugfixes due to Marc LeBlanc, Christopher Lloyd + 0.60 fix compiling as c++ + 0.59 fix warnings: merge Dave Moore's -Wall fixes + 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available + 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.55 fix bug: restart_interval not initialized to 0 + 0.54 allow NULL for 'int *comp' + 0.53 fix bug in png 3->4; speedup png decoding + 0.52 png handles req_comp=3,4 directly; minor cleanup; jpeg comments + 0.51 obey req_comp requests, 1-component jpegs return as 1-component, + on 'test' only check type, not whether we support this variant + 0.50 (2006-11-19) + first released version +*/ + + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/gui/dependencies/stb_image/stb_image_write.h b/gui/dependencies/stb_image/stb_image_write.h new file mode 100644 index 0000000000000000000000000000000000000000..e4b32ed1bc32ef9c962acbf47a9d10af01939e08 --- /dev/null +++ b/gui/dependencies/stb_image/stb_image_write.h @@ -0,0 +1,1724 @@ +/* stb_image_write - v1.16 - public domain - http://nothings.org/stb + writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015 + no warranty implied; use at your own risk + + Before #including, + + #define STB_IMAGE_WRITE_IMPLEMENTATION + + in the file that you want to have the implementation. + + Will probably not work correctly with strict-aliasing optimizations. + +ABOUT: + + This header file is a library for writing images to C stdio or a callback. + + The PNG output is not optimal; it is 20-50% larger than the file + written by a decent optimizing implementation; though providing a custom + zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that. + This library is designed for source code compactness and simplicity, + not optimal image file size or run-time performance. + +BUILDING: + + You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h. + You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace + malloc,realloc,free. + You can #define STBIW_MEMMOVE() to replace memmove() + You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function + for PNG compression (instead of the builtin one), it must have the following signature: + unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality); + The returned data will be freed with STBIW_FREE() (free() by default), + so it must be heap allocated with STBIW_MALLOC() (malloc() by default), + +UNICODE: + + If compiling for Windows and you wish to use Unicode filenames, compile + with + #define STBIW_WINDOWS_UTF8 + and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert + Windows wchar_t filenames to utf8. + +USAGE: + + There are five functions, one for each image file format: + + int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality); + int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); + + void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically + + There are also five equivalent functions that use an arbitrary write function. You are + expected to open/close your file-equivalent before and after calling these: + + int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); + int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + + where the callback is: + void stbi_write_func(void *context, void *data, int size); + + You can configure it with these global variables: + int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE + int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression + int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode + + + You can define STBI_WRITE_NO_STDIO to disable the file variant of these + functions, so the library will not use stdio.h at all. However, this will + also disable HDR writing, because it requires stdio for formatted output. + + Each function returns 0 on failure and non-0 on success. + + The functions create an image file defined by the parameters. The image + is a rectangle of pixels stored from left-to-right, top-to-bottom. + Each pixel contains 'comp' channels of data stored interleaved with 8-bits + per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is + monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall. + The *data pointer points to the first byte of the top-left-most pixel. + For PNG, "stride_in_bytes" is the distance in bytes from the first byte of + a row of pixels to the first byte of the next row of pixels. + + PNG creates output files with the same number of components as the input. + The BMP format expands Y to RGB in the file format and does not + output alpha. + + PNG supports writing rectangles of data even when the bytes storing rows of + data are not consecutive in memory (e.g. sub-rectangles of a larger image), + by supplying the stride between the beginning of adjacent rows. The other + formats do not. (Thus you cannot write a native-format BMP through the BMP + writer, both because it is in BGR order and because it may have padding + at the end of the line.) + + PNG allows you to set the deflate compression level by setting the global + variable 'stbi_write_png_compression_level' (it defaults to 8). + + HDR expects linear float data. Since the format is always 32-bit rgb(e) + data, alpha (if provided) is discarded, and for monochrome data it is + replicated across all three channels. + + TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed + data, set the global variable 'stbi_write_tga_with_rle' to 0. + + JPEG does ignore alpha channels in input data; quality is between 1 and 100. + Higher quality looks better but results in a bigger image. + JPEG baseline (no JPEG progressive). + +CREDITS: + + + Sean Barrett - PNG/BMP/TGA + Baldur Karlsson - HDR + Jean-Sebastien Guay - TGA monochrome + Tim Kelsey - misc enhancements + Alan Hickman - TGA RLE + Emmanuel Julien - initial file IO callback implementation + Jon Olick - original jo_jpeg.cpp code + Daniel Gibson - integrate JPEG, allow external zlib + Aarni Koskela - allow choosing PNG filter + + bugfixes: + github:Chribba + Guillaume Chereau + github:jry2 + github:romigrou + Sergio Gonzalez + Jonas Karlsson + Filip Wasil + Thatcher Ulrich + github:poppolopoppo + Patrick Boettcher + github:xeekworx + Cap Petschulat + Simon Rodriguez + Ivan Tikhonov + github:ignotion + Adam Schackart + Andrew Kensler + +LICENSE + + See end of file for license information. + +*/ + +#ifndef INCLUDE_STB_IMAGE_WRITE_H +#define INCLUDE_STB_IMAGE_WRITE_H + +#include + +// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline' +#ifndef STBIWDEF +#ifdef STB_IMAGE_WRITE_STATIC +#define STBIWDEF static +#else +#ifdef __cplusplus +#define STBIWDEF extern "C" +#else +#define STBIWDEF extern +#endif +#endif +#endif + +#ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations +STBIWDEF int stbi_write_tga_with_rle; +STBIWDEF int stbi_write_png_compression_level; +STBIWDEF int stbi_write_force_png_filter; +#endif + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); + +#ifdef STBIW_WINDOWS_UTF8 +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +#endif +#endif + +typedef void stbi_write_func(void *context, void *data, int size); + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + +STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); + +#endif//INCLUDE_STB_IMAGE_WRITE_H + +#ifdef STB_IMAGE_WRITE_IMPLEMENTATION + +#ifdef _WIN32 + #ifndef _CRT_SECURE_NO_WARNINGS + #define _CRT_SECURE_NO_WARNINGS + #endif + #ifndef _CRT_NONSTDC_NO_DEPRECATE + #define _CRT_NONSTDC_NO_DEPRECATE + #endif +#endif + +#ifndef STBI_WRITE_NO_STDIO +#include +#endif // STBI_WRITE_NO_STDIO + +#include +#include +#include +#include + +#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) +// ok +#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)." +#endif + +#ifndef STBIW_MALLOC +#define STBIW_MALLOC(sz) malloc(sz) +#define STBIW_REALLOC(p,newsz) realloc(p,newsz) +#define STBIW_FREE(p) free(p) +#endif + +#ifndef STBIW_REALLOC_SIZED +#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz) +#endif + + +#ifndef STBIW_MEMMOVE +#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz) +#endif + + +#ifndef STBIW_ASSERT +#include +#define STBIW_ASSERT(x) assert(x) +#endif + +#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff) + +#ifdef STB_IMAGE_WRITE_STATIC +static int stbi_write_png_compression_level = 8; +static int stbi_write_tga_with_rle = 1; +static int stbi_write_force_png_filter = -1; +#else +int stbi_write_png_compression_level = 8; +int stbi_write_tga_with_rle = 1; +int stbi_write_force_png_filter = -1; +#endif + +static int stbi__flip_vertically_on_write = 0; + +STBIWDEF void stbi_flip_vertically_on_write(int flag) +{ + stbi__flip_vertically_on_write = flag; +} + +typedef struct +{ + stbi_write_func *func; + void *context; + unsigned char buffer[64]; + int buf_used; +} stbi__write_context; + +// initialize a callback-based context +static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context) +{ + s->func = c; + s->context = context; +} + +#ifndef STBI_WRITE_NO_STDIO + +static void stbi__stdio_write(void *context, void *data, int size) +{ + fwrite(data,1,size,(FILE*) context); +} + +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) +#ifdef __cplusplus +#define STBIW_EXTERN extern "C" +#else +#define STBIW_EXTERN extern +#endif +STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); + +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +} +#endif + +static FILE *stbiw__fopen(char const *filename, char const *mode) +{ + FILE *f; +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f=0; +#else + f = fopen(filename, mode); +#endif + return f; +} + +static int stbi__start_write_file(stbi__write_context *s, const char *filename) +{ + FILE *f = stbiw__fopen(filename, "wb"); + stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f); + return f != NULL; +} + +static void stbi__end_write_file(stbi__write_context *s) +{ + fclose((FILE *)s->context); +} + +#endif // !STBI_WRITE_NO_STDIO + +typedef unsigned int stbiw_uint32; +typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1]; + +static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v) +{ + while (*fmt) { + switch (*fmt++) { + case ' ': break; + case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int)); + s->func(s->context,&x,1); + break; } + case '2': { int x = va_arg(v,int); + unsigned char b[2]; + b[0] = STBIW_UCHAR(x); + b[1] = STBIW_UCHAR(x>>8); + s->func(s->context,b,2); + break; } + case '4': { stbiw_uint32 x = va_arg(v,int); + unsigned char b[4]; + b[0]=STBIW_UCHAR(x); + b[1]=STBIW_UCHAR(x>>8); + b[2]=STBIW_UCHAR(x>>16); + b[3]=STBIW_UCHAR(x>>24); + s->func(s->context,b,4); + break; } + default: + STBIW_ASSERT(0); + return; + } + } +} + +static void stbiw__writef(stbi__write_context *s, const char *fmt, ...) +{ + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); +} + +static void stbiw__write_flush(stbi__write_context *s) +{ + if (s->buf_used) { + s->func(s->context, &s->buffer, s->buf_used); + s->buf_used = 0; + } +} + +static void stbiw__putc(stbi__write_context *s, unsigned char c) +{ + s->func(s->context, &c, 1); +} + +static void stbiw__write1(stbi__write_context *s, unsigned char a) +{ + if ((size_t)s->buf_used + 1 > sizeof(s->buffer)) + stbiw__write_flush(s); + s->buffer[s->buf_used++] = a; +} + +static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c) +{ + int n; + if ((size_t)s->buf_used + 3 > sizeof(s->buffer)) + stbiw__write_flush(s); + n = s->buf_used; + s->buf_used = n+3; + s->buffer[n+0] = a; + s->buffer[n+1] = b; + s->buffer[n+2] = c; +} + +static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d) +{ + unsigned char bg[3] = { 255, 0, 255}, px[3]; + int k; + + if (write_alpha < 0) + stbiw__write1(s, d[comp - 1]); + + switch (comp) { + case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case + case 1: + if (expand_mono) + stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp + else + stbiw__write1(s, d[0]); // monochrome TGA + break; + case 4: + if (!write_alpha) { + // composite against pink background + for (k = 0; k < 3; ++k) + px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255; + stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]); + break; + } + /* FALLTHROUGH */ + case 3: + stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]); + break; + } + if (write_alpha > 0) + stbiw__write1(s, d[comp - 1]); +} + +static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono) +{ + stbiw_uint32 zero = 0; + int i,j, j_end; + + if (y <= 0) + return; + + if (stbi__flip_vertically_on_write) + vdir *= -1; + + if (vdir < 0) { + j_end = -1; j = y-1; + } else { + j_end = y; j = 0; + } + + for (; j != j_end; j += vdir) { + for (i=0; i < x; ++i) { + unsigned char *d = (unsigned char *) data + (j*x+i)*comp; + stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d); + } + stbiw__write_flush(s); + s->func(s->context, &zero, scanline_pad); + } +} + +static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...) +{ + if (y < 0 || x < 0) { + return 0; + } else { + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); + stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono); + return 1; + } +} + +static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data) +{ + if (comp != 4) { + // write RGB bitmap + int pad = (-x*3) & 3; + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad, + "11 4 22 4" "4 44 22 444444", + 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header + 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header + } else { + // RGBA bitmaps need a v4 header + // use BI_BITFIELDS mode with 32bpp and alpha mask + // (straight BI_RGB with alpha mask doesn't work in most readers) + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0, + "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444", + 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header + 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header + } +} + +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_bmp_core(&s, x, y, comp, data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_bmp_core(&s, x, y, comp, data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif //!STBI_WRITE_NO_STDIO + +static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data) +{ + int has_alpha = (comp == 2 || comp == 4); + int colorbytes = has_alpha ? comp-1 : comp; + int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3 + + if (y < 0 || x < 0) + return 0; + + if (!stbi_write_tga_with_rle) { + return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0, + "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8); + } else { + int i,j,k; + int jend, jdir; + + stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8); + + if (stbi__flip_vertically_on_write) { + j = 0; + jend = y; + jdir = 1; + } else { + j = y-1; + jend = -1; + jdir = -1; + } + for (; j != jend; j += jdir) { + unsigned char *row = (unsigned char *) data + j * x * comp; + int len; + + for (i = 0; i < x; i += len) { + unsigned char *begin = row + i * comp; + int diff = 1; + len = 1; + + if (i < x - 1) { + ++len; + diff = memcmp(begin, row + (i + 1) * comp, comp); + if (diff) { + const unsigned char *prev = begin; + for (k = i + 2; k < x && len < 128; ++k) { + if (memcmp(prev, row + k * comp, comp)) { + prev += comp; + ++len; + } else { + --len; + break; + } + } + } else { + for (k = i + 2; k < x && len < 128; ++k) { + if (!memcmp(begin, row + k * comp, comp)) { + ++len; + } else { + break; + } + } + } + } + + if (diff) { + unsigned char header = STBIW_UCHAR(len - 1); + stbiw__write1(s, header); + for (k = 0; k < len; ++k) { + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp); + } + } else { + unsigned char header = STBIW_UCHAR(len - 129); + stbiw__write1(s, header); + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin); + } + } + } + stbiw__write_flush(s); + } + return 1; +} + +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_tga_core(&s, x, y, comp, (void *) data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_tga_core(&s, x, y, comp, (void *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR writer +// by Baldur Karlsson + +#define stbiw__max(a, b) ((a) > (b) ? (a) : (b)) + +#ifndef STBI_WRITE_NO_STDIO + +static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear) +{ + int exponent; + float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2])); + + if (maxcomp < 1e-32f) { + rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0; + } else { + float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp; + + rgbe[0] = (unsigned char)(linear[0] * normalize); + rgbe[1] = (unsigned char)(linear[1] * normalize); + rgbe[2] = (unsigned char)(linear[2] * normalize); + rgbe[3] = (unsigned char)(exponent + 128); + } +} + +static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte) +{ + unsigned char lengthbyte = STBIW_UCHAR(length+128); + STBIW_ASSERT(length+128 <= 255); + s->func(s->context, &lengthbyte, 1); + s->func(s->context, &databyte, 1); +} + +static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data) +{ + unsigned char lengthbyte = STBIW_UCHAR(length); + STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code + s->func(s->context, &lengthbyte, 1); + s->func(s->context, data, length); +} + +static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline) +{ + unsigned char scanlineheader[4] = { 2, 2, 0, 0 }; + unsigned char rgbe[4]; + float linear[3]; + int x; + + scanlineheader[2] = (width&0xff00)>>8; + scanlineheader[3] = (width&0x00ff); + + /* skip RLE for images too small or large */ + if (width < 8 || width >= 32768) { + for (x=0; x < width; x++) { + switch (ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + s->func(s->context, rgbe, 4); + } + } else { + int c,r; + /* encode into scratch buffer */ + for (x=0; x < width; x++) { + switch(ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + scratch[x + width*0] = rgbe[0]; + scratch[x + width*1] = rgbe[1]; + scratch[x + width*2] = rgbe[2]; + scratch[x + width*3] = rgbe[3]; + } + + s->func(s->context, scanlineheader, 4); + + /* RLE each component separately */ + for (c=0; c < 4; c++) { + unsigned char *comp = &scratch[width*c]; + + x = 0; + while (x < width) { + // find first run + r = x; + while (r+2 < width) { + if (comp[r] == comp[r+1] && comp[r] == comp[r+2]) + break; + ++r; + } + if (r+2 >= width) + r = width; + // dump up to first run + while (x < r) { + int len = r-x; + if (len > 128) len = 128; + stbiw__write_dump_data(s, len, &comp[x]); + x += len; + } + // if there's a run, output it + if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd + // find next byte after run + while (r < width && comp[r] == comp[x]) + ++r; + // output run up to r + while (x < r) { + int len = r-x; + if (len > 127) len = 127; + stbiw__write_run_data(s, len, comp[x]); + x += len; + } + } + } + } + } +} + +static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data) +{ + if (y <= 0 || x <= 0 || data == NULL) + return 0; + else { + // Each component is stored separately. Allocate scratch space for full output scanline. + unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4); + int i, len; + char buffer[128]; + char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n"; + s->func(s->context, header, sizeof(header)-1); + +#ifdef __STDC_LIB_EXT1__ + len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#else + len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#endif + s->func(s->context, buffer, len); + + for(i=0; i < y; i++) + stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i)); + STBIW_FREE(scratch); + return 1; + } +} + +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_hdr_core(&s, x, y, comp, (float *) data); +} + +STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif // STBI_WRITE_NO_STDIO + + +////////////////////////////////////////////////////////////////////////////// +// +// PNG writer +// + +#ifndef STBIW_ZLIB_COMPRESS +// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size() +#define stbiw__sbraw(a) ((int *) (void *) (a) - 2) +#define stbiw__sbm(a) stbiw__sbraw(a)[0] +#define stbiw__sbn(a) stbiw__sbraw(a)[1] + +#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a)) +#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0) +#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a))) + +#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v)) +#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0) +#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0) + +static void *stbiw__sbgrowf(void **arr, int increment, int itemsize) +{ + int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1; + void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2); + STBIW_ASSERT(p); + if (p) { + if (!*arr) ((int *) p)[1] = 0; + *arr = (void *) ((int *) p + 2); + stbiw__sbm(*arr) = m; + } + return *arr; +} + +static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount) +{ + while (*bitcount >= 8) { + stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer)); + *bitbuffer >>= 8; + *bitcount -= 8; + } + return data; +} + +static int stbiw__zlib_bitrev(int code, int codebits) +{ + int res=0; + while (codebits--) { + res = (res << 1) | (code & 1); + code >>= 1; + } + return res; +} + +static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit) +{ + int i; + for (i=0; i < limit && i < 258; ++i) + if (a[i] != b[i]) break; + return i; +} + +static unsigned int stbiw__zhash(unsigned char *data) +{ + stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16); + hash ^= hash << 3; + hash += hash >> 5; + hash ^= hash << 4; + hash += hash >> 17; + hash ^= hash << 25; + hash += hash >> 6; + return hash; +} + +#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount)) +#define stbiw__zlib_add(code,codebits) \ + (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush()) +#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c) +// default huffman tables +#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8) +#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9) +#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7) +#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8) +#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n)) +#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n)) + +#define stbiw__ZHASH 16384 + +#endif // STBIW_ZLIB_COMPRESS + +STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality) +{ +#ifdef STBIW_ZLIB_COMPRESS + // user provided a zlib compress implementation, use that + return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality); +#else // use builtin + static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 }; + static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; + static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 }; + static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 }; + unsigned int bitbuf=0; + int i,j, bitcount=0; + unsigned char *out = NULL; + unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**)); + if (hash_table == NULL) + return NULL; + if (quality < 5) quality = 5; + + stbiw__sbpush(out, 0x78); // DEFLATE 32K window + stbiw__sbpush(out, 0x5e); // FLEVEL = 1 + stbiw__zlib_add(1,1); // BFINAL = 1 + stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman + + for (i=0; i < stbiw__ZHASH; ++i) + hash_table[i] = NULL; + + i=0; + while (i < data_len-3) { + // hash next 3 bytes of data to be compressed + int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3; + unsigned char *bestloc = 0; + unsigned char **hlist = hash_table[h]; + int n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32768) { // if entry lies within window + int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i); + if (d >= best) { best=d; bestloc=hlist[j]; } + } + } + // when hash table entry is too long, delete half the entries + if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) { + STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality); + stbiw__sbn(hash_table[h]) = quality; + } + stbiw__sbpush(hash_table[h],data+i); + + if (bestloc) { + // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal + h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1); + hlist = hash_table[h]; + n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32767) { + int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1); + if (e > best) { // if next match is better, bail on current match + bestloc = NULL; + break; + } + } + } + } + + if (bestloc) { + int d = (int) (data+i - bestloc); // distance back + STBIW_ASSERT(d <= 32767 && best <= 258); + for (j=0; best > lengthc[j+1]-1; ++j); + stbiw__zlib_huff(j+257); + if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]); + for (j=0; d > distc[j+1]-1; ++j); + stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5); + if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]); + i += best; + } else { + stbiw__zlib_huffb(data[i]); + ++i; + } + } + // write out final bytes + for (;i < data_len; ++i) + stbiw__zlib_huffb(data[i]); + stbiw__zlib_huff(256); // end of block + // pad with 0 bits to byte boundary + while (bitcount) + stbiw__zlib_add(0,1); + + for (i=0; i < stbiw__ZHASH; ++i) + (void) stbiw__sbfree(hash_table[i]); + STBIW_FREE(hash_table); + + // store uncompressed instead if compression was worse + if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) { + stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1 + for (j = 0; j < data_len;) { + int blocklen = data_len - j; + if (blocklen > 32767) blocklen = 32767; + stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression + stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN + stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN + stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8)); + memcpy(out+stbiw__sbn(out), data+j, blocklen); + stbiw__sbn(out) += blocklen; + j += blocklen; + } + } + + { + // compute adler32 on input + unsigned int s1=1, s2=0; + int blocklen = (int) (data_len % 5552); + j=0; + while (j < data_len) { + for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; } + s1 %= 65521; s2 %= 65521; + j += blocklen; + blocklen = 5552; + } + stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s2)); + stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s1)); + } + *out_len = stbiw__sbn(out); + // make returned pointer freeable + STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len); + return (unsigned char *) stbiw__sbraw(out); +#endif // STBIW_ZLIB_COMPRESS +} + +static unsigned int stbiw__crc32(unsigned char *buffer, int len) +{ +#ifdef STBIW_CRC32 + return STBIW_CRC32(buffer, len); +#else + static unsigned int crc_table[256] = + { + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, + 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, + 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, + 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, + 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, + 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, + 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, + 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, + 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, + 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, + 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, + 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, + 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, + 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, + 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, + 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD, + 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, + 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, + 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, + 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, + 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, + 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, + 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, + 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, + 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, + 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, + 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, + 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, + 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, + 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, + 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, + 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D + }; + + unsigned int crc = ~0u; + int i; + for (i=0; i < len; ++i) + crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)]; + return ~crc; +#endif +} + +#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4) +#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v)); +#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3]) + +static void stbiw__wpcrc(unsigned char **data, int len) +{ + unsigned int crc = stbiw__crc32(*data - len - 4, len+4); + stbiw__wp32(*data, crc); +} + +static unsigned char stbiw__paeth(int a, int b, int c) +{ + int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c); + if (pa <= pb && pa <= pc) return STBIW_UCHAR(a); + if (pb <= pc) return STBIW_UCHAR(b); + return STBIW_UCHAR(c); +} + +// @OPTIMIZE: provide an option that always forces left-predict or paeth predict +static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer) +{ + static int mapping[] = { 0,1,2,3,4 }; + static int firstmap[] = { 0,1,0,5,6 }; + int *mymap = (y != 0) ? mapping : firstmap; + int i; + int type = mymap[filter_type]; + unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y); + int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes; + + if (type==0) { + memcpy(line_buffer, z, width*n); + return; + } + + // first loop isn't optimized since it's just one pixel + for (i = 0; i < n; ++i) { + switch (type) { + case 1: line_buffer[i] = z[i]; break; + case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break; + case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break; + case 5: line_buffer[i] = z[i]; break; + case 6: line_buffer[i] = z[i]; break; + } + } + switch (type) { + case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break; + case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break; + case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break; + case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break; + case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break; + } +} + +STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len) +{ + int force_filter = stbi_write_force_png_filter; + int ctype[5] = { -1, 0, 4, 2, 6 }; + unsigned char sig[8] = { 137,80,78,71,13,10,26,10 }; + unsigned char *out,*o, *filt, *zlib; + signed char *line_buffer; + int j,zlen; + + if (stride_bytes == 0) + stride_bytes = x * n; + + if (force_filter >= 5) { + force_filter = -1; + } + + filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0; + line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; } + for (j=0; j < y; ++j) { + int filter_type; + if (force_filter > -1) { + filter_type = force_filter; + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer); + } else { // Estimate the best filter by running through all of them: + int best_filter = 0, best_filter_val = 0x7fffffff, est, i; + for (filter_type = 0; filter_type < 5; filter_type++) { + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer); + + // Estimate the entropy of the line using this filter; the less, the better. + est = 0; + for (i = 0; i < x*n; ++i) { + est += abs((signed char) line_buffer[i]); + } + if (est < best_filter_val) { + best_filter_val = est; + best_filter = filter_type; + } + } + if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer); + filter_type = best_filter; + } + } + // when we get here, filter_type contains the filter type, and line_buffer contains the data + filt[j*(x*n+1)] = (unsigned char) filter_type; + STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n); + } + STBIW_FREE(line_buffer); + zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level); + STBIW_FREE(filt); + if (!zlib) return 0; + + // each tag requires 12 bytes of overhead + out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12); + if (!out) return 0; + *out_len = 8 + 12+13 + 12+zlen + 12; + + o=out; + STBIW_MEMMOVE(o,sig,8); o+= 8; + stbiw__wp32(o, 13); // header length + stbiw__wptag(o, "IHDR"); + stbiw__wp32(o, x); + stbiw__wp32(o, y); + *o++ = 8; + *o++ = STBIW_UCHAR(ctype[n]); + *o++ = 0; + *o++ = 0; + *o++ = 0; + stbiw__wpcrc(&o,13); + + stbiw__wp32(o, zlen); + stbiw__wptag(o, "IDAT"); + STBIW_MEMMOVE(o, zlib, zlen); + o += zlen; + STBIW_FREE(zlib); + stbiw__wpcrc(&o, zlen); + + stbiw__wp32(o,0); + stbiw__wptag(o, "IEND"); + stbiw__wpcrc(&o,0); + + STBIW_ASSERT(o == out + *out_len); + + return out; +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes) +{ + FILE *f; + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + + f = stbiw__fopen(filename, "wb"); + if (!f) { STBIW_FREE(png); return 0; } + fwrite(png, 1, len, f); + fclose(f); + STBIW_FREE(png); + return 1; +} +#endif + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes) +{ + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + func(context, png, len); + STBIW_FREE(png); + return 1; +} + + +/* *************************************************************************** + * + * JPEG writer + * + * This is based on Jon Olick's jo_jpeg.cpp: + * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html + */ + +static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18, + 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 }; + +static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) { + int bitBuf = *bitBufP, bitCnt = *bitCntP; + bitCnt += bs[1]; + bitBuf |= bs[0] << (24 - bitCnt); + while(bitCnt >= 8) { + unsigned char c = (bitBuf >> 16) & 255; + stbiw__putc(s, c); + if(c == 255) { + stbiw__putc(s, 0); + } + bitBuf <<= 8; + bitCnt -= 8; + } + *bitBufP = bitBuf; + *bitCntP = bitCnt; +} + +static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) { + float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p; + float z1, z2, z3, z4, z5, z11, z13; + + float tmp0 = d0 + d7; + float tmp7 = d0 - d7; + float tmp1 = d1 + d6; + float tmp6 = d1 - d6; + float tmp2 = d2 + d5; + float tmp5 = d2 - d5; + float tmp3 = d3 + d4; + float tmp4 = d3 - d4; + + // Even part + float tmp10 = tmp0 + tmp3; // phase 2 + float tmp13 = tmp0 - tmp3; + float tmp11 = tmp1 + tmp2; + float tmp12 = tmp1 - tmp2; + + d0 = tmp10 + tmp11; // phase 3 + d4 = tmp10 - tmp11; + + z1 = (tmp12 + tmp13) * 0.707106781f; // c4 + d2 = tmp13 + z1; // phase 5 + d6 = tmp13 - z1; + + // Odd part + tmp10 = tmp4 + tmp5; // phase 2 + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + // The rotator is modified from fig 4-8 to avoid extra negations. + z5 = (tmp10 - tmp12) * 0.382683433f; // c6 + z2 = tmp10 * 0.541196100f + z5; // c2-c6 + z4 = tmp12 * 1.306562965f + z5; // c2+c6 + z3 = tmp11 * 0.707106781f; // c4 + + z11 = tmp7 + z3; // phase 5 + z13 = tmp7 - z3; + + *d5p = z13 + z2; // phase 6 + *d3p = z13 - z2; + *d1p = z11 + z4; + *d7p = z11 - z4; + + *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6; +} + +static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) { + int tmp1 = val < 0 ? -val : val; + val = val < 0 ? val-1 : val; + bits[1] = 1; + while(tmp1 >>= 1) { + ++bits[1]; + } + bits[0] = val & ((1<0)&&(DU[end0pos]==0); --end0pos) { + } + // end0pos = first element in reverse order !=0 + if(end0pos == 0) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + return DU[0]; + } + for(i = 1; i <= end0pos; ++i) { + int startpos = i; + int nrzeroes; + unsigned short bits[2]; + for (; DU[i]==0 && i<=end0pos; ++i) { + } + nrzeroes = i-startpos; + if ( nrzeroes >= 16 ) { + int lng = nrzeroes>>4; + int nrmarker; + for (nrmarker=1; nrmarker <= lng; ++nrmarker) + stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes); + nrzeroes &= 15; + } + stbiw__jpg_calcBits(DU[i], bits); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits); + } + if(end0pos != 63) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + } + return DU[0]; +} + +static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { + // Constants that don't pollute global namespace + static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; + static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d}; + static const unsigned char std_ac_luminance_values[] = { + 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08, + 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28, + 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59, + 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89, + 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6, + 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2, + 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0}; + static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77}; + static const unsigned char std_ac_chrominance_values[] = { + 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91, + 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26, + 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58, + 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87, + 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4, + 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda, + 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + // Huffman tables + static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}}; + static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}}; + static const unsigned short YAC_HT[256][2] = { + {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const unsigned short UVAC_HT[256][2] = { + {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22, + 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99}; + static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99, + 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99}; + static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f, + 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f }; + + int row, col, i, k, subsample; + float fdtbl_Y[64], fdtbl_UV[64]; + unsigned char YTable[64], UVTable[64]; + + if(!data || !width || !height || comp > 4 || comp < 1) { + return 0; + } + + quality = quality ? quality : 90; + subsample = quality <= 90 ? 1 : 0; + quality = quality < 1 ? 1 : quality > 100 ? 100 : quality; + quality = quality < 50 ? 5000 / quality : 200 - quality * 2; + + for(i = 0; i < 64; ++i) { + int uvti, yti = (YQT[i]*quality+50)/100; + YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti); + uvti = (UVQT[i]*quality+50)/100; + UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti); + } + + for(row = 0, k = 0; row < 8; ++row) { + for(col = 0; col < 8; ++col, ++k) { + fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + } + } + + // Write Headers + { + static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 }; + static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 }; + const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width), + 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 }; + s->func(s->context, (void*)head0, sizeof(head0)); + s->func(s->context, (void*)YTable, sizeof(YTable)); + stbiw__putc(s, 1); + s->func(s->context, UVTable, sizeof(UVTable)); + s->func(s->context, (void*)head1, sizeof(head1)); + s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); + s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); + stbiw__putc(s, 0x10); // HTYACinfo + s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1); + s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values)); + stbiw__putc(s, 1); // HTUDCinfo + s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values)); + stbiw__putc(s, 0x11); // HTUACinfo + s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values)); + s->func(s->context, (void*)head2, sizeof(head2)); + } + + // Encode 8x8 macroblocks + { + static const unsigned short fillBits[] = {0x7F, 7}; + int DCY=0, DCU=0, DCV=0; + int bitBuf=0, bitCnt=0; + // comp == 2 is grey+alpha (alpha is ignored) + int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0; + const unsigned char *dataR = (const unsigned char *)data; + const unsigned char *dataG = dataR + ofsG; + const unsigned char *dataB = dataR + ofsB; + int x, y, pos; + if(subsample) { + for(y = 0; y < height; y += 16) { + for(x = 0; x < width; x += 16) { + float Y[256], U[256], V[256]; + for(row = y, pos = 0; row < y+16; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+16; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + + // subsample U,V + { + float subU[64], subV[64]; + int yy, xx; + for(yy = 0, pos = 0; yy < 8; ++yy) { + for(xx = 0; xx < 8; ++xx, ++pos) { + int j = yy*32+xx*2; + subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f; + subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f; + } + } + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + } else { + for(y = 0; y < height; y += 8) { + for(x = 0; x < width; x += 8) { + float Y[64], U[64], V[64]; + for(row = y, pos = 0; row < y+8; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+8; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + + // Do the bit alignment of the EOI marker + stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits); + } + + // EOI + stbiw__putc(s, 0xFF); + stbiw__putc(s, 0xD9); + + return 1; +} + +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); +} + + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +#endif // STB_IMAGE_WRITE_IMPLEMENTATION + +/* Revision history + 1.16 (2021-07-11) + make Deflate code emit uncompressed blocks when it would otherwise expand + support writing BMPs with alpha channel + 1.15 (2020-07-13) unknown + 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels + 1.13 + 1.12 + 1.11 (2019-08-11) + + 1.10 (2019-02-07) + support utf8 filenames in Windows; fix warnings and platform ifdefs + 1.09 (2018-02-11) + fix typo in zlib quality API, improve STB_I_W_STATIC in C++ + 1.08 (2018-01-29) + add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter + 1.07 (2017-07-24) + doc fix + 1.06 (2017-07-23) + writing JPEG (using Jon Olick's code) + 1.05 ??? + 1.04 (2017-03-03) + monochrome BMP expansion + 1.03 ??? + 1.02 (2016-04-02) + avoid allocating large structures on the stack + 1.01 (2016-01-16) + STBIW_REALLOC_SIZED: support allocators with no realloc support + avoid race-condition in crc initialization + minor compile issues + 1.00 (2015-09-14) + installable file IO function + 0.99 (2015-09-13) + warning fixes; TGA rle support + 0.98 (2015-04-08) + added STBIW_MALLOC, STBIW_ASSERT etc + 0.97 (2015-01-18) + fixed HDR asserts, rewrote HDR rle logic + 0.96 (2015-01-17) + add HDR output + fix monochrome BMP + 0.95 (2014-08-17) + add monochrome TGA output + 0.94 (2014-05-31) + rename private functions to avoid conflicts with stb_image.h + 0.93 (2014-05-27) + warning fixes + 0.92 (2010-08-01) + casts to unsigned char to fix warnings + 0.91 (2010-07-17) + first public release + 0.90 first internal release +*/ + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/gui/flake.lock b/gui/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..90c914452b6b7f84c8fccb61d25ccd2883117107 --- /dev/null +++ b/gui/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1744463964, + "narHash": "sha256-LWqduOgLHCFxiTNYi3Uj5Lgz0SR+Xhw3kr/3Xd0GPTM=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2631b0b7abcea6e640ce31cd78ea58910d31e650", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/gui/flake.nix b/gui/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..d053b7aa3676e734983102f678398fedb03d334e --- /dev/null +++ b/gui/flake.nix @@ -0,0 +1,75 @@ +{ + description = "INGP dev env"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + config = { + allowUnfree = true; + cudaSupport = true; + config.cudaVersion = "12"; + }; + }; + in + { + devShell = pkgs.mkShell { + buildInputs = with pkgs; [ + gcc13 + gdb + cmake + pkg-config + binutils + zlib + + xorg.libX11.dev + xorg.libXi.dev + xorg.libXrandr.dev + xorg.libXinerama.dev + xorg.libXcursor.dev + xorg.libXext.dev + xorg.libXfixes.dev + xorg.libXrender.dev + libGL + glew + + vulkan-loader + vulkan-headers + vulkan-validation-layers + vulkan-extension-layer + vulkan-tools + + python3 + stdenv.cc.cc.lib + + cudatoolkit + cudaPackages.cuda_cudart + cudaPackages.cuda_nvrtc + cudaPackages.cuda_nvtx + ]; + + shellHook = '' + # Set GCC 13 as the default compiler + export CC="${pkgs.gcc13}/bin/gcc" + export CXX="${pkgs.gcc13}/bin/g++" + export PATH="${pkgs.gcc13}/bin:$PATH" + + export CUDA_PATH="${pkgs.cudatoolkit}" + export CLANGD_CUDA_INCLUDE="${pkgs.cudatoolkit}" + + export LD_LIBRARY_PATH="/run/opengl-driver/lib:${pkgs.zlib}/lib:${pkgs.stdenv.cc.cc.lib}/lib:''${LD_LIBRARY_PATH:-}" + export VULKAN_SDK="${pkgs.vulkan-loader}" + + export VK_LAYER_PATH="${pkgs.vulkan-validation-layers}/share/vulkan/explicit_layer.d:${pkgs.vulkan-extension-layer}/share/vulkan/explicit_layer.d" + export VK_ICD_FILENAMES="/run/opengl-driver/share/vulkan/icd.d/nvidia_icd.x86_64.json" + ''; + }; + } + ); +} diff --git a/gui/include/neural-graphics-primitives/adam_optimizer.h b/gui/include/neural-graphics-primitives/adam_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..09fa5846558a79c35131a9d7121424691be20344 --- /dev/null +++ b/gui/include/neural-graphics-primitives/adam_optimizer.h @@ -0,0 +1,318 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file adam_optimizer.h + * @author Thomas Müller, NVIDIA + */ + +#pragma once + +#include +#include +#include + +#include + +namespace ngp { + +class VarAdamOptimizer { +public: + VarAdamOptimizer(size_t size = 0, float learning_rate = 1e-3, float epsilon = 1e-08f, float beta1 = 0.9f, float beta2 = 0.99f) : m_state{size} { + m_hparams = { learning_rate, epsilon, beta1, beta2 }; + } + + VarAdamOptimizer& operator=(const VarAdamOptimizer& arg) { + m_state = arg.m_state; + m_hparams = arg.m_hparams; + return *this; + } + + VarAdamOptimizer(const VarAdamOptimizer& arg) { + *this = arg; + } + + void step(const std::vector& gradient) { + ++m_state.iter; + + float actual_learning_rate = m_hparams.learning_rate * std::sqrt(1.0f - std::pow(m_hparams.beta2, (float)m_state.iter)) / (1.0f - std::pow(m_hparams.beta1, (float)m_state.iter)); + + for (size_t i = 0; i < m_state.first_moment.size(); ++i) { + m_state.first_moment[i] = m_hparams.beta1 * m_state.first_moment[i] + (1.0f - m_hparams.beta1) * gradient[i]; + m_state.second_moment[i] = m_hparams.beta2 * m_state.second_moment[i] + (1.0f - m_hparams.beta2) * gradient[i] * gradient[i]; + m_state.variable[i] -= actual_learning_rate * m_state.first_moment[i] / (std::sqrt(m_state.second_moment[i]) + m_hparams.epsilon); + } + } + + uint32_t step() const { + return m_state.iter; + } + + void set_learning_rate(float lr) { + m_hparams.learning_rate = lr; + } + + std::vector& variable() { + return m_state.variable; + } + + const std::vector& variable() const { + return m_state.variable; + } + + void reset_state() { + m_state = State{m_state.first_moment.size()}; + } + + void to_json(nlohmann::json& j) const { + j["iter"] = m_state.iter; + j["first_moment"] = m_state.first_moment; + j["second_moment"] = m_state.second_moment; + j["variable"] = m_state.variable; + j["learning_rate"] = m_hparams.learning_rate; + j["epsilon"] = m_hparams.epsilon; + j["beta1"] = m_hparams.beta1; + j["beta2"] = m_hparams.beta2; + } + + void from_json(const nlohmann::json& j) { + m_state.iter = j.at("iter"); + m_state.first_moment = j.at("first_moment").get>(); + m_state.second_moment = j.at("second_moment").get>(); + m_state.variable = j.at("variable").get>(); + m_hparams.learning_rate = j.at("learning_rate"); + m_hparams.epsilon = j.at("epsilon"); + m_hparams.beta1 = j.at("beta1"); + m_hparams.beta2 = j.at("beta2"); + } + +private: + struct State { + State() = default; + State(const State&) = default; + State(size_t size) { + iter = 0; + first_moment = std::vector(size, 0.0f); + second_moment = std::vector(size, 0.0f); + variable = std::vector(size, 0.0f); + } + + uint32_t iter; + std::vector first_moment; + std::vector second_moment; + std::vector variable; + } m_state; + + struct Hyperparameters { + float learning_rate; + float epsilon; + float beta1; + float beta2; + } m_hparams; +}; + +inline void to_json(nlohmann::json& j, const VarAdamOptimizer& opt) { + opt.to_json(j); +} + +inline void from_json(const nlohmann::json& j, VarAdamOptimizer& opt) { + opt.from_json(j); +} + +template +class AdamOptimizer { +public: + AdamOptimizer(float learning_rate = 1e-3, float epsilon = 1e-08f, float beta1 = 0.9f, float beta2 = 0.99f) { + m_hparams = { learning_rate, epsilon, beta1, beta2 }; + } + + AdamOptimizer& operator=(const AdamOptimizer& arg) { + m_state = arg.m_state; + m_hparams = arg.m_hparams; + return *this; + } + + AdamOptimizer(const AdamOptimizer& arg) { + *this = arg; + } + + void step(const T& gradient) { + ++m_state.iter; + + float actual_learning_rate = m_hparams.learning_rate * std::sqrt(1.0f - std::pow(m_hparams.beta2, (float)m_state.iter)) / (1.0f - std::pow(m_hparams.beta1, (float)m_state.iter)); + m_state.first_moment = m_hparams.beta1 * m_state.first_moment + (1.0f - m_hparams.beta1) * gradient; + m_state.second_moment = m_hparams.beta2 * m_state.second_moment + (1.0f - m_hparams.beta2) * gradient * gradient; + m_state.variable -= actual_learning_rate * m_state.first_moment / (sqrt(m_state.second_moment) + T(m_hparams.epsilon)); + } + + uint32_t step() const { + return m_state.iter; + } + + void set_learning_rate(float lr) { + m_hparams.learning_rate = lr; + } + + T& variable() { + return m_state.variable; + } + + const T& variable() const { + return m_state.variable; + } + + void reset_state() { + m_state = {}; + } + + void to_json(nlohmann::json& j) const { + j["iter"] = m_state.iter; + j["first_moment"] = m_state.first_moment; + j["second_moment"] = m_state.second_moment; + j["variable"] = m_state.variable; + j["learning_rate"] = m_hparams.learning_rate; + j["epsilon"] = m_hparams.epsilon; + j["beta1"] = m_hparams.beta1; + j["beta2"] = m_hparams.beta2; + } + + void from_json(const nlohmann::json& j) { + m_state.iter = j.at("iter"); + m_state.first_moment = j.at("first_moment"); + m_state.second_moment = j.at("second_moment"); + m_state.variable = j.at("variable"); + m_hparams.learning_rate = j.at("learning_rate"); + m_hparams.epsilon = j.at("epsilon"); + m_hparams.beta1 = j.at("beta1"); + m_hparams.beta2 = j.at("beta2"); + } + +private: + struct State { + uint32_t iter = 0; + T first_moment = T(0.0f); + T second_moment = T(0.0f); + T variable = T(0.0f); + } m_state = {}; + + struct Hyperparameters { + float learning_rate; + float epsilon; + float beta1; + float beta2; + } m_hparams = {}; +}; + +template +inline void to_json(nlohmann::json& j, const AdamOptimizer& opt) { + opt.to_json(j); +} + +template +inline void from_json(const nlohmann::json& j, AdamOptimizer& opt) { + opt.from_json(j); +} + +class RotationAdamOptimizer { +public: + RotationAdamOptimizer(float learning_rate = 1e-3, float epsilon = 1e-08f, float beta1 = 0.9f, float beta2 = 0.99f) { + m_hparams = { learning_rate, epsilon, beta1, beta2 }; + } + + RotationAdamOptimizer& operator=(const RotationAdamOptimizer& arg) { + m_state = arg.m_state; + m_hparams = arg.m_hparams; + return *this; + } + + RotationAdamOptimizer(const RotationAdamOptimizer& arg) { + *this = arg; + } + + void step(const vec3& gradient) { + ++m_state.iter; + + float actual_learning_rate = m_hparams.learning_rate * std::sqrt(1 - std::pow(m_hparams.beta2, m_state.iter)) / (1 - std::pow(m_hparams.beta1, m_state.iter)); + m_state.first_moment = m_hparams.beta1 * m_state.first_moment + (1 - m_hparams.beta1) * gradient; + m_state.second_moment = m_hparams.beta2 * m_state.second_moment + (1 - m_hparams.beta2) * gradient * gradient; + vec3 rot = actual_learning_rate * m_state.first_moment / (sqrt(m_state.second_moment) + m_hparams.epsilon); + + m_state.variable = rotvec(rotmat(-rot) * rotmat(variable())); + } + + uint32_t step() const { + return m_state.iter; + } + + void set_learning_rate(float lr) { + m_hparams.learning_rate = lr; + } + + const vec3& variable() const { + return m_state.variable; + } + + void reset_state() { + m_state = {}; + } + + void to_json(nlohmann::json& j) const { + j["iter"] = m_state.iter; + j["first_moment"] = m_state.first_moment; + j["second_moment"] = m_state.second_moment; + j["variable"] = m_state.variable; + j["learning_rate"] = m_hparams.learning_rate; + j["epsilon"] = m_hparams.epsilon; + j["beta1"] = m_hparams.beta1; + j["beta2"] = m_hparams.beta2; + } + + void from_json(const nlohmann::json& j) { + m_state.iter = j.at("iter"); + m_state.first_moment = j.at("first_moment"); + m_state.second_moment = j.at("second_moment"); + m_state.variable = j.at("variable"); + m_hparams.learning_rate = j.at("learning_rate"); + m_hparams.epsilon = j.at("epsilon"); + m_hparams.beta1 = j.at("beta1"); + m_hparams.beta2 = j.at("beta2"); + } + +private: + struct State { + uint32_t iter = 0; + vec3 first_moment = vec3(0.0f); + vec3 second_moment = vec3(0.0f); + vec3 variable = vec3(0.0f); + } m_state; + + struct Hyperparameters { + float learning_rate; + float epsilon; + float beta1; + float beta2; + } m_hparams; +}; + +inline void to_json(nlohmann::json& j, const RotationAdamOptimizer& opt) { + opt.to_json(j); +} + +inline void from_json(const nlohmann::json& j, RotationAdamOptimizer& opt) { + opt.from_json(j); +} + +} diff --git a/gui/include/neural-graphics-primitives/bounding_box.cuh b/gui/include/neural-graphics-primitives/bounding_box.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b919cbc244f5bfac100f015adf61d788094cc791 --- /dev/null +++ b/gui/include/neural-graphics-primitives/bounding_box.cuh @@ -0,0 +1,259 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file bounding_box.cuh + * @author Thomas Müller & Alex Evans, NVIDIA + * @brief CUDA/C++ AABB implementation. + */ + +#pragma once + +#include +#include +#include + +namespace ngp { + +template +NGP_HOST_DEVICE inline void project(vec3 points[N_POINTS], const vec3& axis, float& min, float& max) { + min = std::numeric_limits::infinity(); + max = -std::numeric_limits::infinity(); + + NGP_PRAGMA_UNROLL + for (uint32_t i = 0; i < N_POINTS; ++i) { + float val = dot(axis, points[i]); + + if (val < min) { + min = val; + } + + if (val > max) { + max = val; + } + } +} + +struct BoundingBox { + NGP_HOST_DEVICE BoundingBox() {} + + NGP_HOST_DEVICE BoundingBox(const vec3& a, const vec3& b) : min{a}, max{b} {} + + NGP_HOST_DEVICE explicit BoundingBox(const Triangle& tri) { + min = max = tri.a; + enlarge(tri.b); + enlarge(tri.c); + } + + NGP_HOST_DEVICE BoundingBox(Triangle* begin, Triangle* end) { + min = max = begin->a; + for (auto it = begin; it != end; ++it) { + enlarge(*it); + } + } + + NGP_HOST_DEVICE void enlarge(const BoundingBox& other) { + min = tcnn::min(min, other.min); + max = tcnn::max(max, other.max); + } + + NGP_HOST_DEVICE void enlarge(const Triangle& tri) { + enlarge(tri.a); + enlarge(tri.b); + enlarge(tri.c); + } + + NGP_HOST_DEVICE void enlarge(const vec3& point) { + min = tcnn::min(min, point); + max = tcnn::max(max, point); + } + + NGP_HOST_DEVICE void inflate(float amount) { + min -= vec3(amount); + max += vec3(amount); + } + + NGP_HOST_DEVICE vec3 diag() const { + return max - min; + } + + NGP_HOST_DEVICE vec3 relative_pos(const vec3& pos) const { + return (pos - min) / diag(); + } + + NGP_HOST_DEVICE vec3 center() const { + return 0.5f * (max + min); + } + + NGP_HOST_DEVICE BoundingBox intersection(const BoundingBox& other) const { + BoundingBox result = *this; + result.min = tcnn::max(result.min, other.min); + result.max = tcnn::min(result.max, other.max); + return result; + } + + NGP_HOST_DEVICE bool intersects(const BoundingBox& other) const { + return !intersection(other).is_empty(); + } + + // Based on the separating axis theorem + // (https://fileadmin.cs.lth.se/cs/Personal/Tomas_Akenine-Moller/code/tribox_tam.pdf) + // Code adapted from a C# implementation at stack overflow + // https://stackoverflow.com/a/17503268 + NGP_HOST_DEVICE bool intersects(const Triangle& triangle) const { + float triangle_min, triangle_max; + float box_min, box_max; + + // Test the box normals (x-, y- and z-axes) + vec3 box_normals[3] = { + vec3{1.0f, 0.0f, 0.0f}, + vec3{0.0f, 1.0f, 0.0f}, + vec3{0.0f, 0.0f, 1.0f}, + }; + + vec3 triangle_normal = triangle.normal(); + vec3 triangle_verts[3]; + triangle.get_vertices(triangle_verts); + + for (int i = 0; i < 3; i++) { + project<3>(triangle_verts, box_normals[i], triangle_min, triangle_max); + if (triangle_max < min[i] || triangle_min > max[i]) { + return false; // No intersection possible. + } + } + + vec3 verts[8]; + get_vertices(verts); + + // Test the triangle normal + float triangle_offset = dot(triangle_normal, triangle.a); + project<8>(verts, triangle_normal, box_min, box_max); + if (box_max < triangle_offset || box_min > triangle_offset) { + return false; // No intersection possible. + } + + // Test the nine edge cross-products + vec3 edges[3] = { + triangle.a - triangle.b, + triangle.a - triangle.c, + triangle.b - triangle.c, + }; + + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + // The box normals are the same as it's edge tangents + vec3 axis = cross(edges[i], box_normals[j]); + project<8>(verts, axis, box_min, box_max); + project<3>(triangle_verts, axis, triangle_min, triangle_max); + if (box_max < triangle_min || box_min > triangle_max) + return false; // No intersection possible + } + } + + // No separating axis found. + return true; + } + + NGP_HOST_DEVICE vec2 ray_intersect(const vec3& pos, const vec3& dir) const { + float tmin = (min.x - pos.x) / dir.x; + float tmax = (max.x - pos.x) / dir.x; + + if (tmin > tmax) { + host_device_swap(tmin, tmax); + } + + float tymin = (min.y - pos.y) / dir.y; + float tymax = (max.y - pos.y) / dir.y; + + if (tymin > tymax) { + host_device_swap(tymin, tymax); + } + + if (tmin > tymax || tymin > tmax) { + return { std::numeric_limits::max(), std::numeric_limits::max() }; + } + + if (tymin > tmin) { + tmin = tymin; + } + + if (tymax < tmax) { + tmax = tymax; + } + + float tzmin = (min.z - pos.z) / dir.z; + float tzmax = (max.z - pos.z) / dir.z; + + if (tzmin > tzmax) { + host_device_swap(tzmin, tzmax); + } + + if (tmin > tzmax || tzmin > tmax) { + return { std::numeric_limits::max(), std::numeric_limits::max() }; + } + + if (tzmin > tmin) { + tmin = tzmin; + } + + if (tzmax < tmax) { + tmax = tzmax; + } + + return { tmin, tmax }; + } + + NGP_HOST_DEVICE bool is_empty() const { + return max.x < min.x || max.y < min.y || max.z < min.z; + } + + NGP_HOST_DEVICE bool contains(const vec3& p) const { + return + p.x >= min.x && p.x <= max.x && + p.y >= min.y && p.y <= max.y && + p.z >= min.z && p.z <= max.z; + } + + /// Calculate the squared point-AABB distance + NGP_HOST_DEVICE float distance(const vec3& p) const { + return sqrt(distance_sq(p)); + } + + NGP_HOST_DEVICE float distance_sq(const vec3& p) const { + return length2(tcnn::max(tcnn::max(min - p, p - max), vec3(0.0f))); + } + + NGP_HOST_DEVICE float signed_distance(const vec3& p) const { + vec3 q = abs(p - min) - diag(); + return length(tcnn::max(q, vec3(0.0f))) + std::min(tcnn::max(q), 0.0f); + } + + NGP_HOST_DEVICE void get_vertices(vec3 v[8]) const { + v[0] = {min.x, min.y, min.z}; + v[1] = {min.x, min.y, max.z}; + v[2] = {min.x, max.y, min.z}; + v[3] = {min.x, max.y, max.z}; + v[4] = {max.x, min.y, min.z}; + v[5] = {max.x, min.y, max.z}; + v[6] = {max.x, max.y, min.z}; + v[7] = {max.x, max.y, max.z}; + } + + vec3 min = vec3(std::numeric_limits::infinity()); + vec3 max = vec3(-std::numeric_limits::infinity()); +}; + +} diff --git a/gui/include/neural-graphics-primitives/camera_path.h b/gui/include/neural-graphics-primitives/camera_path.h new file mode 100644 index 0000000000000000000000000000000000000000..66c86e5e358ad5c3aae78dad40814dc7af67c982 --- /dev/null +++ b/gui/include/neural-graphics-primitives/camera_path.h @@ -0,0 +1,222 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file camera_path.h + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#pragma once + +#include + +#include + +#ifdef NGP_GUI +# include +# include +#endif + +#include +#include + +struct ImDrawList; + +namespace ngp { + +struct CameraKeyframe { + quat R; + vec3 T; + float fov; + + float timestamp = 0; + + mat4x3 m() const { + auto rot = to_mat3(normalize(R)); + return mat4x3(rot[0], rot[1], rot[2], T); + } + + void from_m(const mat4x3& rv) { + T = rv[3]; + R = quat(mat3(rv)); + } + + CameraKeyframe() = default; + CameraKeyframe(const quat& r, const vec3& t, float fv, float time) : R(r), T(t), fov(fv), timestamp{time} {} + CameraKeyframe(mat4x3 m, float fv, float time) : fov(fv), timestamp(time) { from_m(m); } + CameraKeyframe operator*(float f) const { return {R * f, T * f, fov * f, timestamp}; } + CameraKeyframe operator+(const CameraKeyframe& rhs) const { + quat Rr = rhs.R; + if (dot(Rr, R) < 0.0f) { + Rr = -Rr; + } + return {R + Rr, T + rhs.T, fov + rhs.fov, rhs.timestamp}; + } + + bool same_pos_as(const CameraKeyframe& rhs) const { return distance(T, rhs.T) < 0.0001f && fabsf(dot(R, rhs.R)) >= 0.999f; } +}; + +CameraKeyframe lerp(const CameraKeyframe& p0, const CameraKeyframe& p1, float t, float t0, float t1); +CameraKeyframe spline_cm(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2, const CameraKeyframe& p3); +CameraKeyframe spline_cubic(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2, const CameraKeyframe& p3); +CameraKeyframe spline_quadratic(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2); +CameraKeyframe spline_linear(float t, const CameraKeyframe& p0, const CameraKeyframe& p1); + +enum class EEditingKernel : int { + None, + Gaussian, + Quartic, + Hat, + Box, +}; +static constexpr const char* EditingKernelStr = "None\0Gaussian\0Quartic\0Hat\0Box\0\0"; + +struct CameraPath { + std::vector keyframes; + bool update_cam_from_path = false; + float play_time = 0.f; + float auto_play_speed = 0.f; + float default_duration_seconds = 3.0f; + // If loop is set true, the last frame set will be more like "next to last," + // with animation then returning back to the first frame, making a continuous loop. + // Note that the user does not have to (and should not normally) duplicate the first frame to be the last frame. + bool loop = false; + + int keyframe_subsampling = 1; + + + EEditingKernel editing_kernel_type = EEditingKernel::None; + float editing_kernel_radius = 0.1f; + + // Cubic spline per default. Order 1 (p/w linear) is also supported. + int spline_order = 3; + + struct RenderSettings { + ivec2 resolution = {1920, 1080}; + int spp = 8; + float fps = 60.0f; + float shutter_fraction = 0.5f; + int quality = 8; + + uint32_t n_frames(const float duration) const { return (uint32_t)((double)duration * fps); } + + float frame_seconds(const float duration) const { return 1.0f / (duration * fps); } + + float frame_milliseconds(const float duration) const { return 1000.0f / (duration * fps); } + + std::string filename = "video.mp4"; + }; + + RenderSettings render_settings; + bool rendering = false; + uint32_t render_frame_idx = 0; + std::chrono::time_point render_start_time; + + mat4x3 render_frame_end_camera; + + struct Pos { + int kfidx; + float t; + }; + + void clear() { + keyframes.clear(); + play_time = 0.0f; + } + + bool empty() const { return keyframes.empty(); } + + bool has_valid_timestamps() const; + + void make_keyframe_timestamps_equidistant(const float duration_seconds); + + float duration_seconds() const; + + void set_duration_seconds(const float duration); + + void sanitize_keyframes(); + + Pos get_pos(float playtime); + + float get_playtime(int i) { + if (i <= 0 || keyframes.size() < 2) { + return 0.0f; + } + + const auto& kf = keyframes[clamp(i - 1, 0, (int)keyframes.size() - 1)]; + const float duration = loop ? keyframes.back().timestamp : keyframes[keyframes.size() - 2].timestamp; + return kf.timestamp / duration; + } + + const CameraKeyframe& get_keyframe(int i) const { + if (loop) { + int size = (int)keyframes.size(); + // add size to ensure no negative value is generated by modulo + return keyframes[(i + size) % size]; + } else { + return keyframes[clamp(i, 0, (int)keyframes.size() - 1)]; + } + } + + CameraKeyframe eval_camera_path(float t) { + if (keyframes.empty()) { + return {}; + } + + auto p = get_pos(t); + switch (spline_order) { + case 0: return get_keyframe(p.kfidx + (int)round(p.t)); + case 1: return spline_linear(p.t, get_keyframe(p.kfidx), get_keyframe(p.kfidx + 1)); + case 2: return spline_quadratic(p.t, get_keyframe(p.kfidx - 1), get_keyframe(p.kfidx), get_keyframe(p.kfidx + 1)); + case 3: + return spline_cubic( + p.t, get_keyframe(p.kfidx - 1), get_keyframe(p.kfidx), get_keyframe(p.kfidx + 1), get_keyframe(p.kfidx + 2) + ); + default: throw std::runtime_error{fmt::format("Spline of order {} is not supported.", spline_order)}; + } + } + + void save(const fs::path& path); + void load(const fs::path& path, const mat4x3& first_xform); + + void add_camera(const mat4x3& camera, float fov, float timestamp); + +#ifdef NGP_GUI + ImGuizmo::MODE m_gizmo_mode = ImGuizmo::LOCAL; + ImGuizmo::OPERATION m_gizmo_op = ImGuizmo::TRANSLATE; + int imgui(char path_filename_buf[1024], float frame_milliseconds, const mat4x3& camera, float fov, const mat4x3& first_xform); + bool imgui_viz( + ImDrawList* list, + mat4& view2proj, + mat4& world2proj, + mat4& world2view, + vec2 focal, + float aspect, + float znear, + float zfar + ); +#endif +}; + +#ifdef NGP_GUI +void add_debug_line(ImDrawList* list, const mat4& proj, vec3 a, vec3 b, uint32_t col = 0xffffffff, float thickness = 1.0f); +void visualize_cube(ImDrawList* list, const mat4& world2proj, const vec3& a, const vec3& b, const mat3& render_aabb_to_local); +void visualize_camera( + ImDrawList* list, const mat4& world2proj, const mat4x3& xform, float aspect, uint32_t col = 0x80ffffff, float thickness = 1.0f +); +#endif + +} // namespace ngp diff --git a/gui/include/neural-graphics-primitives/common.h b/gui/include/neural-graphics-primitives/common.h new file mode 100644 index 0000000000000000000000000000000000000000..f6ef742fe910c499c9d51932d14b62accf27c3f6 --- /dev/null +++ b/gui/include/neural-graphics-primitives/common.h @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common.h + * @author Thomas Müller, NVIDIA + * @brief Shared functionality among multiple neural-graphics-primitives components. + */ + +#pragma once + +#ifdef _WIN32 +# define NOMINMAX +#endif + +#include +using namespace tcnn; + +#if defined(__CUDA_ARCH__) +# define NGP_PRAGMA_UNROLL _Pragma("unroll") +# define NGP_PRAGMA_NO_UNROLL _Pragma("unroll 1") +#else +# define NGP_PRAGMA_UNROLL +# define NGP_PRAGMA_NO_UNROLL +#endif + +#if defined(__CUDACC__) || (defined(__clang__) && defined(__CUDA__)) +# define NGP_HOST_DEVICE __host__ __device__ +#else +# define NGP_HOST_DEVICE +#endif + +namespace ngp { + +enum class EMeshRenderMode : int { + Off, + VertexColors, + VertexNormals, + FaceIDs, +}; + +enum class EGroundTruthRenderMode : int { + Shade, + Depth, + NumRenderModes, +}; +static constexpr const char* GroundTruthRenderModeStr = "Shade\0Depth\0\0"; + +enum class ERenderMode : int { + AO, + Shade, + Normals, + Positions, + Depth, + Distortion, + Cost, + Slice, + NumRenderModes, + EncodingVis, // EncodingVis exists outside of the standard render modes +}; +static constexpr const char* RenderModeStr = "AO\0Shade\0Normals\0Positions\0Depth\0Distortion\0Cost\0Slice\0\0"; + +enum class EPmVizMode : int { + Shade, + Depth, + Offset, + Holes, +}; +static constexpr const char* PmVizModeStr = "Shade\0Depth\0Offset\0Holes\0\0"; + +enum class ERandomMode : int { + Random, + Halton, + Sobol, + Stratified, + NumImageRandomModes, +}; +static constexpr const char* RandomModeStr = "Random\0Halton\0Sobol\0Stratified\0\0"; + +enum class ELossType : int { + L2, + L1, + Mape, + Smape, + Huber, + LogL1, + RelativeL2, +}; +static constexpr const char* LossTypeStr = "L2\0L1\0MAPE\0SMAPE\0Huber\0LogL1\0RelativeL2\0\0"; + +enum class EMeshSdfMode : int { + Watertight, + Raystab, + PathEscape, +}; +static constexpr const char* MeshSdfModeStr = "Watertight\0Raystab\0PathEscape\0\0"; + +enum class EColorSpace : int { + Linear, + SRGB, + VisPosNeg, +}; +static constexpr const char* ColorSpaceStr = "Linear\0SRGB\0\0"; + +enum class ETonemapCurve : int { Identity, ACES, Hable, Reinhard }; +static constexpr const char* TonemapCurveStr = "Identity\0ACES\0Hable\0Reinhard\0\0"; + +enum class EDlssQuality : int { + UltraPerformance, + MaxPerformance, + Balanced, + MaxQuality, + UltraQuality, + NumDlssQualitySettings, + None, +}; +static constexpr const char* DlssQualityStr = "UltraPerformance\0MaxPerformance\0Balanced\0MaxQuality\0UltraQuality\0Invalid\0None\0\0"; +static constexpr const char* DlssQualityStrArray[] = { + "UltraPerformance", "MaxPerformance", "Balanced", "MaxQuality", "UltraQuality", "Invalid", "None" +}; + +enum class ETestbedMode : int { + Gen3c, + None, +}; + +enum class ESDFGroundTruthMode : int { + RaytracedMesh, + SpheretracedMesh, + SDFBricks, +}; + +enum EPmPixelState : uint8_t { + Hole = 0, + FillableHole, + FilledHole, + Reprojected, +}; + + +struct TrainingXForm { + NGP_HOST_DEVICE bool operator==(const TrainingXForm& other) const { return start == other.start && end == other.end; } + + mat4x3 start; + mat4x3 end; +}; + +enum class ELensMode : int { + Perspective, + OpenCV, + FTheta, + LatLong, + OpenCVFisheye, + Equirectangular, + Orthographic, +}; +static constexpr const char* LensModeStr = "Perspective\0OpenCV\0F-Theta\0LatLong\0OpenCV Fisheye\0Equirectangular\0Orthographic\0\0"; + +struct Lens { + ELensMode mode = ELensMode::Perspective; + float params[7] = {}; + + NGP_HOST_DEVICE bool is_360() const { return mode == ELensMode::Equirectangular || mode == ELensMode::LatLong; } + + NGP_HOST_DEVICE bool supports_dlss() { + return mode == ELensMode::LatLong || mode == ELensMode::Equirectangular || mode == ELensMode::Perspective || + mode == ELensMode::Orthographic || mode == ELensMode::OpenCV || mode == ELensMode::OpenCVFisheye; + } +}; + +enum class EGen3cCameraSource : int { + // Fake camera trajectory based on fixed translation and rotation speeds. + Fake = 0, + // Camera trajectory from the current viewpoint + predicted movement, + // including when using a VR headset. + Viewpoint, + // Camera trajectory from a path authored with the camera tools. + Authored +}; +static constexpr const char* Gen3cCameraSourceStr = "Fake\0Viewpoint\0Authored\0\0"; + +inline NGP_HOST_DEVICE uint32_t binary_search(float val, const float* data, uint32_t length) { + if (length == 0) { + return 0; + } + + uint32_t it; + uint32_t count, step; + count = length; + + uint32_t first = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (data[it] < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + + return min(first, length - 1); +} + +template struct Buffer2DView { + T* data = nullptr; + ivec2 resolution = 0; + + // Lookup via integer pixel position (no bounds checking) + NGP_HOST_DEVICE T at(const ivec2& px) const { return data[px.x + px.y * resolution.x]; } + + // Lookup via UV coordinates in [0,1]^2 + NGP_HOST_DEVICE T at(const vec2& uv) const { + ivec2 px = clamp(ivec2(vec2(resolution) * uv), 0, resolution - 1); + return at(px); + } + + // Lookup via UV coordinates in [0,1]^2 and LERP the nearest texels + NGP_HOST_DEVICE T at_lerp(const vec2& uv) const { + const vec2 px_float = vec2(resolution) * uv; + const ivec2 px = ivec2(px_float); + + const vec2 weight = px_float - vec2(px); + + auto read_val = [&](ivec2 pos) { return at(clamp(pos, 0, resolution - 1)); }; + + return ( + (1 - weight.x) * (1 - weight.y) * read_val({px.x, px.y}) + (weight.x) * (1 - weight.y) * read_val({px.x + 1, px.y}) + + (1 - weight.x) * (weight.y) * read_val({px.x, px.y + 1}) + (weight.x) * (weight.y) * read_val({px.x + 1, px.y + 1}) + ); + } + + NGP_HOST_DEVICE operator bool() const { return data; } +}; + +} // namespace ngp diff --git a/gui/include/neural-graphics-primitives/common_device.cuh b/gui/include/neural-graphics-primitives/common_device.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e1e0b05e4d163ed9102764b15692d7745f9c87be --- /dev/null +++ b/gui/include/neural-graphics-primitives/common_device.cuh @@ -0,0 +1,866 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common.h + * @author Thomas Müller, NVIDIA + * @brief Shared functionality among multiple neural-graphics-primitives components. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace ngp { + +// The maximum depth that can be produced when rendering a frame. +// Chosen somewhat low (rather than std::numeric_limits::infinity()) +// to permit numerically stable reprojection and DLSS operation, +// even when rendering the infinitely distant horizon. +inline constexpr __device__ float MAX_DEPTH() { return 16384.0f; } + +inline NGP_HOST_DEVICE float srgb_to_linear(float srgb) { + if (srgb <= 0.04045f) { + return srgb / 12.92f; + } else { + return pow((srgb + 0.055f) / 1.055f, 2.4f); + } +} + +inline NGP_HOST_DEVICE vec3 srgb_to_linear(const vec3& x) { return {srgb_to_linear(x.x), srgb_to_linear(x.y), (srgb_to_linear(x.z))}; } + +inline NGP_HOST_DEVICE float srgb_to_linear_derivative(float srgb) { + if (srgb <= 0.04045f) { + return 1.0f / 12.92f; + } else { + return 2.4f / 1.055f * pow((srgb + 0.055f) / 1.055f, 1.4f); + } +} + +inline NGP_HOST_DEVICE vec3 srgb_to_linear_derivative(const vec3& x) { + return {srgb_to_linear_derivative(x.x), srgb_to_linear_derivative(x.y), (srgb_to_linear_derivative(x.z))}; +} + +inline NGP_HOST_DEVICE float linear_to_srgb(float linear) { + if (linear < 0.0031308f) { + return 12.92f * linear; + } else { + return 1.055f * pow(linear, 0.41666f) - 0.055f; + } +} + +inline NGP_HOST_DEVICE vec3 linear_to_srgb(const vec3& x) { return {linear_to_srgb(x.x), linear_to_srgb(x.y), (linear_to_srgb(x.z))}; } + +inline NGP_HOST_DEVICE float linear_to_srgb_derivative(float linear) { + if (linear < 0.0031308f) { + return 12.92f; + } else { + return 1.055f * 0.41666f * pow(linear, 0.41666f - 1.0f); + } +} + +inline NGP_HOST_DEVICE vec3 linear_to_srgb_derivative(const vec3& x) { + return {linear_to_srgb_derivative(x.x), linear_to_srgb_derivative(x.y), (linear_to_srgb_derivative(x.z))}; +} + +template +__device__ void deposit_image_gradient( + const vec2& value, T* __restrict__ gradient, T* __restrict__ gradient_weight, const ivec2& resolution, const vec2& pos +) { + const vec2 pos_float = vec2(resolution) * pos; + const ivec2 texel = {pos_float}; + + const vec2 weight = pos_float - vec2(texel); + + constexpr uint32_t N_DIMS = 2; + + auto deposit_val = [&](const vec2& value, T weight, ivec2 pos) { + pos.x = max(min(pos.x, resolution.x - 1), 0); + pos.y = max(min(pos.y, resolution.y - 1), 0); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 // atomicAdd(__half2) is only supported with compute capability 60 and above + if (std::is_same::value) { + for (uint32_t c = 0; c < N_DIMS; c += 2) { + atomicAdd((__half2*)&gradient[(pos.x + pos.y * resolution.x) * N_DIMS + c], {(T)value[c] * weight, (T)value[c + 1] * weight}); + atomicAdd((__half2*)&gradient_weight[(pos.x + pos.y * resolution.x) * N_DIMS + c], {weight, weight}); + } + } else +#endif + { + for (uint32_t c = 0; c < N_DIMS; ++c) { + atomicAdd(&gradient[(pos.x + pos.y * resolution.x) * N_DIMS + c], (T)value[c] * weight); + atomicAdd(&gradient_weight[(pos.x + pos.y * resolution.x) * N_DIMS + c], weight); + } + } + }; + + deposit_val(value, (1 - weight.x) * (1 - weight.y), {texel.x, texel.y}); + deposit_val(value, (weight.x) * (1 - weight.y), {texel.x + 1, texel.y}); + deposit_val(value, (1 - weight.x) * (weight.y), {texel.x, texel.y + 1}); + deposit_val(value, (weight.x) * (weight.y), {texel.x + 1, texel.y + 1}); +} + +struct FoveationPiecewiseQuadratic { + FoveationPiecewiseQuadratic() = default; + NGP_HOST_DEVICE FoveationPiecewiseQuadratic(float center_pixel_steepness, float center_inverse_piecewise_y, float center_radius) { + float center_inverse_radius = center_radius * center_pixel_steepness; + float left_inverse_piecewise_switch = center_inverse_piecewise_y - center_inverse_radius; + float right_inverse_piecewise_switch = center_inverse_piecewise_y + center_inverse_radius; + + if (left_inverse_piecewise_switch < 0) { + left_inverse_piecewise_switch = 0.0f; + } + + if (right_inverse_piecewise_switch > 1) { + right_inverse_piecewise_switch = 1.0f; + } + + float am = center_pixel_steepness; + float d = (right_inverse_piecewise_switch - left_inverse_piecewise_switch) / center_pixel_steepness / 2; + + // binary search for l,r,bm since analytical is very complex + float bm; + float m_min = 0.0f; + float m_max = 1.0f; + for (uint32_t i = 0; i < 20; i++) { + float m = (m_min + m_max) / 2.0f; + float l = m - d; + float r = m + d; + + bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1); + + float l_actual = (left_inverse_piecewise_switch - bm) / am; + float r_actual = (right_inverse_piecewise_switch - bm) / am; + float m_actual = (l_actual + r_actual) / 2; + + if (m_actual > m) { + m_min = m; + } else { + m_max = m; + } + } + + float l = (left_inverse_piecewise_switch - bm) / am; + float r = (right_inverse_piecewise_switch - bm) / am; + + // Full linear case. Default construction covers this. + if ((l == 0.0f && r == 1.0f) || (am == 1.0f)) { + return; + } + + // write out solution + switch_left = l; + switch_right = r; + this->am = am; + al = (am - 1) / (r * r - 2 * r + l * l + 1); + bl = (am * (r * r - 2 * r + 1) + am * l * l + (2 - 2 * am) * l) / (r * r - 2 * r + l * l + 1); + cl = 0; + this->bm = bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1); + ar = -(am - 1) / (r * r - 2 * r + l * l + 1); + br = (am * (r * r + 1) - 2 * r + am * l * l) / (r * r - 2 * r + l * l + 1); + cr = -(am * r * r - r * r + (am - 1) * l * l) / (r * r - 2 * r + l * l + 1); + + inv_switch_left = am * switch_left + bm; + inv_switch_right = am * switch_right + bm; + } + + // left parabola: al * x^2 + bl * x + cl + float al = 0.0f, bl = 0.0f, cl = 0.0f; + // middle linear piece: am * x + bm. am should give 1:1 pixel mapping between warped size and full size. + float am = 1.0f, bm = 0.0f; + // right parabola: al * x^2 + bl * x + cl + float ar = 0.0f, br = 0.0f, cr = 0.0f; + + // points where left and right switch over from quadratic to linear + float switch_left = 0.0f, switch_right = 1.0f; + // same, in inverted space + float inv_switch_left = 0.0f, inv_switch_right = 1.0f; + + NGP_HOST_DEVICE float warp(float x) const { + x = clamp(x, 0.0f, 1.0f); + if (x < switch_left) { + return al * x * x + bl * x + cl; + } else if (x > switch_right) { + return ar * x * x + br * x + cr; + } else { + return am * x + bm; + } + } + + NGP_HOST_DEVICE float unwarp(float y) const { + y = clamp(y, 0.0f, 1.0f); + if (y < inv_switch_left) { + return (sqrt(-4 * al * cl + 4 * al * y + bl * bl) - bl) / (2 * al); + } else if (y > inv_switch_right) { + return (sqrt(-4 * ar * cr + 4 * ar * y + br * br) - br) / (2 * ar); + } else { + return (y - bm) / am; + } + } + + NGP_HOST_DEVICE float density(float x) const { + x = clamp(x, 0.0f, 1.0f); + if (x < switch_left) { + return 2 * al * x + bl; + } else if (x > switch_right) { + return 2 * ar * x + br; + } else { + return am; + } + } +}; + +struct Foveation { + Foveation() = default; + + NGP_HOST_DEVICE Foveation(const vec2& center_pixel_steepness, const vec2& center_inverse_piecewise_y, const vec2& center_radius) : + warp_x{center_pixel_steepness.x, center_inverse_piecewise_y.x, center_radius.x}, + warp_y{center_pixel_steepness.y, center_inverse_piecewise_y.y, center_radius.y} {} + + FoveationPiecewiseQuadratic warp_x, warp_y; + + NGP_HOST_DEVICE vec2 warp(const vec2& x) const { return {warp_x.warp(x.x), warp_y.warp(x.y)}; } + + NGP_HOST_DEVICE vec2 unwarp(const vec2& y) const { return {warp_x.unwarp(y.x), warp_y.unwarp(y.y)}; } + + NGP_HOST_DEVICE float density(const vec2& x) const { return warp_x.density(x.x) * warp_y.density(x.y); } +}; + +template NGP_HOST_DEVICE inline void opencv_lens_distortion_delta(const T* extra_params, const T u, const T v, T* du, T* dv) { + const T k1 = extra_params[0]; + const T k2 = extra_params[1]; + const T p1 = extra_params[2]; + const T p2 = extra_params[3]; + + const T u2 = u * u; + const T uv = u * v; + const T v2 = v * v; + const T r2 = u2 + v2; + const T radial = k1 * r2 + k2 * r2 * r2; + *du = u * radial + T(2) * p1 * uv + p2 * (r2 + T(2) * u2); + *dv = v * radial + T(2) * p2 * uv + p1 * (r2 + T(2) * v2); +} + +template +NGP_HOST_DEVICE inline void opencv_fisheye_lens_distortion_delta(const T* extra_params, const T u, const T v, T* du, T* dv) { + const T k1 = extra_params[0]; + const T k2 = extra_params[1]; + const T k3 = extra_params[2]; + const T k4 = extra_params[3]; + + const T r = sqrt(u * u + v * v); + + if (r > (T)std::numeric_limits::epsilon()) { + const T theta = atan(r); + const T theta2 = theta * theta; + const T theta4 = theta2 * theta2; + const T theta6 = theta4 * theta2; + const T theta8 = theta4 * theta4; + const T thetad = theta * (T(1) + k1 * theta2 + k2 * theta4 + k3 * theta6 + k4 * theta8); + *du = u * thetad / r - u; + *dv = v * thetad / r - v; + } else { + *du = T(0); + *dv = T(0); + } +} + +template NGP_HOST_DEVICE inline void iterative_lens_undistortion(const T* params, T* u, T* v, F distortion_fun) { + // Parameters for Newton iteration using numerical differentiation with + // central differences, 100 iterations should be enough even for complex + // camera models with higher order terms. + const uint32_t kNumIterations = 100; + const float kMaxStepNorm = 1e-10f; + const float kRelStepSize = 1e-6f; + + mat2 J; + const vec2 x0{*u, *v}; + vec2 x{*u, *v}; + vec2 dx; + vec2 dx_0b; + vec2 dx_0f; + vec2 dx_1b; + vec2 dx_1f; + + for (uint32_t i = 0; i < kNumIterations; ++i) { + const float step0 = max(std::numeric_limits::epsilon(), abs(kRelStepSize * x[0])); + const float step1 = max(std::numeric_limits::epsilon(), abs(kRelStepSize * x[1])); + distortion_fun(params, x[0], x[1], &dx[0], &dx[1]); + distortion_fun(params, x[0] - step0, x[1], &dx_0b[0], &dx_0b[1]); + distortion_fun(params, x[0] + step0, x[1], &dx_0f[0], &dx_0f[1]); + distortion_fun(params, x[0], x[1] - step1, &dx_1b[0], &dx_1b[1]); + distortion_fun(params, x[0], x[1] + step1, &dx_1f[0], &dx_1f[1]); + J[0][0] = 1 + (dx_0f[0] - dx_0b[0]) / (2 * step0); + J[1][0] = (dx_1f[0] - dx_1b[0]) / (2 * step1); + J[0][1] = (dx_0f[1] - dx_0b[1]) / (2 * step0); + J[1][1] = 1 + (dx_1f[1] - dx_1b[1]) / (2 * step1); + const vec2 step_x = inverse(J) * (x + dx - x0); + x -= step_x; + if (length2(step_x) < kMaxStepNorm) { + break; + } + } + + *u = x[0]; + *v = x[1]; +} + +template NGP_HOST_DEVICE inline void iterative_opencv_lens_undistortion(const T* params, T* u, T* v) { + iterative_lens_undistortion(params, u, v, opencv_lens_distortion_delta); +} + +template NGP_HOST_DEVICE inline void iterative_opencv_fisheye_lens_undistortion(const T* params, T* u, T* v) { + iterative_lens_undistortion(params, u, v, opencv_fisheye_lens_distortion_delta); +} + +inline NGP_HOST_DEVICE Ray pixel_to_ray_pinhole( + uint32_t spp, const ivec2& pixel, const ivec2& resolution, const vec2& focal_length, const mat4x3& camera_matrix, const vec2& screen_center +) { + const vec2 uv = vec2(pixel) / vec2(resolution); + + vec3 dir = { + (uv.x - screen_center.x) * (float)resolution.x / focal_length.x, (uv.y - screen_center.y) * (float)resolution.y / focal_length.y, 1.0f + }; + + dir = mat3(camera_matrix) * dir; + return {camera_matrix[3], dir}; +} + +inline NGP_HOST_DEVICE vec3 f_theta_undistortion(const vec2& uv, const float* params, const vec3& error_direction) { + // we take f_theta intrinsics to be: r0, r1, r2, r3, resx, resy; we rescale to whatever res the intrinsics specify. + float xpix = uv.x * params[5]; + float ypix = uv.y * params[6]; + float norm = sqrtf(xpix * xpix + ypix * ypix); + float alpha = params[0] + norm * (params[1] + norm * (params[2] + norm * (params[3] + norm * params[4]))); + float sin_alpha, cos_alpha; + sincosf(alpha, &sin_alpha, &cos_alpha); + if (cos_alpha <= std::numeric_limits::min() || norm == 0.f) { + return error_direction; + } + sin_alpha *= 1.f / norm; + return {sin_alpha * xpix, sin_alpha * ypix, cos_alpha}; +} + +inline NGP_HOST_DEVICE vec3 latlong_to_dir(const vec2& uv) { + float theta = (uv.y - 0.5f) * PI(); + float phi = (uv.x - 0.5f) * PI() * 2.0f; + float sp, cp, st, ct; + sincosf(theta, &st, &ct); + sincosf(phi, &sp, &cp); + return {sp * ct, st, cp * ct}; +} + +inline NGP_HOST_DEVICE vec3 equirectangular_to_dir(const vec2& uv) { + float ct = (uv.y - 0.5f) * 2.0f; + float st = sqrt(max(1.0f - ct * ct, 0.0f)); + float phi = (uv.x - 0.5f) * PI() * 2.0f; + float sp, cp; + sincosf(phi, &sp, &cp); + return {sp * st, ct, cp * st}; +} + +inline NGP_HOST_DEVICE vec2 dir_to_latlong(const vec3& dir) { + float theta = asin(dir.y); + float phi = atan2(dir.x, dir.z); + return {phi / (PI() * 2.0f) + 0.5f, theta / PI() + 0.5f}; +} + +inline NGP_HOST_DEVICE vec2 dir_to_equirectangular(const vec3& dir) { + float ct = dir.y; + float phi = atan2(dir.x, dir.z); + return {phi / (PI() * 2.0f) + 0.5f, ct / 2.0f + 0.5f}; +} + +inline NGP_HOST_DEVICE Ray uv_to_ray( + uint32_t spp, + const vec2& uv, + const ivec2& resolution, + const vec2& focal_length, + const mat4x3& camera_matrix, + const vec2& screen_center, + const vec3& parallax_shift = vec3(0.0f), + float near_distance = 0.0f, + float focus_z = 1.0f, + float aperture_size = 0.0f, + const Foveation& foveation = {}, + Buffer2DView hidden_area_mask = {}, + const Lens& lens = {}, + Buffer2DView distortion = {} +) { + vec2 warped_uv = foveation.warp(uv); + + // Check the hidden area mask _after_ applying foveation, because foveation will be undone + // before blitting to the framebuffer to which the hidden area mask corresponds. + if (hidden_area_mask && !hidden_area_mask.at(warped_uv)) { + return Ray::invalid(); + } + + vec3 head_pos = {parallax_shift.x, parallax_shift.y, 0.f}; + vec3 dir; + if (lens.mode == ELensMode::FTheta) { + dir = f_theta_undistortion(warped_uv - screen_center, lens.params, {0.f, 0.f, 0.f}); + if (dir == vec3(0.0f)) { + return Ray::invalid(); + } + } else if (lens.mode == ELensMode::LatLong) { + dir = latlong_to_dir(warped_uv); + } else if (lens.mode == ELensMode::Equirectangular) { + dir = equirectangular_to_dir(warped_uv); + } else if (lens.mode == ELensMode::Orthographic) { + dir = {0.0f, 0.0f, 1.0f}; + head_pos += vec3{ + (warped_uv.x - screen_center.x) * (float)resolution.x / focal_length.x, + (warped_uv.y - screen_center.y) * (float)resolution.y / focal_length.y, + 0.0f, + }; + } else { + dir = { + (warped_uv.x - screen_center.x) * (float)resolution.x / focal_length.x, + (warped_uv.y - screen_center.y) * (float)resolution.y / focal_length.y, + 1.0f + }; + + if (lens.mode == ELensMode::OpenCV) { + iterative_opencv_lens_undistortion(lens.params, &dir.x, &dir.y); + } else if (lens.mode == ELensMode::OpenCVFisheye) { + iterative_opencv_fisheye_lens_undistortion(lens.params, &dir.x, &dir.y); + } + } + + if (distortion) { + dir.xy() += distortion.at_lerp(warped_uv); + } + + if (lens.mode != ELensMode::Orthographic && lens.mode != ELensMode::LatLong && lens.mode != ELensMode::Equirectangular) { + dir -= head_pos * parallax_shift.z; // we could use focus_z here in the denominator. for now, we pack m_scale in here. + } + + dir = mat3(camera_matrix) * dir; + + vec3 origin = mat3(camera_matrix) * head_pos + camera_matrix[3]; + if (aperture_size != 0.0f) { + vec3 lookat = origin + dir * focus_z; + auto px = ivec2(uv * vec2(resolution)); + vec2 blur = aperture_size * square2disk_shirley(ld_random_val_2d(spp, px.x * 19349663 + px.y * 96925573) * 2.0f - 1.0f); + origin += mat2x3(camera_matrix) * blur; + dir = (lookat - origin) / focus_z; + } + + origin += dir * near_distance; + return {origin, dir}; +} + +inline NGP_HOST_DEVICE Ray pixel_to_ray( + uint32_t spp, + const ivec2& pixel, + const ivec2& resolution, + const vec2& focal_length, + const mat4x3& camera_matrix, + const vec2& screen_center, + const vec3& parallax_shift = vec3(0.0f), + bool snap_to_pixel_centers = false, + float near_distance = 0.0f, + float focus_z = 1.0f, + float aperture_size = 0.0f, + const Foveation& foveation = {}, + Buffer2DView hidden_area_mask = {}, + const Lens& lens = {}, + Buffer2DView distortion = {} +) { + return uv_to_ray( + spp, + (vec2(pixel) + ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp)) / vec2(resolution), + resolution, + focal_length, + camera_matrix, + screen_center, + parallax_shift, + near_distance, + focus_z, + aperture_size, + foveation, + hidden_area_mask, + lens, + distortion + ); +} + +inline NGP_HOST_DEVICE vec2 pos_to_uv( + const vec3& pos, + const ivec2& resolution, + const vec2& focal_length, + const mat4x3& camera_matrix, + const vec2& screen_center, + const vec3& parallax_shift, + const Foveation& foveation = {}, + const Lens& lens = {} +) { + vec3 head_pos = {parallax_shift.x, parallax_shift.y, 0.f}; + vec2 uv; + + if (lens.mode == ELensMode::Orthographic) { + vec3 rel_pos = inverse(mat3(camera_matrix)) * (pos - camera_matrix[3]) - head_pos; + uv = rel_pos.xy() * focal_length / vec2(resolution) + screen_center; + } else { + // Express ray in terms of camera frame + vec3 origin = mat3(camera_matrix) * head_pos + camera_matrix[3]; + + vec3 dir = pos - origin; + dir = inverse(mat3(camera_matrix)) * dir; + dir /= lens.is_360() ? length(dir) : dir.z; + + if (lens.mode == ELensMode::Equirectangular) { + uv = dir_to_equirectangular(dir); + } else if (lens.mode == ELensMode::LatLong) { + uv = dir_to_latlong(dir); + } else { + // Perspective with potential distortions applied on top + dir += head_pos * parallax_shift.z; + + float du = 0.0f, dv = 0.0f; + if (lens.mode == ELensMode::OpenCV) { + opencv_lens_distortion_delta(lens.params, dir.x, dir.y, &du, &dv); + } else if (lens.mode == ELensMode::OpenCVFisheye) { + opencv_fisheye_lens_distortion_delta(lens.params, dir.x, dir.y, &du, &dv); + } else { + // No other type of distortion is permitted. + assert(lens.mode == ELensMode::Perspective); + } + + dir.x += du; + dir.y += dv; + + uv = dir.xy() * focal_length / vec2(resolution) + screen_center; + } + } + + return foveation.unwarp(uv); +} + +inline NGP_HOST_DEVICE vec2 pos_to_pixel( + const vec3& pos, + const ivec2& resolution, + const vec2& focal_length, + const mat4x3& camera_matrix, + const vec2& screen_center, + const vec3& parallax_shift, + const Foveation& foveation = {}, + const Lens& lens = {} +) { + return pos_to_uv(pos, resolution, focal_length, camera_matrix, screen_center, parallax_shift, foveation, lens) * vec2(resolution); +} + +inline NGP_HOST_DEVICE vec2 motion_vector( + const uint32_t sample_index, + const ivec2& pixel, + const ivec2& resolution, + const vec2& focal_length, + const mat4x3& camera, + const mat4x3& prev_camera, + const vec2& screen_center, + const vec3& parallax_shift, + const bool snap_to_pixel_centers, + const float depth, + const Foveation& foveation = {}, + const Foveation& prev_foveation = {}, + const Lens& lens = {} +) { + vec2 pxf = vec2(pixel) + ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index); + Ray ray = uv_to_ray( + sample_index, + pxf / vec2(resolution), + resolution, + focal_length, + camera, + screen_center, + parallax_shift, + 0.0f, + 1.0f, + 0.0f, + foveation, + {}, // No hidden area mask + lens + ); + + vec2 prev_pxf = pos_to_pixel(ray(depth), resolution, focal_length, prev_camera, screen_center, parallax_shift, prev_foveation, lens); + + return prev_pxf - pxf; +} + +// Maps view-space depth (physical units) in the range [znear, zfar] hyperbolically to +// the interval [1, 0]. This is the reverse-z-component of "normalized device coordinates", +// which are commonly used in rasterization, where linear interpolation in screen space +// has to be equivalent to linear interpolation in real space (which, in turn, is +// guaranteed by the hyperbolic mapping of depth). This format is commonly found in +// z-buffers, and hence expected by downstream image processing functions, such as DLSS +// and VR reprojection. +inline NGP_HOST_DEVICE float to_ndc_depth(float z, float n, float f) { + // View depth outside of the view frustum leads to output outside of [0, 1] + z = clamp(z, n, f); + + float scale = n / (n - f); + float bias = -f * scale; + return clamp((z * scale + bias) / z, 0.0f, 1.0f); +} + +inline NGP_HOST_DEVICE float fov_to_focal_length(int resolution, float degrees) { + return 0.5f * (float)resolution / tanf(0.5f * degrees * PI() / 180.0f); +} + +inline NGP_HOST_DEVICE vec2 fov_to_focal_length(const ivec2& resolution, const vec2& degrees) { + return 0.5f * vec2(resolution) / tan(0.5f * degrees * (PI() / 180.0f)); +} + +inline NGP_HOST_DEVICE float focal_length_to_fov(int resolution, float focal_length) { + return 2.0f * 180.0f / PI() * atanf(float(resolution) / (focal_length * 2.0f)); +} + +inline NGP_HOST_DEVICE vec2 focal_length_to_fov(const ivec2& resolution, const vec2& focal_length) { + return 2.0f * 180.0f / PI() * atan(vec2(resolution) / (focal_length * 2.0f)); +} + +inline NGP_HOST_DEVICE vec2 relative_focal_length_to_fov(const vec2& rel_focal_length) { + return 2.0f * 180.0f / PI() * atan(vec2(1.0f) / (rel_focal_length * 2.0f)); +} + +inline NGP_HOST_DEVICE mat4x3 camera_log_lerp(const mat4x3& a, const mat4x3& b, float t) { + return mat_exp(mat_log(mat4(b) * inverse(mat4(a))) * t) * mat4(a); +} + +inline NGP_HOST_DEVICE mat4x3 camera_slerp(const mat4x3& a, const mat4x3& b, float t) { + mat3 rot = slerp(mat3(a), mat3(b), t); + return {rot[0], rot[1], rot[2], mix(a[3], b[3], t)}; +} + +inline NGP_HOST_DEVICE mat4x3 + get_xform_given_rolling_shutter(const TrainingXForm& training_xform, const vec4& rolling_shutter, const vec2& uv, float motionblur_time) { + float pixel_t = rolling_shutter.x + rolling_shutter.y * uv.x + rolling_shutter.z * uv.y + rolling_shutter.w * motionblur_time; + return camera_slerp(training_xform.start, training_xform.end, pixel_t); +} + +template +__global__ void from_rgba32( + const uint64_t num_pixels, + const uint8_t* __restrict__ pixels, + T* __restrict__ out, + bool white_2_transparent = false, + bool black_2_transparent = false, + uint32_t mask_color = 0 +) { + const uint64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_pixels) { + return; + } + + uint8_t rgba[4]; + *((uint32_t*)&rgba[0]) = *((uint32_t*)&pixels[i * 4]); + + float alpha = rgba[3] * (1.0f / 255.0f); + // NSVF dataset has 'white = transparent' madness + if (white_2_transparent && rgba[0] == 255 && rgba[1] == 255 && rgba[2] == 255) { + alpha = 0.f; + } + if (black_2_transparent && rgba[0] == 0 && rgba[1] == 0 && rgba[2] == 0) { + alpha = 0.f; + } + + tvec rgba_out; + rgba_out[0] = (T)(srgb_to_linear(rgba[0] * (1.0f / 255.0f)) * alpha); + rgba_out[1] = (T)(srgb_to_linear(rgba[1] * (1.0f / 255.0f)) * alpha); + rgba_out[2] = (T)(srgb_to_linear(rgba[2] * (1.0f / 255.0f)) * alpha); + rgba_out[3] = (T)alpha; + + if (mask_color != 0 && mask_color == *((uint32_t*)&rgba[0])) { + rgba_out[0] = rgba_out[1] = rgba_out[2] = rgba_out[3] = (T)-1.0f; + } + + *((tvec*)&out[i * 4]) = rgba_out; +} + +// Foley & van Dam p593 / http://en.wikipedia.org/wiki/HSL_and_HSV +inline NGP_HOST_DEVICE vec3 hsv_to_rgb(const vec3& hsv) { + float h = hsv.x, s = hsv.y, v = hsv.z; + if (s == 0.0f) { + return vec3(v); + } + + h = fmodf(h, 1.0f) * 6.0f; + int i = (int)h; + float f = h - (float)i; + float p = v * (1.0f - s); + float q = v * (1.0f - s * f); + float t = v * (1.0f - s * (1.0f - f)); + + switch (i) { + case 0: return {v, t, p}; + case 1: return {q, v, p}; + case 2: return {p, v, t}; + case 3: return {p, q, v}; + case 4: return {t, p, v}; + case 5: + default: return {v, p, q}; + } +} + +inline NGP_HOST_DEVICE vec3 to_rgb(const vec2& dir) { return hsv_to_rgb({atan2f(dir.y, dir.x) / (2.0f * PI()) + 0.5f, 1.0f, length(dir)}); } + +enum class EImageDataType { + None, + Byte, + Half, + Float, +}; + +enum class EDepthDataType { + UShort, + Float, +}; + +inline NGP_HOST_DEVICE ivec2 image_pos(const vec2& pos, const ivec2& resolution) { + return clamp(ivec2(pos * vec2(resolution)), 0, resolution - 1); +} + +inline NGP_HOST_DEVICE uint64_t pixel_idx(const ivec2& px, const ivec2& resolution, uint32_t img) { + return px.x + px.y * resolution.x + img * (uint64_t)resolution.x * resolution.y; +} + +inline NGP_HOST_DEVICE uint64_t pixel_idx(const vec2& uv, const ivec2& resolution, uint32_t img) { + return pixel_idx(image_pos(uv, resolution), resolution, img); +} + +// inline NGP_HOST_DEVICE vec3 composit_and_lerp(vec2 pos, const ivec2& resolution, uint32_t img, const __half* training_images, const vec3& +// background_color, const vec3& exposure_scale = vec3(1.0f)) { +// pos = (pos.cwiseProduct(vec2(resolution)) - 0.5f).cwiseMax(0.0f).cwiseMin(vec2(resolution) - (1.0f + 1e-4f)); + +// const ivec2 pos_int = pos.cast(); +// const vec2 weight = pos - pos_int.cast(); + +// const ivec2 idx = pos_int.cwiseMin(resolution - 2).cwiseMax(0); + +// auto read_val = [&](const ivec2& p) { +// __half val[4]; +// *(uint64_t*)&val[0] = ((uint64_t*)training_images)[pixel_idx(p, resolution, img)]; +// return vec3{val[0], val[1], val[2]} * exposure_scale + background_color * (1.0f - (float)val[3]); +// }; + +// return ( +// (1 - weight.x) * (1 - weight.y) * read_val({idx.x, idx.y}) + +// (weight.x) * (1 - weight.y) * read_val({idx.x+1, idx.y}) + +// (1 - weight.x) * (weight.y) * read_val({idx.x, idx.y+1}) + +// (weight.x) * (weight.y) * read_val({idx.x+1, idx.y+1}) +// ); +// } + +// inline NGP_HOST_DEVICE vec3 composit(vec2 pos, const ivec2& resolution, uint32_t img, const __half* training_images, const vec3& +// background_color, const vec3& exposure_scale = vec3(1.0f)) { +// auto read_val = [&](const ivec2& p) { +// __half val[4]; +// *(uint64_t*)&val[0] = ((uint64_t*)training_images)[pixel_idx(p, resolution, img)]; +// return vec3{val[0], val[1], val[2]} * exposure_scale + background_color * (1.0f - (float)val[3]); +// }; + +// return read_val(image_pos(pos, resolution)); +// } + +inline NGP_HOST_DEVICE uint32_t rgba_to_rgba32(const vec4& rgba) { + return ((uint32_t)(clamp(rgba.r, 0.0f, 1.0f) * 255.0f + 0.5f) << 0) | ((uint32_t)(clamp(rgba.g, 0.0f, 1.0f) * 255.0f + 0.5f) << 8) | + ((uint32_t)(clamp(rgba.b, 0.0f, 1.0f) * 255.0f + 0.5f) << 16) | ((uint32_t)(clamp(rgba.a, 0.0f, 1.0f) * 255.0f + 0.5f) << 24); +} + +inline NGP_HOST_DEVICE float rgba32_to_a(uint32_t rgba32) { return ((rgba32 & 0xFF000000) >> 24) * (1.0f / 255.0f); } + +inline NGP_HOST_DEVICE vec3 rgba32_to_rgb(uint32_t rgba32) { + return vec3{ + ((rgba32 & 0x000000FF) >> 0) * (1.0f / 255.0f), + ((rgba32 & 0x0000FF00) >> 8) * (1.0f / 255.0f), + ((rgba32 & 0x00FF0000) >> 16) * (1.0f / 255.0f), + }; +} + +inline NGP_HOST_DEVICE vec4 rgba32_to_rgba(uint32_t rgba32) { + return vec4{ + ((rgba32 & 0x000000FF) >> 0) * (1.0f / 255.0f), + ((rgba32 & 0x0000FF00) >> 8) * (1.0f / 255.0f), + ((rgba32 & 0x00FF0000) >> 16) * (1.0f / 255.0f), + ((rgba32 & 0xFF000000) >> 24) * (1.0f / 255.0f), + }; +} + +inline NGP_HOST_DEVICE vec4 read_rgba(ivec2 px, const ivec2& resolution, const void* pixels, EImageDataType image_data_type, uint32_t img = 0) { + switch (image_data_type) { + default: + // This should never happen. Bright red to indicate this. + return vec4{5.0f, 0.0f, 0.0f, 1.0f}; + case EImageDataType::Byte: { + uint32_t val = ((uint32_t*)pixels)[pixel_idx(px, resolution, img)]; + if (val == 0x00FF00FF) { + return vec4(-1.0f); + } + + vec4 result = rgba32_to_rgba(val); + result.rgb() = srgb_to_linear(result.rgb()) * result.a; + return result; + } + case EImageDataType::Half: { + __half val[4]; + *(uint64_t*)&val[0] = ((uint64_t*)pixels)[pixel_idx(px, resolution, img)]; + return vec4{(float)val[0], (float)val[1], (float)val[2], (float)val[3]}; + } + case EImageDataType::Float: return ((vec4*)pixels)[pixel_idx(px, resolution, img)]; + } +} + +inline NGP_HOST_DEVICE vec4 read_rgba(vec2 pos, const ivec2& resolution, const void* pixels, EImageDataType image_data_type, uint32_t img = 0) { + return read_rgba(image_pos(pos, resolution), resolution, pixels, image_data_type, img); +} + +inline NGP_HOST_DEVICE float read_depth(vec2 pos, const ivec2& resolution, const float* depth, uint32_t img = 0) { + auto read_val = [&](const ivec2& p) { return depth[pixel_idx(p, resolution, img)]; }; + + return read_val(image_pos(pos, resolution)); +} + +inline __device__ int float_to_ordered_int(float f) { + int i = __float_as_int(f); + return (i >= 0) ? i : i ^ 0x7FFFFFFF; +} + +inline __device__ float ordered_int_to_float(int i) { return __int_as_float(i >= 0 ? i : i ^ 0x7FFFFFFF); } + +inline __device__ vec3 colormap_turbo(float x) { + const vec4 kRedVec4 = {0.13572138f, 4.61539260f, -42.66032258f, 132.13108234f}; + const vec4 kGreenVec4 = {0.09140261f, 2.19418839f, 4.84296658f, -14.18503333f}; + const vec4 kBlueVec4 = {0.10667330f, 12.64194608f, -60.58204836f, 110.36276771f}; + const vec2 kRedVec2 = {-152.94239396f, 59.28637943f}; + const vec2 kGreenVec2 = {4.27729857f, 2.82956604f}; + const vec2 kBlueVec2 = {-89.90310912f, 27.34824973f}; + + x = __saturatef(x); + vec4 v4 = {1.0f, x, x * x, x * x * x}; + vec2 v2 = {v4.w * x, v4.w * v4.z}; + return { + dot(v4, kRedVec4) + dot(v2, kRedVec2), + dot(v4, kGreenVec4) + dot(v2, kGreenVec2), + dot(v4, kBlueVec4) + dot(v2, kBlueVec2), + }; +} + +} // namespace ngp diff --git a/gui/include/neural-graphics-primitives/common_host.h b/gui/include/neural-graphics-primitives/common_host.h new file mode 100644 index 0000000000000000000000000000000000000000..0e9d871ee28b159a2b13233ee6d503184fa79093 --- /dev/null +++ b/gui/include/neural-graphics-primitives/common_host.h @@ -0,0 +1,203 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common_host.h + * @author Thomas Müller, NVIDIA + * @brief Shared functionality among multiple neural-graphics-primitives components. + */ + +#pragma once + +#include + +#include + +#include +#include + +#include + +#include +#include + +namespace ngp { + +namespace fs = filesystem; + +bool is_wsl(); + +fs::path discover_executable_dir(); +fs::path discover_root_dir(); + +#ifdef _WIN32 +std::string utf16_to_utf8(const std::wstring& utf16); +std::wstring utf8_to_utf16(const std::string& utf16); +std::wstring native_string(const fs::path& path); +#else +std::string native_string(const fs::path& path); +#endif + +bool ends_with(const std::string& str, const std::string& ending); +bool ends_with_case_insensitive(const std::string& str, const std::string& ending); + +ETestbedMode mode_from_scene(const std::string& scene); +ETestbedMode mode_from_string(const std::string& str); +std::string to_string(ETestbedMode); + +inline std::string replace_all(std::string str, const std::string& a, const std::string& b) { + std::string::size_type n = 0; + while ((n = str.find(a, n)) != std::string::npos) { + str.replace(n, a.length(), b); + n += b.length(); + } + return str; +} + +template T snap_to_nearest(T val, const std::vector& candidates) { + T best_dist = std::numeric_limits::max(); + T result = candidates.empty() ? val : candidates[0]; + for (T c : candidates) { + T dist = abs(val - c); + if (dist < best_dist) { + best_dist = dist; + result = c; + } + } + + return result; +} + +enum class EEmaType { + Time, + Step, +}; + +template class Ema { +public: + Ema(EEmaType type, float half_life) : + m_type{type}, m_decay{std::pow(0.5f, 1.0f / max(half_life, 0.000001f))}, m_creation_time{std::chrono::steady_clock::now()} {} + + int64_t current_progress() { + if (m_type == EEmaType::Time) { + auto now = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(now - m_creation_time).count(); + } else { + return m_last_progress + 1; + } + } + + void update(const T& val) { + int64_t cur = current_progress(); + int64_t elapsed = cur - m_last_progress; + m_last_progress = cur; + + float decay = std::pow(m_decay, elapsed); + m_val = val; + m_ema_val = decay * m_ema_val + (1.0f - decay) * val; + } + + void set(const T& val) { + m_last_progress = current_progress(); + m_val = m_ema_val = val; + } + + T val() const { return m_val; } + + T ema_val() const { return m_ema_val; } + +private: + T m_val = 0.0f; + T m_ema_val = 0.0f; + EEmaType m_type; + float m_decay; + + int64_t m_last_progress = 0; + std::chrono::time_point m_creation_time; +}; + + +uint8_t* load_stbi(const fs::path& path, int* width, int* height, int* comp, int req_comp); +float* load_stbi_float(const fs::path& path, int* width, int* height, int* comp, int req_comp); +uint16_t* load_stbi_16(const fs::path& path, int* width, int* height, int* comp, int req_comp); +bool is_hdr_stbi(const fs::path& path); +int write_stbi(const fs::path& path, int width, int height, int comp, const uint8_t* pixels, int quality = 100); + +FILE* native_fopen(const fs::path& path, const char* mode); + +GPUMemory load_exr_gpu(const fs::path& path, int* width, int* height); +GPUMemory load_stbi_gpu(const fs::path& path, int* width, int* height); + +template class Buffer2D { +public: + Buffer2D() = default; + Buffer2D(const ivec2& resolution) { resize(resolution); } + + T* data() const { return m_data.data(); } + + size_t bytes() const { return m_data.bytes(); } + + void resize(const ivec2& resolution) { + m_data.resize(product(resolution)); + m_resolution = resolution; + } + + const ivec2& resolution() const { return m_resolution; } + + Buffer2DView view() const { + // Row major for now. + return {data(), m_resolution}; + } + + Buffer2DView const_view() const { + // Row major for now. + return {data(), m_resolution}; + } + +private: + GPUMemory m_data; + ivec2 m_resolution; +}; + +template struct GPUImage { + GPUImage() : image{}, padding{0} {} + + GPUImage(ivec2 resolution, uint32_t padding, cudaStream_t stream) : + image{resolution.y + padding * 2, resolution.x + padding * 2, stream}, padding{padding} {} + + GPUImage(ivec2 resolution, cudaStream_t stream) : GPUImage(resolution, 0, stream) {} + + MatrixView view() const { return image.slice(padding, image.m() - 2 * padding, padding, image.n() - 2 * padding).view(); } + + T* data() const { return image.data(); } + + size_t n_elements_padded() const { return image.n_elements(); } + size_t n_elements() const { return product(resolution()); } + + ivec2 resolution_padded() const { return {(int)image.n(), (int)image.m()}; } + ivec2 resolution() const { return {(int)(image.n() - 2 * padding), (int)(image.m() - 2 * padding)}; } + + explicit operator bool() const { return image.data() != nullptr; } + + GPUMatrix image; + uint32_t padding; +}; + +struct BoundingBox; +struct Triangle; +std::ostream& operator<<(std::ostream& os, const BoundingBox& triangle); +std::ostream& operator<<(std::ostream& os, const Triangle& triangle); +} // namespace ngp diff --git a/gui/include/neural-graphics-primitives/discrete_distribution.h b/gui/include/neural-graphics-primitives/discrete_distribution.h new file mode 100644 index 0000000000000000000000000000000000000000..8798f5b2c3f6b6c59e12faa0b641c412e134a9b1 --- /dev/null +++ b/gui/include/neural-graphics-primitives/discrete_distribution.h @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file discrete_distribution.h + * @author Thomas Müller, NVIDIA + */ + +#pragma once + +#include + +namespace ngp { + +struct DiscreteDistribution { + void build(std::vector weights) { + float total_weight = 0; + for (float w : weights) { + total_weight += w; + } + float inv_total_weight = 1 / total_weight; + + float cdf_accum = 0; + cdf.clear(); + for (float w : weights) { + float norm = w * inv_total_weight; + cdf_accum += norm; + pmf.emplace_back(norm); + cdf.emplace_back(cdf_accum); + } + cdf.back() = 1.0f; // Prevent precision problems from causing overruns in the end + } + + uint32_t sample(float val) { + return std::min(binary_search(val, cdf.data(), (uint32_t)cdf.size()), (uint32_t)cdf.size()-1); + } + + std::vector pmf; + std::vector cdf; +}; + +} diff --git a/gui/include/neural-graphics-primitives/dlss.h b/gui/include/neural-graphics-primitives/dlss.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9d0e2aa006e19a878ce7614f60201b1c923667 --- /dev/null +++ b/gui/include/neural-graphics-primitives/dlss.h @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file dlss.h + * @author Thomas Müller, NVIDIA + */ + +#pragma once + +#include + +#include + +namespace ngp { + +class IDlss { +public: + virtual ~IDlss() {} + + virtual void update_feature( + const ivec2& in_resolution, + bool is_hdr, + bool sharpen + ) = 0; + virtual void run( + const ivec2& in_resolution, + bool is_hdr, + float sharpening, + const vec2& jitter_offset, + bool shall_reset + ) = 0; + + virtual cudaSurfaceObject_t frame() = 0; + virtual cudaSurfaceObject_t depth() = 0; + virtual cudaSurfaceObject_t mvec() = 0; + virtual cudaSurfaceObject_t exposure() = 0; + virtual cudaSurfaceObject_t output() = 0; + + virtual ivec2 clamp_resolution(const ivec2& resolution) const = 0; + virtual ivec2 out_resolution() const = 0; + virtual ivec2 max_out_resolution() const = 0; + + virtual bool is_hdr() const = 0; + virtual bool sharpen() const = 0; + virtual EDlssQuality quality() const = 0; +}; + +class IDlssProvider { +public: + virtual ~IDlssProvider() {} + + virtual size_t allocated_bytes() const = 0; + virtual std::unique_ptr init_dlss(const ivec2& out_resolution) = 0; +}; + +#ifdef NGP_VULKAN +std::shared_ptr init_vulkan_and_ngx(); +#endif + +} diff --git a/gui/include/neural-graphics-primitives/json_binding.h b/gui/include/neural-graphics-primitives/json_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..d42a2001a9717f91a4016cdf07ff416fe85f3c8c --- /dev/null +++ b/gui/include/neural-graphics-primitives/json_binding.h @@ -0,0 +1,119 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +/** @file json_binding.h + * @author Thomas Müller, NVIDIA + * @brief Conversion between some ngp types and nlohmann::json. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace ngp { + +inline void to_json(nlohmann::json& j, const BoundingBox& box) { + j["min"] = box.min; + j["max"] = box.max; +} + +inline void from_json(const nlohmann::json& j, BoundingBox& box) { + box.min = j.at("min"); + box.max = j.at("max"); +} + +inline void to_json(nlohmann::json& j, const Lens& lens) { + if (lens.mode == ELensMode::OpenCV) { + j["is_fisheye"] = false; + j["k1"] = lens.params[0]; + j["k2"] = lens.params[1]; + j["p1"] = lens.params[2]; + j["p2"] = lens.params[3]; + } else if (lens.mode == ELensMode::OpenCVFisheye) { + j["is_fisheye"] = true; + j["k1"] = lens.params[0]; + j["k2"] = lens.params[1]; + j["k3"] = lens.params[2]; + j["k4"] = lens.params[3]; + } else if (lens.mode == ELensMode::FTheta) { + j["ftheta_p0"] = lens.params[0]; + j["ftheta_p1"] = lens.params[1]; + j["ftheta_p2"] = lens.params[2]; + j["ftheta_p3"] = lens.params[3]; + j["ftheta_p4"] = lens.params[4]; + j["w"] = lens.params[5]; + j["h"] = lens.params[6]; + } else if (lens.mode == ELensMode::LatLong) { + j["latlong"] = true; + } else if (lens.mode == ELensMode::Equirectangular) { + j["equirectangular"] = true; + } else if (lens.mode == ELensMode::Orthographic) { + j["orthographic"] = true; + } +} + +inline void from_json(const nlohmann::json& j, Lens& lens) { + if (j.contains("k1")) { + if (j.value("is_fisheye", false)) { + lens.mode = ELensMode::OpenCVFisheye; + lens.params[0] = j.at("k1"); + lens.params[1] = j.at("k2"); + lens.params[2] = j.at("k3"); + lens.params[3] = j.at("k4"); + } else { + lens.mode = ELensMode::OpenCV; + lens.params[0] = j.at("k1"); + lens.params[1] = j.at("k2"); + lens.params[2] = j.at("p1"); + lens.params[3] = j.at("p2"); + } + } else if (j.contains("ftheta_p0")) { + lens.mode = ELensMode::FTheta; + lens.params[0] = j.at("ftheta_p0"); + lens.params[1] = j.at("ftheta_p1"); + lens.params[2] = j.at("ftheta_p2"); + lens.params[3] = j.at("ftheta_p3"); + lens.params[4] = j.at("ftheta_p4"); + lens.params[5] = j.at("w"); + lens.params[6] = j.at("h"); + } else if (j.contains("latlong")) { + lens.mode = ELensMode::LatLong; + } else if (j.contains("equirectangular")) { + lens.mode = ELensMode::Equirectangular; + } else if (j.contains("orthographic")) { + lens.mode = ELensMode::Orthographic; + } else { + lens.mode = ELensMode::Perspective; + } +} + +inline void from_json(const nlohmann::json& j, TrainingXForm& x) { + x.start = j.at("start"); + x.end = j.at("end"); +} + +inline void to_json(nlohmann::json& j, const TrainingXForm& x) { + j["start"] = x.start; + j["end"] = x.end; +} + +} diff --git a/gui/include/neural-graphics-primitives/openxr_hmd.h b/gui/include/neural-graphics-primitives/openxr_hmd.h new file mode 100644 index 0000000000000000000000000000000000000000..befa7b9902f8b76b5dbd934b06f20eeebc0d6be8 --- /dev/null +++ b/gui/include/neural-graphics-primitives/openxr_hmd.h @@ -0,0 +1,298 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file openxr_hmd.h + * @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA + * @brief Wrapper around the OpenXR API, providing access to + * per-eye framebuffers, lens parameters, visible area, + * view, hand, and eye poses, as well as controller inputs. + */ + +#pragma once + +#ifdef _WIN32 +# include +#else +# include +#endif + +#define XR_USE_GRAPHICS_API_OPENGL + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type +#endif + +namespace ngp { + +enum class EEnvironmentBlendMode { + Opaque = XR_ENVIRONMENT_BLEND_MODE_OPAQUE, + Additive = XR_ENVIRONMENT_BLEND_MODE_ADDITIVE, + AlphaBlend = XR_ENVIRONMENT_BLEND_MODE_ALPHA_BLEND, +}; + +inline std::string to_string(EEnvironmentBlendMode mode) { + switch (mode) { + case EEnvironmentBlendMode::Opaque: return "Opaque"; + case EEnvironmentBlendMode::Additive: return "Additive"; + case EEnvironmentBlendMode::AlphaBlend: return "Blend"; + default: throw std::runtime_error{"Invalid blend mode."}; + } +} + +class OpenXRHMD { +public: + enum class EControlFlow { + Continue, + Restart, + Quit, + }; + + struct FrameInfo { + struct View { + GLuint framebuffer; + XrCompositionLayerProjectionView view{XR_TYPE_COMPOSITION_LAYER_PROJECTION_VIEW}; + XrCompositionLayerDepthInfoKHR depth_info{XR_TYPE_COMPOSITION_LAYER_DEPTH_INFO_KHR}; + std::shared_ptr> hidden_area_mask = nullptr; + mat4x3 pose; + }; + struct Hand { + mat4x3 pose; + bool pose_active = false; + vec2 thumbstick = vec2(0.0f); + float grab_strength = 0.0f; + bool grabbing = false; + bool pressing = false; + vec3 grab_pos; + vec3 prev_grab_pos; + vec3 drag() const { + return grab_pos - prev_grab_pos; + } + }; + std::vector views; + Hand hands[2]; + }; + using FrameInfoPtr = std::shared_ptr; + + // RAII OpenXRHMD with OpenGL +#if defined(XR_USE_PLATFORM_WIN32) + OpenXRHMD(HDC hdc, HGLRC hglrc); +#elif defined(XR_USE_PLATFORM_XLIB) + OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext); +#elif defined(XR_USE_PLATFORM_WAYLAND) + OpenXRHMD(wl_display* display); +#endif + + virtual ~OpenXRHMD(); + + // disallow copy / move + OpenXRHMD(const OpenXRHMD&) = delete; + OpenXRHMD& operator=(const OpenXRHMD&) = delete; + OpenXRHMD(OpenXRHMD&&) = delete; + OpenXRHMD& operator=(OpenXRHMD&&) = delete; + + void clear(); + + // poll events, handle state changes, return control flow information + EControlFlow poll_events(); + + // begin OpenXR frame, return views to render + FrameInfoPtr begin_frame(); + // must be called for each begin_frame + void end_frame(FrameInfoPtr frame_info, float znear, float zfar, bool submit_depth); + + void set_environment_blend_mode(EEnvironmentBlendMode mode) { + m_environment_blend_mode = mode; + } + + EEnvironmentBlendMode environment_blend_mode() const { + return m_environment_blend_mode; + } + + const std::vector& supported_environment_blend_modes() const { + return m_supported_environment_blend_modes; + } + + const char* supported_environment_blend_modes_imgui_string() const { + return m_supported_environment_blend_modes_imgui_string.data(); + } + + // if true call begin_frame and end_frame - does not imply visibility + bool must_run_frame_loop() const { + return + m_session_state == XR_SESSION_STATE_READY || + m_session_state == XR_SESSION_STATE_SYNCHRONIZED || + m_session_state == XR_SESSION_STATE_VISIBLE || + m_session_state == XR_SESSION_STATE_FOCUSED; + } + + // if true, VR is being rendered to the HMD + bool is_visible() const { + // XR_SESSION_STATE_VISIBLE -> app content is shown in HMD + // XR_SESSION_STATE_FOCUSED -> VISIBLE + input is send to app + return m_session_state == XR_SESSION_STATE_VISIBLE || m_session_state == XR_SESSION_STATE_FOCUSED; + } + +private: + // steps of the init process, called from the constructor + void init_create_xr_instance(); + void init_get_xr_system(); + void init_configure_xr_views(); + void init_check_for_xr_blend_mode(); + void init_xr_actions(); + +#if defined(XR_USE_PLATFORM_WIN32) + void init_open_gl(HDC hdc, HGLRC hglrc); +#elif defined(XR_USE_PLATFORM_XLIB) + void init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext); +#elif defined(XR_USE_PLATFORM_WAYLAND) + void init_open_gl(wl_display* display); +#endif + + void init_xr_session(); + void init_xr_spaces(); + void init_xr_swapchain_open_gl(); + void init_open_gl_shaders(); + + // session state change + void session_state_change(XrSessionState state, EControlFlow& flow); + + std::shared_ptr> rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view); + // system/instance + XrInstance m_instance{XR_NULL_HANDLE}; + XrSystemId m_system_id = {}; + XrInstanceProperties m_instance_properties = {XR_TYPE_INSTANCE_PROPERTIES}; + XrSystemProperties m_system_properties = {XR_TYPE_SYSTEM_PROPERTIES}; + std::vector m_api_layer_properties; + std::vector m_instance_extension_properties; + + // view and blending + XrViewConfigurationType m_view_configuration_type = {}; + XrViewConfigurationProperties m_view_configuration_properties = {XR_TYPE_VIEW_CONFIGURATION_PROPERTIES}; + std::vector m_view_configuration_views; + std::vector m_supported_environment_blend_modes; + std::vector m_supported_environment_blend_modes_imgui_string; + EEnvironmentBlendMode m_environment_blend_mode = EEnvironmentBlendMode::Opaque; + + // actions + std::array m_hand_paths; + std::array m_hand_spaces; + XrAction m_pose_action{XR_NULL_HANDLE}; + XrAction m_press_action{XR_NULL_HANDLE}; + XrAction m_grab_action{XR_NULL_HANDLE}; + + // Two separate actions for Xbox controller support + std::array m_thumbstick_actions; + + XrActionSet m_action_set{XR_NULL_HANDLE}; + +#if defined(XR_USE_PLATFORM_WIN32) + XrGraphicsBindingOpenGLWin32KHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WIN32_KHR}; +#elif defined(XR_USE_PLATFORM_XLIB) + XrGraphicsBindingOpenGLXlibKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_XLIB_KHR}; +#elif defined(XR_USE_PLATFORM_WAYLAND) + XrGraphicsBindingOpenGLWaylandKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WAYLAND_KHR}; +#endif + + XrSession m_session{XR_NULL_HANDLE}; + XrSessionState m_session_state{XR_SESSION_STATE_UNKNOWN}; + + // reference space + std::vector m_reference_spaces; + XrSpace m_space{XR_NULL_HANDLE}; + XrExtent2Df m_bounds; + + // swap chains + struct Swapchain { + Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& xr_instance); + Swapchain(const Swapchain&) = delete; + Swapchain& operator=(const Swapchain&) = delete; + Swapchain(Swapchain&& other) { + *this = std::move(other); + } + Swapchain& operator=(Swapchain&& other) { + std::swap(handle, other.handle); + std::swap(depth_handle, other.depth_handle); + std::swap(width, other.width); + std::swap(height, other.height); + images_gl = std::move(other.images_gl); + depth_images_gl = std::move(other.depth_images_gl); + framebuffers_gl = std::move(other.framebuffers_gl); + return *this; + } + virtual ~Swapchain(); + + void clear(); + + XrSwapchain handle{XR_NULL_HANDLE}; + XrSwapchain depth_handle{XR_NULL_HANDLE}; + + int32_t width = 0; + int32_t height = 0; + std::vector images_gl; + std::vector depth_images_gl; + std::vector framebuffers_gl; + }; + + int64_t m_swapchain_rgba_format = 0; + std::vector m_swapchains; + + bool m_supports_composition_layer_depth = false; + int64_t m_swapchain_depth_format = 0; + + bool m_supports_hidden_area_mask = false; + std::vector>> m_hidden_area_masks; + + bool m_supports_eye_tracking = false; + + // frame data + XrFrameState m_frame_state{XR_TYPE_FRAME_STATE}; + FrameInfoPtr m_previous_frame_info; + + GLuint m_hidden_area_mask_program = 0; + + // print more debug info during OpenXRs init: + const bool m_print_api_layers = false; + const bool m_print_extensions = false; + const bool m_print_system_properties = false; + const bool m_print_instance_properties = false; + const bool m_print_view_configuration_types = false; + const bool m_print_view_configuration_properties = false; + const bool m_print_view_configuration_view = false; + const bool m_print_environment_blend_modes = false; + const bool m_print_available_swapchain_formats = false; + const bool m_print_reference_spaces = false; +}; + +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/gui/include/neural-graphics-primitives/pybind11_vec.hpp b/gui/include/neural-graphics-primitives/pybind11_vec.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6ba55329b8efc36d0c81b390748624a4bb76962a --- /dev/null +++ b/gui/include/neural-graphics-primitives/pybind11_vec.hpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * NVIDIA CORPORATION and its licensors retain all intellectual property + * and proprietary rights in and to this software, related documentation + * and any modifications thereto. Any use, reproduction, disclosure or + * distribution of this software and related documentation without an express + * license agreement from NVIDIA CORPORATION is strictly prohibited. + */ + +/** @file pybind11_vec.cuh + * @author Thomas Müller, NVIDIA + * @brief pybind11 bindings for NGP's vector and matrix types. Adapted from + * Patrik Huber's glm binding code per the BSD license of pybind11. + */ + +#pragma once + +#include + +#include + +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +NAMESPACE_BEGIN(pybind11) +NAMESPACE_BEGIN(detail) + +/** + * @file utils/pybind11_glm.hpp + * @brief Transparent conversion to and from Python for glm vector and matrix types. + * + * All converters for matrices assume col-major storage of glm, the default. + * Things will likely break if non-default storage order is used. + */ + +template +struct type_caster> { + using vector_type = ngp::tvec; + using Scalar = T; + static constexpr std::size_t num_elements = N; + + bool load(handle src, bool) + { + array_t buf(src, true); + if (!buf.check()) + return false; + + if (buf.ndim() == 1) // a 1-dimensional vector + { + if (buf.shape(0) != num_elements) { + return false; // not a 2-elements vector + } + if (buf.strides(0) != sizeof(Scalar)) + { + std::cout << "An array with non-standard strides is given. Please pass a contiguous array." << std::endl; + return false; + } + value = vector_type(buf.mutable_data()); // make_vec* copies the data (unnecessarily) + } + else { // buf.ndim() != 1 + return false; + } + return true; + } + + static handle cast(const vector_type& src, return_value_policy /* policy */, handle /* parent */) + { + return array( + num_elements, // shape + src.data() // data + ).release(); + } + + // Specifies the doc-string for the type in Python: + PYBIND11_TYPE_CASTER(vector_type, _("vec")); +}; + +template +struct type_caster> { + using matrix_type = ngp::tmat; + using Scalar = T; + static constexpr std::size_t num_rows = M; + static constexpr std::size_t num_cols = N; + + bool load(handle src, bool) + { + array_t buf(src, true); + if (!buf.check()) + return false; + + if (buf.ndim() == 2) // a 2-dimensional matrix + { + if (buf.shape(0) != num_rows || buf.shape(1) != num_cols) { + return false; // not a 4x4 matrix + } + if (buf.strides(0) / sizeof(Scalar) != num_cols || buf.strides(1) != sizeof(Scalar)) + { + std::cout << "An array with non-standard strides is given. Please pass a contiguous array." << std::endl; + return false; + } + // What we get from Python is laid out in row-major memory order, while GLM's + // storage is col-major, thus, we transpose. + value = ngp::transpose(matrix_type(buf.mutable_data())); // make_mat*() copies the data (unnecessarily) + } + else { // buf.ndim() != 2 + return false; + } + return true; + } + + static handle cast(const matrix_type& src, return_value_policy /* policy */, handle /* parent */) + { + return array( + { num_rows, num_cols }, // shape + { sizeof(Scalar), sizeof(Scalar) * num_rows }, // strides - flip the row/col layout! + src.data() // data + ).release(); + } + + // Specifies the doc-string for the type in Python: + PYBIND11_TYPE_CASTER(matrix_type, _("mat")); +}; + +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/gui/include/neural-graphics-primitives/random_val.cuh b/gui/include/neural-graphics-primitives/random_val.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9bc6ec8a2a145a46f61fad0dd6cbadb7364ba939 --- /dev/null +++ b/gui/include/neural-graphics-primitives/random_val.cuh @@ -0,0 +1,332 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file random_val.cuh + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#pragma once + +#include + +#include + +namespace ngp { + +using default_rng_t = pcg32; + +inline constexpr NGP_HOST_DEVICE float PI() { return 3.14159265358979323846f; } + +template +inline __host__ __device__ float random_val(RNG& rng) { + return rng.next_float(); +} + +template +inline __host__ __device__ uint32_t random_uint(RNG& rng) { + return rng.next_uint(); +} + +template +inline __host__ __device__ vec2 random_val_2d(RNG& rng) { + return {rng.next_float(), rng.next_float()}; +} + +inline __host__ __device__ vec3 cylindrical_to_dir(const vec2& p) { + const float cos_theta = -2.0f * p.x + 1.0f; + const float phi = 2.0f * PI() * (p.y - 0.5f); + + const float sin_theta = sqrtf(fmaxf(1.0f - cos_theta * cos_theta, 0.0f)); + float sin_phi, cos_phi; + sincosf(phi, &sin_phi, &cos_phi); + + return {sin_theta * cos_phi, sin_theta * sin_phi, cos_theta}; +} + +inline __host__ __device__ vec2 dir_to_cylindrical(const vec3& d) { + const float cos_theta = fminf(fmaxf(-d.z, -1.0f), 1.0f); + float phi = atan2(d.y, d.x); + return {(cos_theta + 1.0f) / 2.0f, (phi / (2.0f * PI())) + 0.5f}; +} + +inline __host__ __device__ vec2 dir_to_spherical(const vec3& d) { + const float cos_theta = fminf(fmaxf(d.z, -1.0f), 1.0f); + const float theta = acosf(cos_theta); + float phi = atan2(d.y, d.x); + return {theta, phi}; +} + +inline __host__ __device__ vec2 dir_to_spherical_unorm(const vec3& d) { + vec2 spherical = dir_to_spherical(d); + return {spherical.x / PI(), (spherical.y / (2.0f * PI()) + 0.5f)}; +} + +template +inline __host__ __device__ vec3 random_dir(RNG& rng) { + return cylindrical_to_dir(random_val_2d(rng)); +} + +inline __host__ __device__ float fractf(float x) { + return x - floorf(x); +} + +template +__device__ __host__ vec3 fibonacci_dir(uint32_t i, const vec2& offset) { + // Fibonacci lattice with offset + float epsilon; + if (N_DIRS >= 11000) { + epsilon = 27; + } else if (N_DIRS >= 890) { + epsilon = 10; + } else if (N_DIRS >= 177) { + epsilon = 3.33; + } else if (N_DIRS >= 24) { + epsilon = 1.33; + } else { + epsilon = 0.33; + } + + static constexpr float GOLDEN_RATIO = 1.6180339887498948482045868343656f; + return cylindrical_to_dir(vec2{fractf((i+epsilon) / (N_DIRS-1+2*epsilon) + offset.x), fractf(i / GOLDEN_RATIO + offset.y)}); +} + +template +inline __host__ __device__ vec2 random_uniform_disc(RNG& rng) { + vec2 sample = random_val_2d(rng); + float r = sqrtf(sample.x); + float sin_phi, cos_phi; + sincosf(2.0f * PI() * sample.y, &sin_phi, &cos_phi); + return vec2{ r * sin_phi, r * cos_phi }; +} + +inline __host__ __device__ vec2 square2disk_shirley(const vec2& square) { + float phi, r; + float a = square.x; + float b = square.y; + if (a*a > b*b) { // use squares instead of absolute values + r = a; + phi = (PI()/4.0f) * (b/a); + } else { + r = b; + phi = (PI()/2.0f) - (PI()/4.0f) * (a/b); + } + + float sin_phi, cos_phi; + sincosf(phi, &sin_phi, &cos_phi); + + return {r*cos_phi, r*sin_phi}; +} + +inline __host__ __device__ __device__ vec3 cosine_hemisphere(const vec2& u) { + // Uniformly sample disk + const float r = sqrtf(u.x); + const float phi = 2.0f * PI() * u.y; + + vec3 p; + p.x = r * cosf(phi); + p.y = r * sinf(phi); + + // Project up to hemisphere + p.z = sqrtf(fmaxf(0.0f, 1.0f - p.x*p.x - p.y*p.y)); + + return p; +} + +template +inline __host__ __device__ vec3 random_dir_cosine(RNG& rng) { + return cosine_hemisphere(random_val_2d(rng)); +} + +template +inline __host__ __device__ vec3 random_val_3d(RNG& rng) { + return {rng.next_float(), rng.next_float(), rng.next_float()}; +} + +template +inline __host__ __device__ vec4 random_val_4d(RNG& rng) { + return {rng.next_float(), rng.next_float(), rng.next_float(), rng.next_float()}; +} + +// The below code has been adapted from Burley [2019] https://www.jcgt.org/published/0009/04/01/paper.pdf + +inline __host__ __device__ uint32_t sobol(uint32_t index, uint32_t dim) { + static constexpr uint32_t directions[5][32] = { + 0x80000000, 0x40000000, 0x20000000, 0x10000000, + 0x08000000, 0x04000000, 0x02000000, 0x01000000, + 0x00800000, 0x00400000, 0x00200000, 0x00100000, + 0x00080000, 0x00040000, 0x00020000, 0x00010000, + 0x00008000, 0x00004000, 0x00002000, 0x00001000, + 0x00000800, 0x00000400, 0x00000200, 0x00000100, + 0x00000080, 0x00000040, 0x00000020, 0x00000010, + 0x00000008, 0x00000004, 0x00000002, 0x00000001, + + 0x80000000, 0xc0000000, 0xa0000000, 0xf0000000, + 0x88000000, 0xcc000000, 0xaa000000, 0xff000000, + 0x80800000, 0xc0c00000, 0xa0a00000, 0xf0f00000, + 0x88880000, 0xcccc0000, 0xaaaa0000, 0xffff0000, + 0x80008000, 0xc000c000, 0xa000a000, 0xf000f000, + 0x88008800, 0xcc00cc00, 0xaa00aa00, 0xff00ff00, + 0x80808080, 0xc0c0c0c0, 0xa0a0a0a0, 0xf0f0f0f0, + 0x88888888, 0xcccccccc, 0xaaaaaaaa, 0xffffffff, + + 0x80000000, 0xc0000000, 0x60000000, 0x90000000, + 0xe8000000, 0x5c000000, 0x8e000000, 0xc5000000, + 0x68800000, 0x9cc00000, 0xee600000, 0x55900000, + 0x80680000, 0xc09c0000, 0x60ee0000, 0x90550000, + 0xe8808000, 0x5cc0c000, 0x8e606000, 0xc5909000, + 0x6868e800, 0x9c9c5c00, 0xeeee8e00, 0x5555c500, + 0x8000e880, 0xc0005cc0, 0x60008e60, 0x9000c590, + 0xe8006868, 0x5c009c9c, 0x8e00eeee, 0xc5005555, + + 0x80000000, 0xc0000000, 0x20000000, 0x50000000, + 0xf8000000, 0x74000000, 0xa2000000, 0x93000000, + 0xd8800000, 0x25400000, 0x59e00000, 0xe6d00000, + 0x78080000, 0xb40c0000, 0x82020000, 0xc3050000, + 0x208f8000, 0x51474000, 0xfbea2000, 0x75d93000, + 0xa0858800, 0x914e5400, 0xdbe79e00, 0x25db6d00, + 0x58800080, 0xe54000c0, 0x79e00020, 0xb6d00050, + 0x800800f8, 0xc00c0074, 0x200200a2, 0x50050093, + + 0x80000000, 0x40000000, 0x20000000, 0xb0000000, + 0xf8000000, 0xdc000000, 0x7a000000, 0x9d000000, + 0x5a800000, 0x2fc00000, 0xa1600000, 0xf0b00000, + 0xda880000, 0x6fc40000, 0x81620000, 0x40bb0000, + 0x22878000, 0xb3c9c000, 0xfb65a000, 0xddb2d000, + 0x78022800, 0x9c0b3c00, 0x5a0fb600, 0x2d0ddb00, + 0xa2878080, 0xf3c9c040, 0xdb65a020, 0x6db2d0b0, + 0x800228f8, 0x400b3cdc, 0x200fb67a, 0xb00ddb9d, + }; + + uint32_t X = 0; + + NGP_PRAGMA_UNROLL + for (uint32_t bit = 0; bit < 32; bit++) { + uint32_t mask = (index >> bit) & 1; + X ^= mask * directions[dim][bit]; + } + + return X; +} + +inline __host__ __device__ uvec2 sobol2d(uint32_t index) { + return {sobol(index, 0), sobol(index, 1)}; +} + +inline __host__ __device__ uvec4 sobol4d(uint32_t index) { + return {sobol(index, 0), sobol(index, 1), sobol(index, 2), sobol(index, 3)}; +} + +inline __host__ __device__ uint32_t hash_combine(uint32_t seed, uint32_t v) { + return seed ^ (v + (seed << 6) + (seed >> 2)); +} + +inline __host__ __device__ uint32_t reverse_bits(uint32_t x) { + x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); + x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2)); + x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4)); + x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8)); + return ((x >> 16) | (x << 16)); +} + +inline __host__ __device__ uint32_t laine_karras_permutation(uint32_t x, uint32_t seed) { + x += seed; + x ^= x * 0x6c50b47cu; + x ^= x * 0xb82f1e52u; + x ^= x * 0xc7afe638u; + x ^= x * 0x8d22f6e6u; + return x; +} + +inline __host__ __device__ uint32_t nested_uniform_scramble_base2(uint32_t x, uint32_t seed) { + x = reverse_bits(x); + x = laine_karras_permutation(x, seed); + x = reverse_bits(x); + return x; +} + +inline __host__ __device__ uvec4 shuffled_scrambled_sobol4d(uint32_t index, uint32_t seed) { + index = nested_uniform_scramble_base2(index, seed); + auto X = sobol4d(index); + for (uint32_t i = 0; i < 4; i++) { + X[i] = nested_uniform_scramble_base2(X[i], hash_combine(seed, i)); + } + return X; +} + +inline __host__ __device__ uvec2 shuffled_scrambled_sobol2d(uint32_t index, uint32_t seed) { + index = nested_uniform_scramble_base2(index, seed); + auto X = sobol2d(index); + for (uint32_t i = 0; i < 2; ++i) { + X[i] = nested_uniform_scramble_base2(X[i], hash_combine(seed, i)); + } + return X; +} + +inline __host__ __device__ vec4 ld_random_val_4d(uint32_t index, uint32_t seed) { + constexpr float S = float(1.0/(1ull<<32)); + uvec4 x = shuffled_scrambled_sobol4d(index, seed); + return {(float)x.x * S, (float)x.y * S, (float)x.z * S, (float)x.w * S}; +} + +inline __host__ __device__ vec2 ld_random_val_2d(uint32_t index, uint32_t seed) { + constexpr float S = float(1.0/(1ull<<32)); + uvec2 x = shuffled_scrambled_sobol2d(index, seed); + return {(float)x.x * S, (float)x.y * S}; +} + +inline __host__ __device__ float ld_random_val(uint32_t index, uint32_t seed, uint32_t dim = 0) { + constexpr float S = float(1.0/(1ull<<32)); + index = nested_uniform_scramble_base2(index, seed); + return (float)nested_uniform_scramble_base2(sobol(index, dim), hash_combine(seed, dim)) * S; +} + +template +__host__ __device__ float halton(size_t idx) { + float f = 1; + float result = 0; + + while (idx > 0) { + f /= base; + result += f * (idx % base); + idx /= base; + } + + return result; +} + +inline __host__ __device__ vec2 halton23(size_t idx) { + return {halton<2>(idx), halton<3>(idx)}; +} + +// Halton +// inline __host__ __device__ vec2 ld_random_pixel_offset(const uint32_t spp) { +// vec2 offset = vec2(0.5f) - halton23(0) + halton23(spp); +// offset.x = fractf(offset.x); +// offset.y = fractf(offset.y); +// return offset; +// } + +// Scrambled Sobol +inline __host__ __device__ vec2 ld_random_pixel_offset(const uint32_t spp) { + vec2 offset = vec2(0.5f) - ld_random_val_2d(0, 0xdeadbeef) + ld_random_val_2d(spp, 0xdeadbeef); + offset.x = fractf(offset.x); + offset.y = fractf(offset.y); + return offset; +} + +} + diff --git a/gui/include/neural-graphics-primitives/render_buffer.h b/gui/include/neural-graphics-primitives/render_buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..f7bdfcc318d184e38833145cfea14ef17c713225 --- /dev/null +++ b/gui/include/neural-graphics-primitives/render_buffer.h @@ -0,0 +1,329 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file render_buffer.h + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace ngp { + +typedef unsigned int GLenum; +typedef int GLint; +typedef unsigned int GLuint; + +class SurfaceProvider { +public: + virtual cudaSurfaceObject_t surface() = 0; + virtual cudaArray_t array() = 0; + virtual ivec2 resolution() const = 0; + virtual void resize(const ivec2&, int n_channels = 4) = 0; +}; + +class CudaSurface2D : public SurfaceProvider { +public: + CudaSurface2D() { + m_array = nullptr; + m_surface = 0; + } + + ~CudaSurface2D() { + free(); + } + + void free(); + + void resize(const ivec2& size, int n_channels) override; + + cudaSurfaceObject_t surface() override { + return m_surface; + } + + cudaArray_t array() override { + return m_array; + } + + ivec2 resolution() const override { + return m_size; + } + +private: + ivec2 m_size = ivec2(0); + int m_n_channels = 0; + cudaArray_t m_array; + cudaSurfaceObject_t m_surface; +}; + +#ifdef NGP_GUI +class GLTexture : public SurfaceProvider { +public: + GLTexture() = default; + GLTexture(const std::string& texture_name) + : m_texture_name(texture_name), m_texture_id(0) + { } + + GLTexture(const GLTexture& other) = delete; + + GLTexture(GLTexture&& other) + : m_texture_name(move(other.m_texture_name)), m_texture_id(other.m_texture_id) { + other.m_texture_id = 0; + } + + GLTexture& operator=(GLTexture&& other) { + m_texture_name = move(other.m_texture_name); + std::swap(m_texture_id, other.m_texture_id); + return *this; + } + + ~GLTexture(); + + GLuint texture(); + + cudaSurfaceObject_t surface() override; + + cudaArray_t array() override; + + void blit_from_cuda_mapping(); + + const std::string& texture_name() const { return m_texture_name; } + + bool is_8bit() { return m_is_8bit; } + + void load(const fs::path& path); + + void load(const float* data, ivec2 new_size, int n_channels); + + void load(const uint8_t* data, ivec2 new_size, int n_channels); + + void resize(const ivec2& new_size, int n_channels, bool is_8bit); + + void resize(const ivec2& new_size, int n_channels) override { + resize(new_size, n_channels, false); + } + + ivec2 resolution() const override { + return m_size; + } + +private: + class CUDAMapping { + public: + CUDAMapping(GLuint texture_id, const ivec2& size, int n_channels); + ~CUDAMapping(); + + cudaSurfaceObject_t surface() const { return m_cuda_surface ? m_cuda_surface->surface() : m_surface; } + + cudaArray_t array() const { return m_cuda_surface ? m_cuda_surface->array() : m_mapped_array; } + + bool is_interop() const { return !m_cuda_surface; } + + const float* data_cpu(); + + private: + cudaGraphicsResource_t m_graphics_resource = {}; + cudaArray_t m_mapped_array = {}; + cudaSurfaceObject_t m_surface = {}; + + ivec2 m_size; + int m_n_channels; + std::vector m_data_cpu; + + std::unique_ptr m_cuda_surface; + }; + + std::string m_texture_name; + GLuint m_texture_id = 0; + ivec2 m_size = ivec2(0); + int m_n_channels = 0; + GLint m_internal_format; + GLenum m_format; + bool m_is_8bit = false; + std::unique_ptr m_cuda_mapping; +}; + +bool check_shader(uint32_t handle, const char* desc, bool program); +uint32_t compile_shader(bool pixel, const char* code); +#endif //NGP_GUI + +struct CudaRenderBufferView { + vec4* frame_buffer = nullptr; + float* depth_buffer = nullptr; + ivec2 resolution = ivec2(0); + uint32_t spp = 0; + + std::shared_ptr> hidden_area_mask = nullptr; + + void clear(cudaStream_t stream) const; +}; + +class CudaRenderBuffer { +public: + CudaRenderBuffer(const std::shared_ptr& rgba, const std::shared_ptr& depth = nullptr) : m_rgba_target{rgba}, m_depth_target{depth} {} + + CudaRenderBuffer(const CudaRenderBuffer& other) = delete; + CudaRenderBuffer& operator=(const CudaRenderBuffer& other) = delete; + CudaRenderBuffer(CudaRenderBuffer&& other) = default; + CudaRenderBuffer& operator=(CudaRenderBuffer&& other) = default; + + cudaSurfaceObject_t surface() { + return m_rgba_target->surface(); + } + + ivec2 in_resolution() const { + return m_in_resolution; + } + + ivec2 out_resolution() const { + return m_rgba_target->resolution(); + } + + void resize(const ivec2& res); + + void reset_accumulation() { + m_spp = 0; + } + + uint32_t spp() const { + return m_spp; + } + + void set_spp(uint32_t value) { + m_spp = value; + } + + vec4* frame_buffer() const { + return m_frame_buffer.data(); + } + + float* depth_buffer() const { + return m_depth_buffer.data(); + } + + vec4* accumulate_buffer() const { + return m_accumulate_buffer.data(); + } + + CudaRenderBufferView view() const { + return { + frame_buffer(), + depth_buffer(), + in_resolution(), + spp(), + hidden_area_mask(), + }; + } + + void clear_frame(cudaStream_t stream); + + void accumulate(float exposure, cudaStream_t stream); + + void tonemap(float exposure, const vec4& background_color, EColorSpace output_color_space, float znear, float zfar, bool snap_to_pixel_centers, cudaStream_t stream); + + void overlay_image( + float alpha, + const vec3& exposure, + const vec4& background_color, + EColorSpace output_color_space, + const void* __restrict__ image, + EImageDataType image_data_type, + const ivec2& resolution, + int fov_axis, + float zoom, + const vec2& screen_center, + cudaStream_t stream + ); + + void overlay_depth( + float alpha, + const float* __restrict__ depth, + float depth_scale, + const ivec2& resolution, + int fov_axis, + float zoom, + const vec2& screen_center, + cudaStream_t stream + ); + + void overlay_false_color(ivec2 training_resolution, bool to_srgb, int fov_axis, cudaStream_t stream, const float *error_map, ivec2 error_map_resolution, const float *average, float brightness, bool viridis); + + SurfaceProvider& surface_provider() { + return *m_rgba_target; + } + + void set_color_space(EColorSpace color_space) { + if (color_space != m_color_space) { + m_color_space = color_space; + reset_accumulation(); + } + } + + void set_tonemap_curve(ETonemapCurve tonemap_curve) { + if (tonemap_curve != m_tonemap_curve) { + m_tonemap_curve = tonemap_curve; + reset_accumulation(); + } + } + + void enable_dlss(IDlssProvider& dlss_provider, const ivec2& max_out_res); + void disable_dlss(); + void set_dlss_sharpening(float value) { + m_dlss_sharpening = value; + } + + const std::unique_ptr& dlss() const { + return m_dlss; + } + + void set_hidden_area_mask(const std::shared_ptr>& hidden_area_mask) { + m_hidden_area_mask = hidden_area_mask; + } + + const std::shared_ptr>& hidden_area_mask() const { + return m_hidden_area_mask; + } + +private: + uint32_t m_spp = 0; + EColorSpace m_color_space = EColorSpace::Linear; + ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity; + + std::unique_ptr m_dlss; + float m_dlss_sharpening = 0.0f; + + ivec2 m_in_resolution = ivec2(0); + + GPUMemory m_frame_buffer; + GPUMemory m_depth_buffer; + GPUMemory m_accumulate_buffer; + + std::shared_ptr> m_hidden_area_mask = nullptr; + + std::shared_ptr m_rgba_target; + std::shared_ptr m_depth_target; +}; + +} diff --git a/gui/include/neural-graphics-primitives/shared_queue.h b/gui/include/neural-graphics-primitives/shared_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..6b1c01efeab50ee9c3957e3dacc3da564177acdb --- /dev/null +++ b/gui/include/neural-graphics-primitives/shared_queue.h @@ -0,0 +1,127 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file was taken from the tev image viewer and is re-released here +// under the NVIDIA Source Code License with permission from the author. + +#pragma once + +#include +#include +#include + +namespace ngp { + +class ICallable { +public: + virtual ~ICallable() {} + virtual void operator()() = 0; +}; + +template +class Callable : public ICallable { +public: + Callable() = default; + Callable(const T& callable) : m_callable{callable} {} + Callable(T&& callable) : m_callable{std::forward(callable)} {} + Callable(const Callable& other) = delete; + Callable& operator=(Callable&& other) { std::swap(m_callable, other.m_callable); return *this; } + Callable(Callable&& other) { *this = std::move(other); } + + void operator()() override { + m_callable(); + } +private: + T m_callable; +}; + +template +std::unique_ptr callable(T&& callable) { + return std::make_unique>(std::forward(callable)); +} + +class SharedQueueEmptyException {}; + +template +class SharedQueue { +public: + bool empty() const { + std::lock_guard lock{mMutex}; + return mRawQueue.empty(); + } + + size_t size() const { + std::lock_guard lock{mMutex}; + return mRawQueue.size(); + } + + void push(T&& newElem) { + std::lock_guard lock{mMutex}; + mRawQueue.emplace_back(std::forward(newElem)); + mDataCondition.notify_one(); + } + + void clear() { + std::lock_guard lock{mMutex}; + mRawQueue.clear(); + } + + void clearAndPush(T&& newElem) { + std::lock_guard lock{mMutex}; + mRawQueue.clear(); + mRawQueue.emplace_back(std::forward(newElem)); + mDataCondition.notify_one(); + } + + T waitAndPop() { + std::unique_lock lock{mMutex}; + + while (mRawQueue.empty()) { + mDataCondition.wait(lock); + } + + T result = std::move(mRawQueue.front()); + mRawQueue.pop_front(); + + return result; + } + + T tryPop(bool back = false) { + std::unique_lock lock{mMutex}; + + if (mRawQueue.empty()) { + throw SharedQueueEmptyException{}; + } + + if (back) { + T result = std::move(mRawQueue.back()); + mRawQueue.pop_back(); + return result; + } else { + T result = std::move(mRawQueue.front()); + mRawQueue.pop_front(); + return result; + } + } + +private: + std::deque mRawQueue; + mutable std::mutex mMutex; + std::condition_variable mDataCondition; +}; + +} diff --git a/gui/include/neural-graphics-primitives/testbed.h b/gui/include/neural-graphics-primitives/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..4b48831a08ec094cdadc79ee6837ff31f2fd83bb --- /dev/null +++ b/gui/include/neural-graphics-primitives/testbed.h @@ -0,0 +1,635 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file testbed.h + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef NGP_GUI +# include +#endif + +#include +#include + +#include + +#ifdef NGP_PYTHON +# include +# include +#endif + +#include +#include + +struct GLFWwindow; + +namespace ngp { + +struct Triangle; +class GLTexture; + +struct ViewIdx { + i16vec2 px; + uint32_t view; +}; + +class Testbed { +public: + Testbed(ETestbedMode mode = ETestbedMode::None); + ~Testbed(); + + bool clear_tmp_dir(); + void update_imgui_paths(); + + void set_mode(ETestbedMode mode); + + using distance_fun_t = std::function; + using normals_fun_t = std::function; + + struct LevelStats { + float mean() { return count ? (x / (float)count) : 0.f; } + float variance() { return count ? (xsquared - (x * x) / (float)count) / (float)count : 0.f; } + float sigma() { return sqrtf(variance()); } + float fraczero() { return (float)numzero / float(count + numzero); } + float fracquant() { return (float)numquant / float(count); } + + float x; + float xsquared; + float min; + float max; + int numzero; + int numquant; + int count; + }; + + class CudaDevice; + + struct View { + std::shared_ptr render_buffer = nullptr; + ivec2 full_resolution = {1, 1}; + int visualized_dimension = 0; + + mat4x3 camera0 = mat4x3::identity(); + mat4x3 camera1 = mat4x3::identity(); + mat4x3 prev_camera = mat4x3::identity(); + + Foveation foveation; + Foveation prev_foveation; + + vec2 relative_focal_length; + vec2 screen_center; + + Lens lens; + + CudaDevice* device = nullptr; + + GPUImage index_field; + GPUImage hole_mask; + GPUImage depth_buffer; + + + vec2 fov() const { return relative_focal_length_to_fov(relative_focal_length); } + + uint32_t uid = 0; + }; + + void render_by_reprojection(cudaStream_t stream, std::vector& views); + + void render_frame( + cudaStream_t stream, + const mat4x3& camera_matrix0, + const mat4x3& camera_matrix1, + const mat4x3& prev_camera_matrix, + const vec2& screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Foveation& prev_foveation, + const Lens& lens, + int visualized_dimension, + CudaRenderBuffer& render_buffer, + bool to_srgb = true, + CudaDevice* device = nullptr + ); + void render_frame_main( + CudaDevice& device, + const mat4x3& camera_matrix0, + const mat4x3& camera_matrix1, + const vec2& screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Lens& lens, + int visualized_dimension + ); + void render_frame_epilogue( + cudaStream_t stream, + const mat4x3& camera_matrix0, + const mat4x3& prev_camera_matrix, + const vec2& screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Foveation& prev_foveation, + const Lens& lens, + CudaRenderBuffer& render_buffer, + bool to_srgb = true + ); + + void init_camera_path_from_reproject_src_cameras(); + void visualize_reproject_src_cameras(ImDrawList* list, const mat4& world2proj); + void clear_src_views(); + + void reset_accumulation(bool due_to_camera_movement = false, bool immediate_redraw = true, bool reset_pip = false); + void redraw_next_frame() { m_render_skip_due_to_lack_of_camera_movement_counter = 0; } + bool reprojection_available() { return m_dlss; } + void load_mesh(const fs::path& data_path); + void set_exposure(float exposure) { m_exposure = exposure; } + void translate_camera(const vec3& rel, const mat3& rot, bool allow_up_down = true); + mat3 rotation_from_angles(const vec2& angles) const; + void mouse_drag(); + void mouse_wheel(); + void load_file(const fs::path& path); + vec3 look_at() const; + void set_look_at(const vec3& pos); + float scale() const { return m_scale; } + void set_scale(float scale); + vec3 view_pos() const { return m_camera[3]; } + vec3 view_dir() const { return m_camera[2]; } + vec3 view_up() const { return m_camera[1]; } + vec3 view_side() const { return m_camera[0]; } + void set_view_dir(const vec3& dir); + void reset_camera(); + bool keyboard_event(); + void update_density_grid_mean_and_bitfield(cudaStream_t stream); + void mark_density_grid_in_sphere_empty(const vec3& pos, float radius, cudaStream_t stream); + + void prepare_next_camera_path_frame(); + void overlay_fps(); + void imgui(); + vec2 calc_focal_length(const ivec2& resolution, const vec2& relative_focal_length, int fov_axis, float zoom) const; + vec2 render_screen_center(const vec2& screen_center) const; + void optimise_mesh_step(uint32_t N_STEPS); + void compute_mesh_vertex_colors(); + + float get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const vec2& uv); + vec3 get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const vec2& focus_pixel); + void autofocus(); + +#ifdef NGP_PYTHON + std::pair, pybind11::array_t> + render_to_cpu(int width, int height, int spp, bool linear, float start_t, float end_t, float fps, float shutter_fraction); + pybind11::array_t + render_to_cpu_rgba(int width, int height, int spp, bool linear, float start_t, float end_t, float fps, float shutter_fraction); + pybind11::array_t view(bool linear, size_t view) const; + std::pair, pybind11::array_t> + reproject(const mat4x3& src, const pybind11::array_t& src_img, const pybind11::array_t& src_depth, const mat4x3& dst); + uint32_t add_src_view( + mat4x3 camera_to_world, + float fx, + float fy, + float cx, + float cy, + Lens lens, + pybind11::array_t img, + pybind11::array_t depth, + float timestamp, + bool is_srgb = false + ); + pybind11::array_t src_view_ids() const; +# ifdef NGP_GUI + pybind11::array_t screenshot(bool linear, bool front_buffer) const; +# endif +#endif + + mat4x3 view_camera(size_t view) const; + + + void draw_visualizations(ImDrawList* list, const mat4x3& camera_matrix); + void reproject_views(const std::vector src, View& dst); + void render(bool skip_rendering); + void init_window(int resw, int resh, bool hidden = false, bool second_window = false); + void destroy_window(); + void init_vr(); + void update_vr_performance_settings(); + void apply_camera_smoothing(float elapsed_ms); + bool begin_frame(); + void handle_user_input(); + vec3 vr_to_world(const vec3& pos) const; + void begin_vr_frame_and_handle_vr_input(); + void draw_gui(); + bool frame(); + bool want_repl(); + void load_image(const fs::path& data_path); + void load_exr_image(const fs::path& data_path); + void load_stbi_image(const fs::path& data_path); + void load_binary_image(const fs::path& data_path); + float fov() const; + void set_fov(float val); + vec2 fov_xy() const; + void set_fov_xy(const vec2& val); + CameraKeyframe copy_camera_to_keyframe() const; + void set_camera_from_keyframe(const CameraKeyframe& k); + void set_camera_from_time(float t); + void load_camera_path(const fs::path& path); + bool loop_animation(); + void set_loop_animation(bool value); + + fs::path root_dir(); + void set_root_dir(const fs::path& dir); + + bool m_want_repl = false; + + bool m_render_window = false; + bool m_gather_histograms = false; + + bool m_render_ground_truth = false; + EGroundTruthRenderMode m_ground_truth_render_mode = EGroundTruthRenderMode::Shade; + float m_ground_truth_alpha = 1.0f; + + bool m_render = true; + int m_max_spp = 0; + ETestbedMode m_testbed_mode = ETestbedMode::None; + + // Rendering stuff + ivec2 m_window_res = ivec2(0); + bool m_dynamic_res = false; + float m_dynamic_res_target_fps = 20.0f; + int m_fixed_res_factor = 8; + float m_scale = 1.0; + float m_aperture_size = 0.0f; + vec2 m_relative_focal_length = vec2(1.0f); + uint32_t m_fov_axis = 1; + float m_zoom = 1.f; // 2d zoom factor (for insets?) + vec2 m_screen_center = vec2(0.5f); // center of 2d zoom + + float m_ndc_znear = 1.0f / 32.0f; + float m_ndc_zfar = 128.0f; + + mat4x3 m_camera = mat4x3::identity(); + mat4x3 m_default_camera = transpose(mat3x4{1.0f, 0.0f, 0.0f, 0.5f, 0.0f, -1.0f, 0.0f, 0.5f, 0.0f, 0.0f, -1.0f, 0.5f}); + mat4x3 m_smoothed_camera = mat4x3::identity(); + size_t m_render_skip_due_to_lack_of_camera_movement_counter = 0; + + bool m_fps_camera = false; + bool m_camera_smoothing = false; + bool m_autofocus = false; + vec3 m_autofocus_target = vec3(0.5f); + + bool m_render_with_lens_distortion = false; + Lens m_render_lens = {}; + + CameraPath m_camera_path = {}; + bool m_record_camera_path = false; + + vec3 m_up_dir = {0.0f, 1.0f, 0.0f}; + vec3 m_sun_dir = normalize(vec3(1.0f)); + float m_bounding_radius = 1; + float m_exposure = 0.f; + + ERenderMode m_render_mode = ERenderMode::Shade; + + uint32_t m_seed = 1337; + +#ifdef NGP_GUI + GLFWwindow* m_glfw_window = nullptr; + struct SecondWindow { + GLFWwindow* window = nullptr; + GLuint program = 0; + GLuint vao = 0, vbo = 0; + void draw(GLuint texture); + } m_second_window; + + float m_drag_depth = 1.0f; + + // The VAO will be empty, but we need a valid one for attribute-less rendering + GLuint m_blit_vao = 0; + GLuint m_blit_program = 0; + + void init_opengl_shaders(); + void blit_texture( + const Foveation& foveation, + GLint rgba_texture, + GLint rgba_filter_mode, + GLint depth_texture, + GLint framebuffer, + const ivec2& offset, + const ivec2& resolution + ); + + void create_second_window(); + + std::unique_ptr m_hmd; + OpenXRHMD::FrameInfoPtr m_vr_frame_info; + + bool m_vr_use_depth_reproject = false; + bool m_vr_use_hidden_area_mask = false; + + std::deque m_reproject_src_views; + View m_reproject_pending_view; + + int m_reproject_min_src_view_index = 0; + int m_reproject_max_src_view_index = 1; + int m_reproject_max_src_view_count = -1; // -1 indicates unlimited + uint32_t m_reproject_selected_src_view = 0; + bool m_reproject_freeze_src_views = false; + int m_reproject_n_views_to_cache = 1; + bool m_reproject_visualize_src_views = false; + + float m_reproject_min_t = 0.1f; + float m_reproject_step_factor = 1.05f; + vec3 m_reproject_parallax = vec3(0.0f, 0.0f, 0.0f); + bool m_reproject_enable = false; + bool m_reproject_reuse_last_frame = true; + + float m_reproject_lazy_render_ms = 100.0f; + float m_reproject_lazy_render_res_factor = 1.25f; + + + bool m_pm_enable = false; + EPmVizMode m_pm_viz_mode = EPmVizMode::Shade; + + void set_n_views(size_t n_views); + + // Callback invoked when a keyboard event is detected. + // If the callback returns `true`, the event is considered handled and the default behavior will not occur. + std::function m_keyboard_event_callback; + + // Callback invoked when a file is dropped onto the window. + // If the callback returns `true`, the files are considered handled and the default behavior will not occur. + std::function&)> m_file_drop_callback; + + std::shared_ptr m_pip_render_texture; + std::vector> m_rgba_render_textures; + std::vector> m_depth_render_textures; +#endif + + std::shared_ptr m_pip_render_buffer; + + SharedQueue> m_task_queue; + + void redraw_gui_next_frame() { m_gui_redraw = true; } + + bool m_gui_redraw = true; + + enum EDataType { + Float, + Half, + }; + + struct VolPayload { + vec3 dir; + vec4 col; + uint32_t pixidx; + }; + + float m_camera_velocity = 1.0f; + EColorSpace m_color_space = EColorSpace::Linear; + ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity; + bool m_dlss = false; + std::shared_ptr m_dlss_provider; + float m_dlss_sharpening = 0.0f; + + // 3D stuff + float m_render_near_distance = 0.0f; + float m_slice_plane_z = 0.0f; + bool m_floor_enable = false; + inline float get_floor_y() const { return m_floor_enable ? m_aabb.min.y + 0.001f : -10000.f; } + BoundingBox m_raw_aabb; + BoundingBox m_aabb = {vec3(0.0f), vec3(1.0f)}; + BoundingBox m_render_aabb = {vec3(0.0f), vec3(1.0f)}; + mat3 m_render_aabb_to_local = mat3::identity(); + + // Rendering/UI bookkeeping + Ema m_render_ms = {EEmaType::Time, 100}; + // The frame contains everything, i.e. rendering + GUI and buffer swapping + Ema m_frame_ms = {EEmaType::Time, 100}; + std::chrono::time_point m_last_frame_time_point; + std::chrono::time_point m_last_gui_draw_time_point; + vec4 m_background_color = {0.0f, 0.0f, 0.0f, 1.0f}; + + bool m_vsync = true; + bool m_render_transparency_as_checkerboard = false; + + // Visualization of neuron activations + int m_visualized_dimension = -1; + int m_visualized_layer = 0; + + std::vector m_views; + ivec2 m_n_views = {1, 1}; + + float m_picture_in_picture_res = 0.f; // if non zero, requests a small second picture :) + + enum class ImGuiMode : uint32_t { + Enabled, + FpsOverlay, + Disabled, + // Don't set the below + NumModes, + }; + + struct ImGuiVars { + static const uint32_t MAX_PATH_LEN = 1024; + + ImGuiMode mode = ImGuiMode::Enabled; // tab to cycle + char cam_path_path[MAX_PATH_LEN] = "cam.json"; + char video_path[MAX_PATH_LEN] = "video.mp4"; + char cam_export_path[MAX_PATH_LEN] = "cam_export.json"; + + void* overlay_font = nullptr; + } m_imgui; + + fs::path m_root_dir = ""; + + bool m_visualize_unit_cube = false; + bool m_edit_render_aabb = false; + bool m_edit_world_transform = true; + + bool m_snap_to_pixel_centers = false; + + vec3 m_parallax_shift = {0.0f, 0.0f, 0.0f}; // to shift the viewer's origin by some amount in camera space + + StreamAndEvent m_stream; + + class CudaDevice { + public: + struct Data { + std::shared_ptr> hidden_area_mask; + }; + + CudaDevice(int id, bool is_primary); + + CudaDevice(const CudaDevice&) = delete; + CudaDevice& operator=(const CudaDevice&) = delete; + + CudaDevice(CudaDevice&&) = default; + CudaDevice& operator=(CudaDevice&&) = default; + + ScopeGuard device_guard(); + + int id() const { return m_id; } + + bool is_primary() const { return m_is_primary; } + + std::string name() const { return cuda_device_name(m_id); } + + int compute_capability() const { return cuda_compute_capability(m_id); } + + cudaStream_t stream() const { return m_stream->get(); } + + void wait_for(cudaStream_t stream) const { + CUDA_CHECK_THROW(cudaEventRecord(m_primary_device_event.event, stream)); + m_stream->wait_for(m_primary_device_event.event); + } + + void signal(cudaStream_t stream) const { m_stream->signal(stream); } + + const CudaRenderBufferView& render_buffer_view() const { return m_render_buffer_view; } + + void set_render_buffer_view(const CudaRenderBufferView& view) { m_render_buffer_view = view; } + + Data& data() const { return *m_data; } + + bool dirty() const { return m_dirty; } + + void set_dirty(bool value) { m_dirty = value; } + + void clear() { + m_data = std::make_unique(); + m_render_buffer_view = {}; + set_dirty(true); + } + + template auto enqueue_task(F&& f) -> std::future> { + if (is_primary()) { + return std::async(std::launch::deferred, std::forward(f)); + } else { + return m_render_worker->enqueue_task(std::forward(f)); + } + } + + private: + int m_id; + bool m_is_primary; + std::unique_ptr m_stream; + struct Event { + Event() { CUDA_CHECK_THROW(cudaEventCreate(&event)); } + + ~Event() { cudaEventDestroy(event); } + + Event(const Event&) = delete; + Event& operator=(const Event&) = delete; + Event(Event&& other) { *this = std::move(other); } + Event& operator=(Event&& other) { + std::swap(event, other.event); + return *this; + } + + cudaEvent_t event = {}; + }; + Event m_primary_device_event; + std::unique_ptr m_data; + CudaRenderBufferView m_render_buffer_view = {}; + + bool m_dirty = true; + + std::unique_ptr m_render_worker; + }; + + void sync_device(CudaRenderBuffer& render_buffer, CudaDevice& device); + ScopeGuard use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, CudaDevice& device); + void set_all_devices_dirty(); + + std::vector m_devices; + CudaDevice& primary_device() { return m_devices.front(); } + + ThreadPool m_thread_pool; + std::vector> m_render_futures; + + bool m_use_aux_devices = false; + bool m_foveated_rendering = false; + bool m_dynamic_foveated_rendering = true; + float m_foveated_rendering_full_res_diameter = 0.55f; + float m_foveated_rendering_scaling = 1.0f; + float m_foveated_rendering_max_scaling = 2.0f; + bool m_foveated_rendering_visualize = false; + + default_rng_t m_rng; + + CudaRenderBuffer m_windowless_render_surface{std::make_shared()}; + + // ---------- Gen3C stuff + /** + * Common signature for Gen3C-related UI callback functions, to be implemented + * in Python. + * + * Inputs: + * name: name of the UI event (e.g. name of the button pressed). + * + * Returns: bool, whether the operation was successful. + */ + using gen3c_cb_t = std::function; + gen3c_cb_t m_gen3c_cb; + + // Info string to be displayed in the Gen3C UI window. + std::string m_gen3c_info; + // Path to an image or directory to use to seed the generative model. + // The specific format is guessed based on what the path points to. + std::string m_gen3c_seed_path; + // Whether to automatically launch new inference requests. + bool m_gen3c_auto_inference = false; + + EGen3cCameraSource m_gen3c_camera_source = EGen3cCameraSource::Authored; + // Fake translation speed in scene unit / frame. + vec3 m_gen3c_translation_speed = {0.05f, 0.f, 0.f}; + // Fake rotation speed around (x, y, z) in radians / frame. + vec3 m_gen3c_rotation_speed = {0.f, 0.05f, 0.f}; + + // Number of frames to request for each inference request. + std::string m_gen3c_inference_info = ""; + + // Progress of seeding-related things (scale 0..1). Set to a negative value to hide the progress bar. + float m_gen3c_seeding_progress = -1.0f; + // Progress of inference-related things (scale 0..1). Set to a negative value to hide the progress bar. + float m_gen3c_inference_progress = -1.0f; + + // Saving Gen3C inference outputs + bool m_gen3c_save_frames = false; + // Whether or not to display generated frames in the UI. + // No display means that we can save some time by not de-compressing + // the result video from the server, and even skip depth prediction for most frames. + bool m_gen3c_display_frames = false; + std::string m_gen3c_output_dir = ""; + + // When rendering with Gen3C, whether to include the rendered cache in the generated video (for debugging / visualization) + bool m_gen3c_show_cache_renderings = false; + + bool m_gen3c_inference_is_connected = false; + // Either we render the camera path from the local pointcloud or we use the inference server to get a photoreal video + bool m_gen3c_render_with_gen3c = true; +}; + +} // namespace ngp diff --git a/gui/include/neural-graphics-primitives/thread_pool.h b/gui/include/neural-graphics-primitives/thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..60c035a3b378097aaf3951ab385dbce5faadfa89 --- /dev/null +++ b/gui/include/neural-graphics-primitives/thread_pool.h @@ -0,0 +1,116 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file was taken from the tev image viewer and is re-released here +// under the NVIDIA Source Code License with permission from the author. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace ngp { + +template +void wait_all(T&& futures) { + for (auto& f : futures) { + f.get(); + } +} + +class ThreadPool { +public: + ThreadPool(); + ThreadPool(size_t maxNum_threads, bool force = false); + virtual ~ThreadPool(); + + template + auto enqueue_task(F&& f, bool high_priority = false) -> std::future> { + using return_type = std::result_of_t; + + auto task = std::make_shared>(std::forward(f)); + + auto res = task->get_future(); + + { + std::lock_guard lock{m_task_queue_mutex}; + + if (high_priority) { + m_task_queue.emplace_front([task]() { (*task)(); }); + } else { + m_task_queue.emplace_back([task]() { (*task)(); }); + } + } + + m_worker_condition.notify_one(); + return res; + } + + void start_threads(size_t num); + void shutdown_threads(size_t num); + void set_n_threads(size_t num); + + void wait_until_queue_completed(); + void flush_queue(); + + template + void parallel_for_async(Int start, Int end, F body, std::vector>& futures) { + Int local_num_threads = (Int)m_num_threads; + + Int range = end - start; + Int chunk = (range / local_num_threads) + 1; + + for (Int i = 0; i < local_num_threads; ++i) { + futures.emplace_back(enqueue_task([i, chunk, start, end, body] { + Int inner_start = start + i * chunk; + Int inner_end = std::min(end, start + (i + 1) * chunk); + for (Int j = inner_start; j < inner_end; ++j) { + body(j); + } + })); + } + } + + template + std::vector> parallel_for_async(Int start, Int end, F body) { + std::vector> futures; + parallel_for_async(start, end, body, futures); + return futures; + } + + template + void parallel_for(Int start, Int end, F body) { + wait_all(parallel_for_async(start, end, body)); + } + +private: + size_t m_num_threads = 0; + std::vector m_threads; + + std::deque> m_task_queue; + std::mutex m_task_queue_mutex; + std::condition_variable m_worker_condition; + std::condition_variable m_task_queue_completed_condition; +}; + +} diff --git a/gui/include/neural-graphics-primitives/triangle.cuh b/gui/include/neural-graphics-primitives/triangle.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5813cefe1696237ff289443a3b4a34a135403d70 --- /dev/null +++ b/gui/include/neural-graphics-primitives/triangle.cuh @@ -0,0 +1,216 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file triangle.cuh + * @author Thomas Müller & Alex Evans, NVIDIA + * @brief CUDA/C++ triangle implementation. + */ + +#pragma once + +#include +#include + +#include + +namespace ngp { + +inline NGP_HOST_DEVICE float normdot(const vec3 &a, const vec3 &b) { + float div = length(a) * length(b); + if (div == 0.0f) { + return 1.0f; + } + + return dot(a, b) / div; +} + +inline NGP_HOST_DEVICE float angle(const vec3 &a, const vec3 &b) { + return acosf(clamp(normdot(a, b), -1.0f, 1.0f)); +} + +struct Triangle { + NGP_HOST_DEVICE vec3 sample_uniform_position(const vec2& sample) const { + float sqrt_x = sqrt(sample.x); + float factor0 = 1.0f - sqrt_x; + float factor1 = sqrt_x * (1.0f - sample.y); + float factor2 = sqrt_x * sample.y; + + return factor0 * a + factor1 * b + factor2 * c; + } + + NGP_HOST_DEVICE float surface_area() const { + return 0.5f * length(cross(b - a, c - a)); + } + + NGP_HOST_DEVICE vec3 normal() const { + return normalize(cross(b - a, c - a)); + } + + NGP_HOST_DEVICE const vec3 &operator[](uint32_t i) const { + return i == 0 ? a : (i == 1 ? b : c); + } + + NGP_HOST_DEVICE float angle_at_vertex(uint32_t i) const { + vec3 v1 = (*this)[i] - (*this)[(i + 1) % 3]; + vec3 v2 = (*this)[i] - (*this)[(i + 2) % 3]; + return angle(v1, v2); + } + + NGP_HOST_DEVICE uint32_t closest_vertex_idx(const vec3 &pos) const { + float mag1 = length2(pos - a); + float mag2 = length2(pos - b); + float mag3 = length2(pos - c); + + float minv = min(vec3{ mag1, mag2, mag3 }); + + if (minv == mag1) { + return 0; + } else if (minv == mag2) { + return 1; + } else { + return 2; + } + } + + NGP_HOST_DEVICE float angle_at_pos(const vec3 &pos) const { + return angle_at_vertex(closest_vertex_idx(pos)); + } + + // based on https://www.iquilezles.org/www/articles/intersectors/intersectors.htm + NGP_HOST_DEVICE float ray_intersect(const vec3 &ro, const vec3 &rd, vec3& n) const { + vec3 v1v0 = b - a; + vec3 v2v0 = c - a; + vec3 rov0 = ro - a; + n = cross(v1v0, v2v0); + vec3 q = cross(rov0, rd); + float d = 1.0f / dot(rd, n); + float u = d * -dot(q, v2v0); + float v = d * dot(q, v1v0); + float t = d * -dot(n, rov0); + if (u < 0.0f || u > 1.0f || v < 0.0f || (u+v) > 1.0f || t < 0.0f) { + t = std::numeric_limits::max(); + } + return t; + } + + NGP_HOST_DEVICE float ray_intersect(const vec3 &ro, const vec3 &rd) const { + vec3 n; + return ray_intersect(ro, rd, n); + } + + // based on https://www.iquilezles.org/www/articles/distfunctions/distfunctions.htm + NGP_HOST_DEVICE float distance_sq(const vec3& pos) const { + vec3 v21 = b - a; vec3 p1 = pos - a; + vec3 v32 = c - b; vec3 p2 = pos - b; + vec3 v13 = a - c; vec3 p3 = pos - c; + vec3 nor = cross(v21, v13); + + return + // inside/outside test + (sign(dot(cross(v21, nor), p1)) + sign(dot(cross(v32, nor), p2)) + sign(dot(cross(v13, nor), p3)) < 2.0f) + ? + // 3 edges + min(vec3{ + length2(v21 * clamp(dot(v21, p1) / length2(v21), 0.0f, 1.0f)-p1), + length2(v32 * clamp(dot(v32, p2) / length2(v32), 0.0f, 1.0f)-p2), + length2(v13 * clamp(dot(v13, p3) / length2(v13), 0.0f, 1.0f)-p3), + }) + : + // 1 face + dot(nor, p1) * dot(nor, p1) / length2(nor); + } + + NGP_HOST_DEVICE float distance(const vec3& pos) const { + return sqrt(distance_sq(pos)); + } + + NGP_HOST_DEVICE bool point_in_triangle(const vec3& p) const { + // Move the triangle so that the point becomes the + // triangles origin + vec3 local_a = a - p; + vec3 local_b = b - p; + vec3 local_c = c - p; + + // The point should be moved too, so they are both + // relative, but because we don't use p in the + // equation anymore, we don't need it! + // p -= p; + + // Compute the normal vectors for triangles: + // u = normal of PBC + // v = normal of PCA + // w = normal of PAB + + vec3 u = cross(local_b, local_c); + vec3 v = cross(local_c, local_a); + vec3 w = cross(local_a, local_b); + + // Test to see if the normals are facing the same direction. + // If yes, the point is inside, otherwise it isn't. + return dot(u, v) >= 0.0f && dot(u, w) >= 0.0f; + } + + NGP_HOST_DEVICE vec3 closest_point_to_line(const vec3& a, const vec3& b, const vec3& c) const { + float t = dot(c - a, b - a) / dot(b - a, b - a); + t = max(min(t, 1.0f), 0.0f); + return a + t * (b - a); + } + + NGP_HOST_DEVICE vec3 closest_point(vec3 point) const { + point -= dot(normal(), point - a) * normal(); + + if (point_in_triangle(point)) { + return point; + } + + vec3 c1 = closest_point_to_line(a, b, point); + vec3 c2 = closest_point_to_line(b, c, point); + vec3 c3 = closest_point_to_line(c, a, point); + + float mag1 = length2(point - c1); + float mag2 = length2(point - c2); + float mag3 = length2(point - c3); + + float min = tcnn::min(vec3{mag1, mag2, mag3}); + + if (min == mag1) { + return c1; + } else if (min == mag2) { + return c2; + } else { + return c3; + } + } + + NGP_HOST_DEVICE vec3 centroid() const { + return (a + b + c) / 3.0f; + } + + NGP_HOST_DEVICE float centroid(int axis) const { + return (a[axis] + b[axis] + c[axis]) / 3; + } + + NGP_HOST_DEVICE void get_vertices(vec3 v[3]) const { + v[0] = a; + v[1] = b; + v[2] = c; + } + + vec3 a, b, c; +}; + +} diff --git a/gui/include/tiny-cuda-nn/common.h b/gui/include/tiny-cuda-nn/common.h new file mode 100644 index 0000000000000000000000000000000000000000..5df4db84ea2cdfc54876c62342e9979ca3abf19f --- /dev/null +++ b/gui/include/tiny-cuda-nn/common.h @@ -0,0 +1,440 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common.h + * @author Thomas Müller and Nikolaus Binder, NVIDIA + * @brief Common utilities that are needed by pretty much every component of this framework. + */ + +#pragma once + +#if defined(_WIN32) && !defined(NOMINMAX) +# define NOMINMAX +#endif + +#include +#include +#include +#include + +#if defined(__CUDACC__) +# include +#endif + +////////////////////////////////////// +// CUDA ERROR HANDLING (EXCEPTIONS) // +////////////////////////////////////// + +#define STRINGIFY(x) #x +#define STR(x) STRINGIFY(x) +#define FILE_LINE __FILE__ ":" STR(__LINE__) + +#if defined(__CUDA_ARCH__) + #define TCNN_PRAGMA_UNROLL _Pragma("unroll") + #define TCNN_PRAGMA_NO_UNROLL _Pragma("unroll 1") +#else + #define TCNN_PRAGMA_UNROLL + #define TCNN_PRAGMA_NO_UNROLL +#endif + +#ifdef __CUDACC__ +# ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ +# pragma nv_diag_suppress = unsigned_compare_with_zero +# else +# pragma diag_suppress = unsigned_compare_with_zero +# endif +#endif + +#if defined(__CUDACC__) || (defined(__clang__) && defined(__CUDA__)) +#define TCNN_HOST_DEVICE __host__ __device__ +#define TCNN_DEVICE __device__ +#define TCNN_HOST __host__ +#else +#define TCNN_HOST_DEVICE +#define TCNN_DEVICE +#define TCNN_HOST +#endif + +#ifndef TCNN_MIN_GPU_ARCH +#warning TCNN_MIN_GPU_ARCH was not defined. Using default value 75. +#define TCNN_MIN_GPU_ARCH 75 +#endif + +#include + +#if defined(__CUDA_ARCH__) +static_assert(__CUDA_ARCH__ >= TCNN_MIN_GPU_ARCH * 10, "MIN_GPU_ARCH=" STR(TCNN_MIN_GPU_ARCH) "0 must bound __CUDA_ARCH__=" STR(__CUDA_ARCH__) " from below, but doesn't."); +#endif + +namespace tcnn { + +static constexpr uint32_t MIN_GPU_ARCH = TCNN_MIN_GPU_ARCH; + +// When TCNN managed its model parameters, they are always aligned, +// which yields performance benefits in practice. However, parameters +// supplied by PyTorch are not necessarily aligned. The following +// variable controls whether TCNN must deal with unaligned data. +#if defined(TCNN_PARAMS_UNALIGNED) +static constexpr bool PARAMS_ALIGNED = false; +#else +static constexpr bool PARAMS_ALIGNED = true; +#endif + +#define TCNN_HALF_PRECISION (!(TCNN_MIN_GPU_ARCH == 61 || TCNN_MIN_GPU_ARCH <= 52)) + +// TCNN has the following behavior depending on GPU arch. +// Refer to the first row of the table at the following URL for information about +// when to pick fp16 versus fp32 precision for maximum performance. +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions__throughput-native-arithmetic-instructions +// +// GPU Arch | FullyFusedMLP supported | CUTLASS SmArch supported | Precision +// ----------|-------------------------|--------------------------|-------------------------- +// 80-90 | yes | 80 | __half +// 75 | yes | 75 | __half +// 70 | no | 70 | __half +// 53-60, 62 | no | 70 | __half (no tensor cores) +// <=52, 61 | no | 70 | float (no tensor cores) + +#if defined(__CUDACC__) +# if TCNN_HALF_PRECISION +using network_precision_t = __half; +# else +using network_precision_t = float; +# endif + +// Optionally: set the precision to `float` to disable tensor cores and debug potential +// problems with mixed-precision training. +// using network_precision_t = float; +#endif + +enum class Activation { + ReLU, + LeakyReLU, + Exponential, + Sine, + Sigmoid, + Squareplus, + Softplus, + Tanh, + None, +}; + +enum class GridType { + Hash, + Dense, + Tiled, +}; + +enum class HashType { + Prime, + CoherentPrime, + ReversedPrime, + Rng, + BaseConvert, +}; + +enum class InterpolationType { + Nearest, + Linear, + Smoothstep, +}; + +enum class MatrixLayout { + RowMajor = 0, + SoA = 0, // For data matrices TCNN's convention is RowMajor == SoA (struct of arrays) + ColumnMajor = 1, + AoS = 1, +}; + +static constexpr MatrixLayout RM = MatrixLayout::RowMajor; +static constexpr MatrixLayout SoA = MatrixLayout::SoA; +static constexpr MatrixLayout CM = MatrixLayout::ColumnMajor; +static constexpr MatrixLayout AoS = MatrixLayout::AoS; + +enum class ReductionType { + Concatenation, + Sum, + Product, +}; + +////////////////// +// Misc helpers // +////////////////// + +inline constexpr TCNN_HOST_DEVICE float PI() { return 3.14159265358979323846f; } + +template +TCNN_HOST_DEVICE void host_device_swap(T& a, T& b) { + T c(a); a=b; b=c; +} + +template +TCNN_HOST_DEVICE T gcd(T a, T b) { + while (a != 0) { + b %= a; + host_device_swap(a, b); + } + return b; +} + +template +TCNN_HOST_DEVICE T lcm(T a, T b) { + T tmp = gcd(a, b); + return tmp ? (a / tmp) * b : 0; +} + +template +TCNN_HOST_DEVICE T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +TCNN_HOST_DEVICE T next_multiple(T val, T divisor) { + return div_round_up(val, divisor) * divisor; +} + +template +TCNN_HOST_DEVICE T previous_multiple(T val, T divisor) { + return (val / divisor) * divisor; +} + +template +constexpr TCNN_HOST_DEVICE bool is_pot(T val) { + return (val & (val - 1)) == 0; +} + +inline constexpr TCNN_HOST_DEVICE uint32_t next_pot(uint32_t v) { + --v; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + return v+1; +} + +template constexpr TCNN_HOST_DEVICE float default_loss_scale(); +template <> constexpr TCNN_HOST_DEVICE float default_loss_scale() { return 1.0f; } +#ifdef __CUDACC__ +template <> constexpr TCNN_HOST_DEVICE float default_loss_scale<__half>() { return 128.0f; } +#endif + +constexpr uint32_t BATCH_SIZE_GRANULARITY = 256; +constexpr uint32_t N_THREADS_LINEAR = 128; +constexpr uint32_t WARP_SIZE = 32; + +// Lower-case constants kept for backward compatibility with user code. +constexpr uint32_t batch_size_granularity = BATCH_SIZE_GRANULARITY; +constexpr uint32_t n_threads_linear = N_THREADS_LINEAR; + +template +constexpr TCNN_HOST_DEVICE uint32_t n_blocks_linear(T n_elements, uint32_t n_threads = N_THREADS_LINEAR) { + return (uint32_t)div_round_up(n_elements, (T)n_threads); +} + +template +struct PitchedPtr { + TCNN_HOST_DEVICE PitchedPtr() : ptr{nullptr}, stride_in_bytes{sizeof(T)} {} + TCNN_HOST_DEVICE PitchedPtr(T* ptr, size_t stride_in_elements, size_t offset = 0, size_t extra_stride_bytes = 0) : ptr{ptr + offset}, stride_in_bytes{stride_in_elements * sizeof(T) + extra_stride_bytes} {} + + template + TCNN_HOST_DEVICE explicit PitchedPtr(PitchedPtr other) : ptr{(T*)other.ptr}, stride_in_bytes{other.stride_in_bytes} {} + + TCNN_HOST_DEVICE T* operator()(uint32_t y) const { + return (T*)((const char*)ptr + y * stride_in_bytes); + } + + TCNN_HOST_DEVICE void operator+=(uint32_t y) { + ptr = (T*)((const char*)ptr + y * stride_in_bytes); + } + + TCNN_HOST_DEVICE void operator-=(uint32_t y) { + ptr = (T*)((const char*)ptr - y * stride_in_bytes); + } + + TCNN_HOST_DEVICE explicit operator bool() const { + return ptr; + } + + T* ptr; + size_t stride_in_bytes; +}; + +template +struct MatrixView { + TCNN_HOST_DEVICE MatrixView() : data{nullptr}, stride_i{0}, stride_j{0} {} + TCNN_HOST_DEVICE MatrixView(T* data, STRIDE_T stride_i, STRIDE_T stride_j) : data{data}, stride_i{stride_i}, stride_j{stride_j} {} + TCNN_HOST_DEVICE MatrixView(const MatrixView>& other) : data{other.data}, stride_i{other.stride_i}, stride_j{other.stride_j} {} + + using signed_index_t = std::make_signed_t; + using unsigned_index_t = std::make_unsigned_t; + + // Signed indexing + TCNN_HOST_DEVICE T& operator()(signed_index_t i, signed_index_t j = 0) const { + return data[i * (std::ptrdiff_t)stride_i + j * (std::ptrdiff_t)stride_j]; + } + + TCNN_HOST_DEVICE void advance(signed_index_t m, signed_index_t n) { + data += m * (std::ptrdiff_t)stride_i + n * (std::ptrdiff_t)stride_j; + } + + TCNN_HOST_DEVICE void advance_rows(signed_index_t m) { + advance(m, 0); + } + + TCNN_HOST_DEVICE void advance_cols(signed_index_t n) { + advance(0, n); + } + + // Unsigned indexing + TCNN_HOST_DEVICE T& operator()(unsigned_index_t i, unsigned_index_t j = 0) const { + return data[i * (size_t)stride_i + j * (size_t)stride_j]; + } + + TCNN_HOST_DEVICE void advance(unsigned_index_t m, unsigned_index_t n) { + data += m * (size_t)stride_i + n * (size_t)stride_j; + } + + TCNN_HOST_DEVICE void advance_rows(unsigned_index_t m) { + advance(m, (unsigned_index_t)0); + } + + TCNN_HOST_DEVICE void advance_cols(unsigned_index_t n) { + advance((unsigned_index_t)0, n); + } + + template + TCNN_HOST_DEVICE tvec, N> row(unsigned_index_t m) const { + tvec, N> result; + TCNN_PRAGMA_UNROLL + for (unsigned_index_t i = 0; i < N; ++i) { + result[i] = (*this)(m, i); + } + return result; + } + + template + TCNN_HOST_DEVICE tvec, N> col(unsigned_index_t n) const { + tvec, N> result; + TCNN_PRAGMA_UNROLL + for (unsigned_index_t i = 0; i < N; ++i) { + result[i] = (*this)(i, n); + } + return result; + } + + template + TCNN_HOST_DEVICE void set_row(unsigned_index_t m, const tvec& val) { + TCNN_PRAGMA_UNROLL + for (unsigned_index_t i = 0; i < N; ++i) { + (*this)(m, i) = val[i]; + } + } + + template + TCNN_HOST_DEVICE void set_col(unsigned_index_t n, const tvec& val) { + TCNN_PRAGMA_UNROLL + for (unsigned_index_t i = 0; i < N; ++i) { + (*this)(i, n) = val[i]; + } + } + + TCNN_HOST_DEVICE explicit operator bool() const { + return data; + } + + T* data; + STRIDE_T stride_i, stride_j; +}; + +template +struct Interval { + // Inclusive start, exclusive end + T start, end; + + TCNN_HOST_DEVICE bool operator<(const Interval& other) const { + // This operator is used to sort non-overlapping intervals. Since intervals + // may be empty, the second half of the following expression is required to + // resolve ambiguity when `end` of adjacent empty intervals is equal. + return end < other.end || (end == other.end && start < other.start); + } + + TCNN_HOST_DEVICE bool overlaps(const Interval& other) const { + return !intersect(other).empty(); + } + + TCNN_HOST_DEVICE Interval intersect(const Interval& other) const { + return {std::max(start, other.start), std::min(end, other.end)}; + } + + TCNN_HOST_DEVICE bool valid() const { + return end >= start; + } + + TCNN_HOST_DEVICE bool empty() const { + return end <= start; + } + + TCNN_HOST_DEVICE T size() const { + return end - start; + } +}; + +struct Ray { + vec3 o; + vec3 d; + + TCNN_HOST_DEVICE vec3 operator()(float t) const { + return o + t * d; + } + + TCNN_HOST_DEVICE void advance(float t) { + o += d * t; + } + + TCNN_HOST_DEVICE float distance_to(const vec3& p) const { + vec3 nearest = p - o; + nearest -= d * dot(nearest, d) / length2(d); + return length(nearest); + } + + TCNN_HOST_DEVICE bool is_valid() const { + return d != vec3(0.0f); + } + + static TCNN_HOST_DEVICE Ray invalid() { + return {{0.0f, 0.0f, 0.0f}, {0.0f, 0.0f, 0.0f}}; + } +}; + +// Helpful data structure to represent ray-object intersections +template +struct PayloadAndIdx { + T t; + int64_t idx; + + // Sort in descending order + TCNN_HOST_DEVICE bool operator<(const PayloadAndIdx& other) { + return t < other.t; + } +}; + +using DistAndIdx = PayloadAndIdx; +using IntervalAndIdx = PayloadAndIdx>; + + +} diff --git a/gui/include/tiny-cuda-nn/common_device.h b/gui/include/tiny-cuda-nn/common_device.h new file mode 100644 index 0000000000000000000000000000000000000000..63c7f082625658ae9e7b2c5b45f0e210a7739af7 --- /dev/null +++ b/gui/include/tiny-cuda-nn/common_device.h @@ -0,0 +1,1284 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common_device.h + * @author Thomas Müller & Nikolaus Binder, NVIDIA + * @brief Implementation of various miscellaneous CUDA kernels and + device functions. + */ + +#pragma once + +#include + +#include + +namespace tcnn { + +__forceinline__ __device__ unsigned lane_id() { + unsigned ret; + asm volatile("mov.u32 %0, %laneid;" : "=r"(ret)); + return ret; +} + +static constexpr float SQRT2 = 1.41421356237309504880f; + +__host__ __device__ inline float logistic(const float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__host__ __device__ inline float logit(const float x) { + return -logf(1.0f / (fminf(fmaxf(x, 1e-9f), 1.0f - 1e-9f)) - 1.0f); +} + +template +__host__ __device__ inline void softmax(float vals[N]) { + float total = 0; + + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + vals[i] = expf(vals[i]); + total += vals[i]; + } + + const float inv_total = 1.0f / total; + + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + vals[i] *= inv_total; + } +} + +template +__host__ __device__ inline float softmax(const float vals[N], uint32_t idx) { + float total = 0; + + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + total += expf(vals[i]); + } + + return expf(vals[idx]) / total; +} + +template +struct VectorFragment { + static const uint32_t num_elements = V::size(); + V x; +}; + +template +using vector_fragment_t = VectorFragment>; + +template +__host__ __device__ T relu(T val) { + return (T)max((float)val, 0.0f); +} + +template <> +inline __host__ __device__ half relu(half val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmax(val, (half)0.0f); +#else + return (half)relu((float)val); +#endif +} + +static constexpr float K_ACT = 10.0f; + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + result = frag; +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = relu((T)frag.x[t]); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f ? 1.0f : 0.01f); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(expf((float)frag.x[t])); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(sinf((float)frag.x[t])); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(logistic((float)frag.x[t])); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float x = (float)frag.x[t] * K_ACT; + result.x[t] = (T)(0.5f * (x + sqrtf(x * x + 4)) / K_ACT); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(logf(expf((float)frag.x[t] * K_ACT) + 1.0f) / K_ACT); + } +} + +template = 0> +__host__ __device__ void warp_activation(const fragment_t& frag, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(tanhf((float)frag.x[t])); + } +} + +template +__host__ __device__ void warp_activation(Activation activation, const fragment_t& frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: warp_activation(frag, result); return; + case Activation::LeakyReLU: warp_activation(frag, result); return; + case Activation::Exponential: warp_activation(frag, result); return; + case Activation::Sine: warp_activation(frag, result); return; + case Activation::Sigmoid: warp_activation(frag, result); return; + case Activation::Squareplus: warp_activation(frag, result); return; + case Activation::Softplus: warp_activation(frag, result); return; + case Activation::Tanh: warp_activation(frag, result); return; + case Activation::None: warp_activation(frag, result); return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t& frag) { + fragment_t result; + warp_activation(activation, frag, result); + return result; +} + +template +__host__ __device__ tvec vec_activation(tvec& v) { + using fragment_t = vector_fragment_t; + warp_activation(*(fragment_t*)&v, *(fragment_t*)&v); +} + +template +__host__ __device__ tvec vec_activation(Activation activation, const tvec& v) { + auto result = warp_activation(activation, vector_fragment_t{v}); + return result.x; +} + +template +__host__ __device__ T activation(Activation activation, T val) { + return vec_activation(activation, tvec{val})[0]; +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + result = frag; +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f ? 1.0f : 0.01f); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(expf(forward_frag_in.x[t])); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(cosf(forward_frag_in.x[t])); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float x = logistic(forward_frag_in.x[t]); + result.x[t] = frag.x[t] * (T)(x * (1.0f - x)); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float x = (float)forward_frag_in.x[t] * K_ACT; + float y = 0.5f * (x + sqrtf(x * x + 4)); + result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float tmp = expf((float)forward_frag_in.x[t] * K_ACT); + result.x[t] = frag.x[t] * (T)(tmp / (tmp + 1)); + } +} + +template = 0> +__host__ __device__ void warp_activation_backward_in(const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float x = tanhf(forward_frag_in.x[t]); + result.x[t] = frag.x[t] * (T)(1.0f - x * x); + } +} + +template +__host__ __device__ void warp_activation_backward_in(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + switch (activation) { + case Activation::ReLU: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::LeakyReLU: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Exponential: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Sine: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Sigmoid: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Squareplus: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Softplus: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::Tanh: warp_activation_backward_in(frag, forward_frag_in, result); return; + case Activation::None: warp_activation_backward_in(frag, forward_frag_in, result); return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation_backward_in(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag_in) { + fragment_t result; + warp_activation_backward_in(activation, frag, forward_frag_in, result); + return result; +} + +template +__host__ __device__ tvec vec_activation_backward_in(tvec& v, const tvec& forward_v_in) { + using fragment_t = vector_fragment_t; + warp_activation_backward_in(*(fragment_t*)&v, *(fragment_t*)&forward_v_in, *(fragment_t*)&v); +} + +template +__host__ __device__ tvec vec_activation_backward_in(Activation activation, const tvec& v, const tvec& forward_v_in) { + auto result = warp_activation_backward_in(activation, vector_fragment_t{v}, vector_fragment_t{forward_v_in}); + return result.x; +} + +template +__host__ __device__ T activation_backward_in(Activation activation, T val, T forward_val_in) { + return vec_activation_backward_in(activation, tvec{val}, tvec{forward_val_in})[0]; +} + +template +__host__ __device__ void warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f); + } + return; + case Activation::LeakyReLU: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f ? 1.0f : 0.01f); + } + return; + case Activation::Exponential: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * forward_frag.x[t]; + } + return; + case Activation::Sine: + // Sine requires stored pre-activations, which we don't have. We only + // write out the post-activations. + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + case Activation::Sigmoid: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] * (T)(1.0f - (float)forward_frag.x[t])); + } + return; + case Activation::Squareplus: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + float y = (float)forward_frag.x[t] * K_ACT; + result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); + } + return; + case Activation::Softplus: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(1.0f - expf(-(float)forward_frag.x[t] * K_ACT)); + } + return; + case Activation::Tanh: + TCNN_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(1.0f - ((float)forward_frag.x[t] * (float)forward_frag.x[t])); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag) { + fragment_t result; + warp_activation_backward(activation, frag, forward_frag, result); + return result; +} + +template +__host__ __device__ tvec vec_activation_backward(Activation activation, const tvec& v, const tvec& forward_v) { + auto result = warp_activation_backward(activation, vector_fragment_t{v}, vector_fragment_t{forward_v}); + return result.x; +} + +template +__host__ __device__ T activation_backward(Activation activation, T val, T forward_val) { + return vec_activation_backward(activation, tvec{val}, tvec{forward_val})[0]; +} + +#define IQ_DEFAULT_STATE 0x853c49e6748fea9bULL + +/// Based on https://www.iquilezles.org/www/articles/sfrand/sfrand.htm +struct iqrand { + /// Initialize the pseudorandom number generator with default seed + TCNN_HOST_DEVICE iqrand() : state((uint32_t)IQ_DEFAULT_STATE) {} + + /// Initialize the pseudorandom number generator with the \ref seed() function + TCNN_HOST_DEVICE iqrand(uint32_t initstate) : state(initstate) {} + + /// Generate a single precision floating point value on the interval [0, 1) + TCNN_HOST_DEVICE float next_float() { + union { + float fres; + unsigned int ires; + }; + + state *= 16807; + ires = ((((unsigned int)state)>>9 ) | 0x3f800000); + return fres - 1.0f; + } + + uint32_t state; // RNG state. All values are possible. +}; + +using default_rng_t = pcg32; + +__device__ inline float random_val(uint32_t seed, uint32_t idx) { + default_rng_t rng{seed}; + rng.advance(idx); + return rng.next_float(); +} + +template +__device__ void sh_enc(uint32_t degree, float x, float y, float z, ARRAY_T& data_out) { + // Let compiler figure out how to sequence/reorder these calculations w.r.t. branches + float xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z; + float x4=x2*x2, y4=y2*y2, z4=z2*z2; + float x6=x4*x2, y6=y4*y2, z6=z4*z2; + + // SH polynomials generated using scripts/gen_sh.py based on the recurrence relations in appendix A1 of https://www.ppsloan.org/publications/StupidSH36.pdf + data_out(0) = (T)(0.28209479177387814f); // 1/(2*sqrt(pi)) + if (degree <= 1) { return; } + data_out(1) = (T)(-0.48860251190291987f*y); // -sqrt(3)*y/(2*sqrt(pi)) + data_out(2) = (T)(0.48860251190291987f*z); // sqrt(3)*z/(2*sqrt(pi)) + data_out(3) = (T)(-0.48860251190291987f*x); // -sqrt(3)*x/(2*sqrt(pi)) + if (degree <= 2) { return; } + data_out(4) = (T)(1.0925484305920792f*xy); // sqrt(15)*xy/(2*sqrt(pi)) + data_out(5) = (T)(-1.0925484305920792f*yz); // -sqrt(15)*yz/(2*sqrt(pi)) + data_out(6) = (T)(0.94617469575755997f*z2 - 0.31539156525251999f); // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + data_out(7) = (T)(-1.0925484305920792f*xz); // -sqrt(15)*xz/(2*sqrt(pi)) + data_out(8) = (T)(0.54627421529603959f*x2 - 0.54627421529603959f*y2); // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (degree <= 3) { return; } + data_out(9) = (T)(0.59004358992664352f*y*(-3.0f*x2 + y2)); // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + data_out(10) = (T)(2.8906114426405538f*xy*z); // sqrt(105)*xy*z/(2*sqrt(pi)) + data_out(11) = (T)(0.45704579946446572f*y*(1.0f - 5.0f*z2)); // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + data_out(12) = (T)(0.3731763325901154f*z*(5.0f*z2 - 3.0f)); // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + data_out(13) = (T)(0.45704579946446572f*x*(1.0f - 5.0f*z2)); // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + data_out(14) = (T)(1.4453057213202769f*z*(x2 - y2)); // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + data_out(15) = (T)(0.59004358992664352f*x*(-x2 + 3.0f*y2)); // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (degree <= 4) { return; } + data_out(16) = (T)(2.5033429417967046f*xy*(x2 - y2)); // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + data_out(17) = (T)(1.7701307697799304f*yz*(-3.0f*x2 + y2)); // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + data_out(18) = (T)(0.94617469575756008f*xy*(7.0f*z2 - 1.0f)); // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + data_out(19) = (T)(0.66904654355728921f*yz*(3.0f - 7.0f*z2)); // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + data_out(20) = (T)(-3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f); // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + data_out(21) = (T)(0.66904654355728921f*xz*(3.0f - 7.0f*z2)); // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + data_out(22) = (T)(0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f)); // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + data_out(23) = (T)(1.7701307697799304f*xz*(-x2 + 3.0f*y2)); // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + data_out(24) = (T)(-3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4); // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (degree <= 5) { return; } + data_out(25) = (T)(0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4)); // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + data_out(26) = (T)(8.3026492595241645f*xy*z*(x2 - y2)); // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + data_out(27) = (T)(-0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f)); // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + data_out(28) = (T)(4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f)); // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + data_out(29) = (T)(0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f)); // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + data_out(30) = (T)(0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f)); // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + data_out(31) = (T)(0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f)); // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + data_out(32) = (T)(2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f)); // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + data_out(33) = (T)(-0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f)); // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + data_out(34) = (T)(2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4)); // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + data_out(35) = (T)(0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4)); // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (degree <= 6) { return; } + data_out(36) = (T)(1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4)); // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + data_out(37) = (T)(2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4)); // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + data_out(38) = (T)(2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f)); // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + data_out(39) = (T)(-0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f)); // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + data_out(40) = (T)(0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f)); // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + data_out(41) = (T)(0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f)); // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + data_out(42) = (T)(6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f); // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + data_out(43) = (T)(0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f)); // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + data_out(44) = (T)(0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f)); // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + data_out(45) = (T)(-0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f)); // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + data_out(46) = (T)(0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4)); // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + data_out(47) = (T)(2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4)); // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + data_out(48) = (T)(10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6); // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (degree <= 7) { return; } + data_out(49) = (T)(0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6)); // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + data_out(50) = (T)(5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4)); // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + data_out(51) = (T)(-0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4)); // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + data_out(52) = (T)(4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f)); // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + data_out(53) = (T)(-0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f)); // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + data_out(54) = (T)(0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f)); // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + data_out(55) = (T)(0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f)); // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + data_out(56) = (T)(0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f)); // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + data_out(57) = (T)(0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f)); // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + data_out(58) = (T)(0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f)); // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + data_out(59) = (T)(-0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f)); // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + data_out(60) = (T)(1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4)); // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + data_out(61) = (T)(-0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4)); // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + data_out(62) = (T)(2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6)); // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + data_out(63) = (T)(0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6)); // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) +} + +template +__device__ vec3 sh_enc_grad(uint32_t degree, float x, float y, float z, const ARRAY_T& dL_dy) { + // Let compiler figure out how to sequence/reorder these calculations w.r.t. branches + float xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z; + float x4=x2*x2, y4=y2*y2, z4=z2*z2; + float x6=x4*x2, y6=y4*y2, z6=z4*z2; + + vec3 d(0.0f); + + // d.x += (float)dL_dy(0) * (0); // 0 + // d.y += (float)dL_dy(0) * (0); // 0 + // d.z += (float)dL_dy(0) * (0); // 0 + if (degree <= 1) { return d; } + // d.x += (float)dL_dy(1) * (0); // 0 + d.y += (float)dL_dy(1) * (-0.48860251190291992); // -sqrt(3)/(2*sqrt(pi)) + // d.z += (float)dL_dy(1) * (0); // 0 + // d.x += (float)dL_dy(2) * (0); // 0 + // d.y += (float)dL_dy(2) * (0); // 0 + d.z += (float)dL_dy(2) * (0.48860251190291992); // sqrt(3)/(2*sqrt(pi)) + d.x += (float)dL_dy(3) * (-0.48860251190291992); // -sqrt(3)/(2*sqrt(pi)) + // d.y += (float)dL_dy(3) * (0); // 0 + // d.z += (float)dL_dy(3) * (0); // 0 + if (degree <= 2) { return d; } + d.x += (float)dL_dy(4) * (1.0925484305920792*y); // sqrt(15)*y/(2*sqrt(pi)) + d.y += (float)dL_dy(4) * (1.0925484305920792*x); // sqrt(15)*x/(2*sqrt(pi)) + // d.z += (float)dL_dy(4) * (0); // 0 + // d.x += (float)dL_dy(5) * (0); // 0 + d.y += (float)dL_dy(5) * (-1.0925484305920792*z); // -sqrt(15)*z/(2*sqrt(pi)) + d.z += (float)dL_dy(5) * (-1.0925484305920792*y); // -sqrt(15)*y/(2*sqrt(pi)) + // d.x += (float)dL_dy(6) * (0); // 0 + // d.y += (float)dL_dy(6) * (0); // 0 + d.z += (float)dL_dy(6) * (1.8923493915151202*z); // 3*sqrt(5)*z/(2*sqrt(pi)) + d.x += (float)dL_dy(7) * (-1.0925484305920792*z); // -sqrt(15)*z/(2*sqrt(pi)) + // d.y += (float)dL_dy(7) * (0); // 0 + d.z += (float)dL_dy(7) * (-1.0925484305920792*x); // -sqrt(15)*x/(2*sqrt(pi)) + d.x += (float)dL_dy(8) * (1.0925484305920792*x); // sqrt(15)*x/(2*sqrt(pi)) + d.y += (float)dL_dy(8) * (-1.0925484305920792*y); // -sqrt(15)*y/(2*sqrt(pi)) + // d.z += (float)dL_dy(8) * (0); // 0 + if (degree <= 3) { return d; } + d.x += (float)dL_dy(9) * (-3.5402615395598609*xy); // -3*sqrt(70)*xy/(4*sqrt(pi)) + d.y += (float)dL_dy(9) * (-1.7701307697799304*x2 + 1.7701307697799304*y2); // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + // d.z += (float)dL_dy(9) * (0); // 0 + d.x += (float)dL_dy(10) * (2.8906114426405538*yz); // sqrt(105)*yz/(2*sqrt(pi)) + d.y += (float)dL_dy(10) * (2.8906114426405538*xz); // sqrt(105)*xz/(2*sqrt(pi)) + d.z += (float)dL_dy(10) * (2.8906114426405538*xy); // sqrt(105)*xy/(2*sqrt(pi)) + // d.x += (float)dL_dy(11) * (0); // 0 + d.y += (float)dL_dy(11) * (0.45704579946446572 - 2.2852289973223288*z2); // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + d.z += (float)dL_dy(11) * (-4.5704579946446566*yz); // -5*sqrt(42)*yz/(4*sqrt(pi)) + // d.x += (float)dL_dy(12) * (0); // 0 + // d.y += (float)dL_dy(12) * (0); // 0 + d.z += (float)dL_dy(12) * (5.597644988851731*z2 - 1.1195289977703462); // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + d.x += (float)dL_dy(13) * (0.45704579946446572 - 2.2852289973223288*z2); // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + // d.y += (float)dL_dy(13) * (0); // 0 + d.z += (float)dL_dy(13) * (-4.5704579946446566*xz); // -5*sqrt(42)*xz/(4*sqrt(pi)) + d.x += (float)dL_dy(14) * (2.8906114426405538*xz); // sqrt(105)*xz/(2*sqrt(pi)) + d.y += (float)dL_dy(14) * (-2.8906114426405538*yz); // -sqrt(105)*yz/(2*sqrt(pi)) + d.z += (float)dL_dy(14) * (1.4453057213202769*x2 - 1.4453057213202769*y2); // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + d.x += (float)dL_dy(15) * (-1.7701307697799304*x2 + 1.7701307697799304*y2); // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + d.y += (float)dL_dy(15) * (3.5402615395598609*xy); // 3*sqrt(70)*xy/(4*sqrt(pi)) + // d.z += (float)dL_dy(15) * (0); // 0 + if (degree <= 4) { return d; } + d.x += (float)dL_dy(16) * (2.5033429417967046*y*(3.0*x2 - y2)); // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + d.y += (float)dL_dy(16) * (2.5033429417967046*x*(x2 - 3.0*y2)); // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + // d.z += (float)dL_dy(16) * (0); // 0 + d.x += (float)dL_dy(17) * (-10.620784618679583*xy*z); // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + d.y += (float)dL_dy(17) * (5.3103923093397913*z*(-x2 + y2)); // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + d.z += (float)dL_dy(17) * (1.7701307697799304*y*(-3.0*x2 + y2)); // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + d.x += (float)dL_dy(18) * (0.94617469575756008*y*(7.0*z2 - 1.0)); // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + d.y += (float)dL_dy(18) * (0.94617469575756008*x*(7.0*z2 - 1.0)); // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + d.z += (float)dL_dy(18) * (13.246445740605839*xy*z); // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + // d.x += (float)dL_dy(19) * (0); // 0 + d.y += (float)dL_dy(19) * (0.66904654355728921*z*(3.0 - 7.0*z2)); // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + d.z += (float)dL_dy(19) * (2.0071396306718676*y*(1.0 - 7.0*z2)); // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + // d.x += (float)dL_dy(20) * (0); // 0 + // d.y += (float)dL_dy(20) * (0); // 0 + d.z += (float)dL_dy(20) * (14.809976568128603*z*z2 - 6.3471328149122579*z); // (105*z**3 - 45*z)/(4*sqrt(pi)) + d.x += (float)dL_dy(21) * (0.66904654355728921*z*(3.0 - 7.0*z2)); // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + // d.y += (float)dL_dy(21) * (0); // 0 + d.z += (float)dL_dy(21) * (2.0071396306718676*x*(1.0 - 7.0*z2)); // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + d.x += (float)dL_dy(22) * (0.94617469575756008*x*(7.0*z2 - 1.0)); // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + d.y += (float)dL_dy(22) * (0.94617469575756008*y*(1.0 - 7.0*z2)); // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + d.z += (float)dL_dy(22) * (6.6232228703029197*z*(x2 - y2)); // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + d.x += (float)dL_dy(23) * (5.3103923093397913*z*(-x2 + y2)); // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + d.y += (float)dL_dy(23) * (10.620784618679583*xy*z); // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + d.z += (float)dL_dy(23) * (1.7701307697799304*x*(-x2 + 3.0*y2)); // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + d.x += (float)dL_dy(24) * (2.5033429417967046*x*(x2 - 3.0*y2)); // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + d.y += (float)dL_dy(24) * (2.5033429417967046*y*(-3.0*x2 + y2)); // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + // d.z += (float)dL_dy(24) * (0); // 0 + if (degree <= 5) { return d; } + d.x += (float)dL_dy(25) * (13.127641136803401*xy*(-x2 + y2)); // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + d.y += (float)dL_dy(25) * (19.6914617052051*x2*y2 - 3.2819102842008503*x4 - 3.2819102842008503*y4); // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + // d.z += (float)dL_dy(25) * (0); // 0 + d.x += (float)dL_dy(26) * (8.3026492595241645*yz*(3.0*x2 - y2)); // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + d.y += (float)dL_dy(26) * (8.3026492595241645*xz*(x2 - 3.0*y2)); // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + d.z += (float)dL_dy(26) * (8.3026492595241645*xy*(x2 - y2)); // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + d.x += (float)dL_dy(27) * (2.9354297966115022*xy*(1.0 - 9.0*z2)); // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + d.y += (float)dL_dy(27) * (-1.4677148983057511*(x2 - y2)*(9.0*z2 - 1.0)); // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + d.z += (float)dL_dy(27) * (8.8062893898345074*yz*(-3.0*x2 + y2)); // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + d.x += (float)dL_dy(28) * (4.7935367849733241*yz*(3.0*z2 - 1.0)); // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + d.y += (float)dL_dy(28) * (4.7935367849733241*xz*(3.0*z2 - 1.0)); // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + d.z += (float)dL_dy(28) * (4.7935367849733241*xy*(9.0*z2 - 1.0)); // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + // d.x += (float)dL_dy(29) * (0); // 0 + d.y += (float)dL_dy(29) * (6.3412531167397574*z2 - 9.5118796751096362*z4 - 0.45294665119569694); // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + d.z += (float)dL_dy(29) * (12.682506233479513*yz*(1.0 - 3.0*z2)); // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + // d.x += (float)dL_dy(30) * (0); // 0 + // d.y += (float)dL_dy(30) * (0); // 0 + d.z += (float)dL_dy(30) * (-24.559567715218954*z2 + 36.839351572828434*z4 + 1.754254836801354); // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + d.x += (float)dL_dy(31) * (6.3412531167397574*z2 - 9.5118796751096362*z4 - 0.45294665119569694); // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + // d.y += (float)dL_dy(31) * (0); // 0 + d.z += (float)dL_dy(31) * (12.682506233479513*xz*(1.0 - 3.0*z2)); // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + d.x += (float)dL_dy(32) * (4.7935367849733241*xz*(3.0*z2 - 1.0)); // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + d.y += (float)dL_dy(32) * (4.7935367849733241*yz*(1.0 - 3.0*z2)); // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + d.z += (float)dL_dy(32) * (2.3967683924866621*(x2 - y2)*(9.0*z2 - 1.0)); // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + d.x += (float)dL_dy(33) * (-13.209434084751759*x2*z2 + 1.4677148983057511*x2 + 13.209434084751759*y2*z2 - 1.4677148983057511*y2); // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + d.y += (float)dL_dy(33) * (2.9354297966115022*xy*(9.0*z2 - 1.0)); // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + d.z += (float)dL_dy(33) * (8.8062893898345074*xz*(-x2 + 3.0*y2)); // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + d.x += (float)dL_dy(34) * (8.3026492595241645*xz*(x2 - 3.0*y2)); // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + d.y += (float)dL_dy(34) * (8.3026492595241645*yz*(-3.0*x2 + y2)); // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + d.z += (float)dL_dy(34) * (-12.453973889286246*x2*y2 + 2.0756623148810411*x4 + 2.0756623148810411*y4); // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + d.x += (float)dL_dy(35) * (19.6914617052051*x2*y2 - 3.2819102842008503*x4 - 3.2819102842008503*y4); // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(35) * (13.127641136803401*xy*(x2 - y2)); // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + // d.z += (float)dL_dy(35) * (0); // 0 + if (degree <= 6) { return d; } + d.x += (float)dL_dy(36) * (4.0991046311514854*y*(-10.0*x2*y2 + 5.0*x4 + y4)); // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(36) * (4.0991046311514854*x*(-10.0*x2*y2 + x4 + 5.0*y4)); // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + // d.z += (float)dL_dy(36) * (0); // 0 + d.x += (float)dL_dy(37) * (47.332383244635047*xy*z*(-x2 + y2)); // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + d.y += (float)dL_dy(37) * (11.833095811158762*z*(6.0*x2*y2 - x4 - y4)); // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + d.z += (float)dL_dy(37) * (2.3666191622317521*y*(10.0*x2*y2 - 5.0*x4 - y4)); // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(38) * (2.0182596029148963*y*(3.0*x2 - y2)*(11.0*z2 - 1.0)); // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + d.y += (float)dL_dy(38) * (2.0182596029148963*x*(x2 - 3.0*y2)*(11.0*z2 - 1.0)); // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + d.z += (float)dL_dy(38) * (44.401711264127719*xy*z*(x2 - y2)); // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + d.x += (float)dL_dy(39) * (5.5272315570895412*xy*z*(3.0 - 11.0*z2)); // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + d.y += (float)dL_dy(39) * (-2.7636157785447706*z*(x2 - y2)*(11.0*z2 - 3.0)); // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + d.z += (float)dL_dy(39) * (-2.7636157785447706*y*(3.0*x2 - y2)*(11.0*z2 - 1.0)); // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + d.x += (float)dL_dy(40) * (0.92120525951492349*y*(-18.0*z2 + 33.0*z4 + 1.0)); // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + d.y += (float)dL_dy(40) * (0.92120525951492349*x*(-18.0*z2 + 33.0*z4 + 1.0)); // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + d.z += (float)dL_dy(40) * (11.054463114179082*xy*z*(11.0*z2 - 3.0)); // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + // d.x += (float)dL_dy(41) * (0); // 0 + d.y += (float)dL_dy(41) * (0.58262136251873131*z*(30.0*z2 - 33.0*z4 - 5.0)); // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + d.z += (float)dL_dy(41) * (2.9131068125936568*y*(18.0*z2 - 33.0*z4 - 1.0)); // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + // d.x += (float)dL_dy(42) * (0); // 0 + // d.y += (float)dL_dy(42) * (0); // 0 + d.z += (float)dL_dy(42) * (2.6699064952403937*z*(-30.0*z2 + 33.0*z4 + 5.0)); // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + d.x += (float)dL_dy(43) * (0.58262136251873131*z*(30.0*z2 - 33.0*z4 - 5.0)); // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + // d.y += (float)dL_dy(43) * (0); // 0 + d.z += (float)dL_dy(43) * (2.9131068125936568*x*(18.0*z2 - 33.0*z4 - 1.0)); // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + d.x += (float)dL_dy(44) * (0.92120525951492349*x*(-18.0*z2 + 33.0*z4 + 1.0)); // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + d.y += (float)dL_dy(44) * (0.92120525951492349*y*(18.0*z2 - 33.0*z4 - 1.0)); // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + d.z += (float)dL_dy(44) * (5.5272315570895412*z*(x2 - y2)*(11.0*z2 - 3.0)); // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + d.x += (float)dL_dy(45) * (-2.7636157785447706*z*(x2 - y2)*(11.0*z2 - 3.0)); // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + d.y += (float)dL_dy(45) * (5.5272315570895412*xy*z*(11.0*z2 - 3.0)); // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + d.z += (float)dL_dy(45) * (-2.7636157785447706*x*(x2 - 3.0*y2)*(11.0*z2 - 1.0)); // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + d.x += (float)dL_dy(46) * (2.0182596029148963*x*(x2 - 3.0*y2)*(11.0*z2 - 1.0)); // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + d.y += (float)dL_dy(46) * (-2.0182596029148963*y*(3.0*x2 - y2)*(11.0*z2 - 1.0)); // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + d.z += (float)dL_dy(46) * (11.10042781603193*z*(-6.0*x2*y2 + x4 + y4)); // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + d.x += (float)dL_dy(47) * (11.833095811158762*z*(6.0*x2*y2 - x4 - y4)); // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(47) * (47.332383244635047*xy*z*(x2 - y2)); // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + d.z += (float)dL_dy(47) * (2.3666191622317521*x*(10.0*x2*y2 - x4 - 5.0*y4)); // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(48) * (4.0991046311514854*x*(-10.0*x2*y2 + x4 + 5.0*y4)); // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(48) * (4.0991046311514854*y*(10.0*x2*y2 - 5.0*x4 - y4)); // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + // d.z += (float)dL_dy(48) * (0); // 0 + if (degree <= 7) { return d; } + d.x += (float)dL_dy(49) * (9.9002782553443485*xy*(10.0*x2*y2 - 3.0*x4 - 3.0*y4)); // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(49) * (-74.252086915082614*x2*y4 + 74.252086915082614*x4*y2 - 4.9501391276721742*x6 + 4.9501391276721742*y6); // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + // d.z += (float)dL_dy(49) * (0); // 0 + d.x += (float)dL_dy(50) * (15.875763970811402*yz*(-10.0*x2*y2 + 5.0*x4 + y4)); // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(50) * (15.875763970811402*xz*(-10.0*x2*y2 + x4 + 5.0*y4)); // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + d.z += (float)dL_dy(50) * (5.2919213236038001*xy*(-10.0*x2*y2 + 3.0*x4 + 3.0*y4)); // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(51) * (-10.378311574405206*xy*(x2 - y2)*(13.0*z2 - 1.0)); // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + d.y += (float)dL_dy(51) * (0.51891557872026028*(13.0*z2 - 1.0)*(10.0*x2*y2 - 5.0*x4 + 4.0*y2*(5.0*x2 - y2) - y4)); // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + d.z += (float)dL_dy(51) * (13.491805046726766*yz*(10.0*x2*y2 - 5.0*x4 - y4)); // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(52) * (4.1513246297620823*yz*(3.0*x2 - y2)*(13.0*z2 - 3.0)); // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + d.y += (float)dL_dy(52) * (4.1513246297620823*xz*(x2 - 3.0*y2)*(13.0*z2 - 3.0)); // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + d.z += (float)dL_dy(52) * (12.453973889286248*xy*(x2 - y2)*(13.0*z2 - 1.0)); // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + d.x += (float)dL_dy(53) * (0.93875360317376422*xy*(66.0*z2 - 143.0*z4 - 3.0)); // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + d.y += (float)dL_dy(53) * (-0.46937680158688211*(x2 - y2)*(13.0*z2*(11.0*z2 - 3.0) - 27.0*z2 + 3.0)); // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + d.z += (float)dL_dy(53) * (-6.8841930899409371*yz*(3.0*x2 - y2)*(13.0*z2 - 3.0)); // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + d.x += (float)dL_dy(54) * (0.44253269244498261*yz*(-110.0*z2 + 143.0*z4 + 15.0)); // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + d.y += (float)dL_dy(54) * (0.44253269244498261*xz*(-110.0*z2 + 143.0*z4 + 15.0)); // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + d.z += (float)dL_dy(54) * (2.2126634622249131*xy*(-66.0*z2 + 143.0*z4 + 3.0)); // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + // d.x += (float)dL_dy(55) * (0); // 0 + d.y += (float)dL_dy(55) * (-12.194767023639836*z2 + 44.714145753346067*z4 - 38.752259652899923*z6 + 0.45165803791258652); // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + d.z += (float)dL_dy(55) * (1.6259689364853116*yz*(110.0*z2 - 143.0*z4 - 15.0)); // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + // d.x += (float)dL_dy(56) * (0); // 0 + // d.y += (float)dL_dy(56) * (0); // 0 + d.z += (float)dL_dy(56) * (64.528641681844675*z2 - 236.60501950009714*z4 + 205.05768356675085*z6 - 2.3899496919201733); // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + d.x += (float)dL_dy(57) * (-12.194767023639836*z2 + 44.714145753346067*z4 - 38.752259652899923*z6 + 0.45165803791258652); // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + // d.y += (float)dL_dy(57) * (0); // 0 + d.z += (float)dL_dy(57) * (1.6259689364853116*xz*(110.0*z2 - 143.0*z4 - 15.0)); // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + d.x += (float)dL_dy(58) * (0.44253269244498261*xz*(-110.0*z2 + 143.0*z4 + 15.0)); // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + d.y += (float)dL_dy(58) * (0.44253269244498261*yz*(110.0*z2 - 143.0*z4 - 15.0)); // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + d.z += (float)dL_dy(58) * (0.07375544874083044*(x2 - y2)*(143.0*z2*(3.0*z2 - 1.0) + 132.0*z2*(13.0*z2 - 5.0) - 187.0*z2 + 45.0)); // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + d.x += (float)dL_dy(59) * (30.97886890473422*x2*z2 - 67.120882626924143*x2*z4 - 1.4081304047606462*x2 - 30.97886890473422*y2*z2 + 67.120882626924143*y2*z4 + 1.4081304047606462*y2); // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + d.y += (float)dL_dy(59) * (0.93875360317376422*xy*(-66.0*z2 + 143.0*z4 + 3.0)); // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + d.z += (float)dL_dy(59) * (-6.8841930899409371*xz*(x2 - 3.0*y2)*(13.0*z2 - 3.0)); // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + d.x += (float)dL_dy(60) * (4.1513246297620823*xz*(x2 - 3.0*y2)*(13.0*z2 - 3.0)); // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + d.y += (float)dL_dy(60) * (-4.1513246297620823*yz*(3.0*x2 - y2)*(13.0*z2 - 3.0)); // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + d.z += (float)dL_dy(60) * (3.1134934723215619*(13.0*z2 - 1.0)*(-6.0*x2*y2 + x4 + y4)); // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(61) * (-0.51891557872026028*(13.0*z2 - 1.0)*(-10.0*x2*y2 + 4.0*x2*(x2 - 5.0*y2) + x4 + 5.0*y4)); // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + d.y += (float)dL_dy(61) * (10.378311574405206*xy*(x2 - y2)*(13.0*z2 - 1.0)); // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + d.z += (float)dL_dy(61) * (13.491805046726766*xz*(10.0*x2*y2 - x4 - 5.0*y4)); // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + d.x += (float)dL_dy(62) * (15.875763970811402*xz*(-10.0*x2*y2 + x4 + 5.0*y4)); // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + d.y += (float)dL_dy(62) * (15.875763970811402*yz*(10.0*x2*y2 - 5.0*x4 - y4)); // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + d.z += (float)dL_dy(62) * (39.6894099270285*x2*y4 - 39.6894099270285*x4*y2 + 2.6459606618019*x6 - 2.6459606618019*y6); // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + d.x += (float)dL_dy(63) * (-74.252086915082614*x2*y4 + 74.252086915082614*x4*y2 - 4.9501391276721742*x6 + 4.9501391276721742*y6); // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + d.y += (float)dL_dy(63) * (9.9002782553443485*xy*(-10.0*x2*y2 + 3.0*x4 + 3.0*y4)); // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + // d.z += (float)dL_dy(63) * (0); // 0 + return d; +} + +template +__device__ uint32_t lcg_hash(const uvec& pos_grid, const uint32_t primes[N_PRIMES]) { + static_assert(N_DIMS <= N_PRIMES, "lcg_hash can only hash up to N_PRIMES dimensions."); + + uint32_t result = 0; + + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N_DIMS; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + +template +__device__ uint32_t prime_hash(const uvec& pos_grid) { + constexpr uint32_t factors[7] = { 1958374283u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + return lcg_hash(pos_grid, factors); +} + +template +__device__ uint32_t coherent_prime_hash(const uvec& pos_grid) { + constexpr uint32_t factors[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + return lcg_hash(pos_grid, factors); +} + +template +__device__ uint32_t reversed_prime_hash(const uvec& pos_grid) { + constexpr uint32_t factors[7] = { 2165219737u, 1434869437u, 2097192037u, 3674653429u, 805459861u, 2654435761u, 1958374283u }; + return lcg_hash(pos_grid, factors); +} + +template +__device__ uint32_t base_convert_hash(const uvec& pos_grid) { + // [Allows for arbitary N_DIMS] A simple base conversion hash, used in permuto-encoding + uint32_t k = 0; + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim < N_DIMS; ++dim) { + k += pos_grid[dim]; + k *= 2531011; + } + return k; +} + +template +__device__ uint32_t rng_hash(const uvec& pos_grid, const uint32_t seed = 1337) { + constexpr uint32_t N_BITS_PER_DIM = 64 / N_DIMS; + uint64_t step = 0; + + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N_DIMS; ++i) { + step ^= (uint64_t)pos_grid[i] << (i * N_BITS_PER_DIM); + } + + default_rng_t rng{seed}; + rng.advance((int64_t)step); + return rng.next_uint(); +} + +template +__device__ +typename std::enable_if::type +grid_hash(const uvec& pos_grid) { + switch (HASH_TYPE) { + case HashType::Prime: return prime_hash(pos_grid); + case HashType::CoherentPrime: return coherent_prime_hash(pos_grid); + case HashType::ReversedPrime: return reversed_prime_hash(pos_grid); + case HashType::Rng: return rng_hash(pos_grid); + } + + return 0; +} + +template +__device__ +typename std::enable_if::type // Use template partial specialization to prevent static assertion on N_DIMS +grid_hash(const uvec& pos_grid) { + return base_convert_hash(pos_grid); +} + +template +__device__ uint32_t grid_index(const GridType grid_type, const uint32_t hashmap_size, const uint32_t grid_resolution, const uvec& pos_grid) { + uint32_t stride = 1; + uint32_t index = 0; + + // The second part of the loop condition is needed to avoid integer overflows in finer levels. + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim < N_DIMS && stride <= hashmap_size; ++dim) { + index += pos_grid[dim] * stride; + stride *= grid_resolution; + } + + if (grid_type == GridType::Hash && hashmap_size < stride) { + index = grid_hash(pos_grid); + } + + return index % hashmap_size; +} + +__host__ __device__ inline float grid_scale(uint32_t level, float log2_per_level_scale, uint32_t base_resolution) { + // The -1 means that `base_resolution` refers to the number of grid _vertices_ rather + // than the number of cells. This is slightly different from the notation in the paper, + // but results in nice, power-of-2-scaled parameter grids that fit better into cache lines. + return exp2f(level * log2_per_level_scale) * base_resolution - 1.0f; +} + +__host__ __device__ inline uint32_t grid_resolution(float scale) { + return (uint32_t)ceilf(scale) + 1; +} + +//TCNN_INTERNAL_BEGIN +template +__device__ __forceinline__ uint32_t permuto_index( + const uvec& key, + const uint32_t hashmap_size +) { + return (base_convert_hash(key) % hashmap_size); +} + +template +__device__ __forceinline__ void permuto_elevate( + const float* __restrict__ pos, // [N_DIMS] + const float* __restrict__ scales_per_dim, // [N_DIMS] + const float* __restrict__ shifts_per_dim, // [N_DIMS] + float* __restrict__ elevated // [N_DIMS+1] +) { + // Elevate d-dimension vector to (d+1)-dimension homogeneous vector on hyperplane H_d + // a) The sum of the components of `elevated` is zero, ensuring it within hyperplane H_d + // b) The magnitudes of the components of `elevated` are similar to each other. + float sum = 0; + TCNN_PRAGMA_UNROLL + for (int dim = N_DIMS; dim > 0; dim--) { + float cf = (pos[dim-1] + shifts_per_dim[dim-1]) * scales_per_dim[dim-1]; + elevated[dim] = sum - (float)dim * cf; + sum += cf; + } + elevated[0] = sum; +} + +template +__device__ __forceinline__ void permuto_find_rem0( + const float* __restrict__ elevated, // [N_DIMS+1] + int* __restrict__ rem0, // [N_DIMS+1] + int* __restrict__ rank // [N_DIMS+1] +) { + // Find the closest remainder-0 point through rounding + int sum = 0; + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim <= N_DIMS; ++dim) { + // NOTE by Jianfei Guo: + // For N=N_DIMS+1 that is not a power of 2, using xxx*(1.0f/N) is significantly faster than xxx/N when --fast_math is off. + // This is because: + // a) xxx*(1.0f/N) compiles into a single floating-point multiplication instruction (mul.f32) in both PTX and SASS codes, + // b) xxx/N compiles into a floating-point division instruction (div.rn.f32) in PTX, which further converts to multiple instructions in SASS + // and largely hinders the performance. + float v = elevated[dim] * (1.0f / (N_DIMS+1)); + float up = ceil(v) * (N_DIMS + 1); + float down = floor(v) * (N_DIMS + 1); + if (up - elevated[dim] < elevated[dim] - down) { + rem0[dim] = (int)up; + } else { + rem0[dim] = (int)down; + } + sum += rem0[dim]; + } + sum /= (int)(N_DIMS+1); + + // Find the simplex we are in and store it in rank + // (where rank describes what position coordinate i has in the sorted order of the features values) + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim < N_DIMS; ++dim) { + float di = elevated[dim] - rem0[dim]; + for (uint32_t other_dim = dim + 1; other_dim <= N_DIMS; ++other_dim) { + if (di < elevated[other_dim] - rem0[other_dim]) { + rank[dim]++; + } else { + rank[other_dim]++; + } + } + } + + // If the point doesn't lie on the plane (sum != 0) bring it back + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim <= N_DIMS; ++dim) { + rank[dim] += sum; + if (rank[dim] < 0) { + rank[dim] += (int)(N_DIMS+1); + rem0[dim] += (int)(N_DIMS+1); + } else if (rank[dim] > (int)N_DIMS) { + rank[dim] -= (int)(N_DIMS+1); + rem0[dim] -= (int)(N_DIMS+1); + } + } +} + +template +__device__ __forceinline__ void permuto_barycentric( + const float* __restrict__ elevated, // [N_DIMS+1] + const int* __restrict__ rem0, // [N_DIMS+1] + const int* __restrict__ rank, // [N_DIMS+1] + float* __restrict__ barycentric // [N_DIMS+2] +) { + // Compute the barycentric coordinates (p.10 in [Adams etal 2010]) + TCNN_PRAGMA_UNROLL + for (uint32_t dim = 0; dim <= N_DIMS; ++dim) { + float delta = (elevated[dim] - rem0[dim]) * (1.0f / (N_DIMS+1)); + barycentric[(int)N_DIMS - rank[dim]] += delta; + barycentric[(int)(N_DIMS + 1) - rank[dim]] -= delta; + } + // Wrap around + barycentric[0] += 1.0f + barycentric[N_DIMS + 1]; +} +//TCNN_INTERNAL_END + +template +__global__ void kernel_activation(const uint32_t num_elements, const Activation act, const T* in, T* out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + auto frag = ((vector_fragment_t*)in)[i]; + warp_activation(act, frag, frag); + ((vector_fragment_t*)out)[i] = frag; +} + +// Transfer functions corresponding to activations; version without biases +template +__global__ void kernel_activation_backward(const uint32_t num_elements, const Activation act, const T* __restrict__ values, const T* gradients_out, T* gradients_in) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + auto frag_forward_in = ((vector_fragment_t*)values)[i]; + auto frag = ((vector_fragment_t*)gradients_out)[i]; + warp_activation_backward_in(act, frag, frag_forward_in, frag); + + ((vector_fragment_t*)gradients_in)[i] = frag; +} + +// Transfer functions corresponding to activations, given _output_ values. Only works if the activation is invertible +template +__global__ void kernel_activation_backward_output(const uint32_t num_elements, const Activation act, const T* __restrict__ output_values, const T* gradients_out, T* gradients_in) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + auto frag_forward_out = ((vector_fragment_t*)output_values)[i]; + auto frag = ((vector_fragment_t*)gradients_out)[i]; + warp_activation_backward(act, frag, frag_forward_out, frag); + + ((vector_fragment_t*)gradients_in)[i] = frag; +} + +// Expands a 10-bit integer into 30 bits +// by inserting 2 zeros after each bit. +__host__ __device__ inline uint32_t expand_bits(uint32_t v) { + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +// Calculates a 30-bit Morton code for the +// given 3D point located within the unit cube [0,1]. +__host__ __device__ inline uint32_t morton3D(uint32_t x, uint32_t y, uint32_t z) { + uint32_t xx = expand_bits(x); + uint32_t yy = expand_bits(y); + uint32_t zz = expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +__host__ __device__ inline uint32_t morton3D_invert(uint32_t x) { + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + +__host__ __device__ inline uint64_t expand_bits(uint64_t w) { + w &= 0x1fffff; + w = (w | w << 32) & 0x1f00000000ffff; + w = (w | w << 16) & 0x1f0000ff0000ff; + w = (w | w << 8) & 0x100f00f00f00f00f; + w = (w | w << 4) & 0x10c30c30c30c30c3; + w = (w | w << 2) & 0x1249249249249249; + return w; +} + +__host__ __device__ inline uint64_t morton3D_64bit(const ivec3& p) { + return ((expand_bits((uint64_t)p.x)) | (expand_bits((uint64_t)p.y) << 1) | (expand_bits((uint64_t)p.z) << 2)); +} + +__device__ inline float smoothstep(float val) { + return val*val*(3.0f - 2.0f * val); +} + +__device__ inline float smoothstep_derivative(float val) { + return 6*val*(1.0f - val); +} + +__device__ inline float smoothstep_2nd_derivative(float val) { + return 6.0f - 12.0f * val; +} + +__device__ inline float identity_fun(float val) { + return val; +} + +__device__ inline float identity_derivative(float val) { + return 1.0f; +} + +__device__ inline float identity_2nd_derivative(float val) { + return 0.0f; +} + +template +__device__ inline void pos_fract(const float input, float* pos, float* pos_derivative, float* pos_2nd_derivative, uint32_t* pos_grid, float scale, F interpolation_fun, FPRIME interpolation_fun_derivative, FPRIMEPRIME interpolation_fun_2nd_derivative) { + // The offset of 0.5 causes different scales to be staggered with respect to each other, thus + // preventing spurious alignment of fractional coordinates upon integer scales (or powers thereof). + // This is mentioned in Appendix A of the "Instant Neural Graphics Primitives" paper. + // The offset can cause wraparound indexing in dense grids, which didn't negatively impact + // the approximation quality in any of our tests. + *pos = fmaf(scale, input, 0.5f); + float tmp = floorf(*pos); + *pos_grid = (uint32_t)(int)tmp; + *pos -= (float)tmp; + *pos_2nd_derivative = interpolation_fun_2nd_derivative(*pos); + *pos_derivative = interpolation_fun_derivative(*pos); + *pos = interpolation_fun(*pos); +} + +template +__device__ inline void pos_fract(const float input, float* pos, float* pos_derivative, uint32_t* pos_grid, float scale, F interpolation_fun, FPRIME interpolation_fun_derivative) { + // The offset of 0.5 causes different scales to be staggered with respect to each other, thus + // preventing spurious alignment of fractional coordinates upon integer scales (or powers thereof). + // This is mentioned in Appendix A of the "Instant Neural Graphics Primitives" paper. + // The offset can cause wraparound indexing in dense grids, which didn't negatively impact + // the approximation quality in any of our tests. + *pos = fmaf(scale, input, 0.5f); + float tmp = floorf(*pos); + *pos_grid = (uint32_t)(int)tmp; + *pos -= tmp; + *pos_derivative = interpolation_fun_derivative(*pos); + *pos = interpolation_fun(*pos); +} + +template +__device__ inline void pos_fract(const float input, float* pos, uint32_t* pos_grid, float scale, F interpolation_fun) { + // The offset of 0.5 causes different scales to be staggered with respect to each other, thus + // preventing spurious alignment of fractional coordinates upon integer scales (or powers thereof). + // This is mentioned in Appendix A of the "Instant Neural Graphics Primitives" paper. + // The offset can cause wraparound indexing in dense grids, which didn't negatively impact + // the approximation quality in any of our tests. + *pos = fmaf(scale, input, 0.5f); + float tmp = floorf(*pos); + *pos_grid = (uint32_t)(int)tmp; + *pos -= tmp; + *pos = interpolation_fun(*pos); +} + +__device__ inline float weight_decay(float relative_weight_decay, float absolute_weight_decay, float weight) { + // Relative weight decay is closely related to l2 regularization, whereas absolute weight decay corresponds to l1 regularization + return (1 - relative_weight_decay) * weight - copysignf(absolute_weight_decay, weight); +} + +__device__ inline float gaussian_cdf(const float x, const float inv_radius) { + return normcdff(x * inv_radius); +} + +__device__ inline float gaussian_cdf_approx(const float x, const float inv_radius) { + static constexpr float MAGIC_SIGMOID_FACTOR = 1.12f / SQRT2; + return logistic(MAGIC_SIGMOID_FACTOR * x * inv_radius); +} + +__device__ inline float gaussian_cdf_approx_derivative(const float result, const float inv_radius) { + static constexpr float MAGIC_SIGMOID_FACTOR = 1.12f / SQRT2; + return result * (1 - result) * MAGIC_SIGMOID_FACTOR * inv_radius; +} + +__device__ inline float gaussian_pdf(const float x, const float inv_radius) { + return inv_radius * rsqrtf(2.0f * PI()) * expf(-0.5f * (x * x * inv_radius * inv_radius)); +} + +__device__ inline float gaussian_pdf_max_1(const float x, const float inv_radius) { + return expf(-0.5f * (x * x * inv_radius * inv_radius)); +} + +__device__ inline float tent(const float x, const float inv_radius) { + return fmaxf(1.0f - fabsf(x * inv_radius), 0.0f); +} + +__device__ inline float tent_cdf(const float x, const float inv_radius) { + return fmaxf(0.0f, fminf(1.0f, x * inv_radius + 0.5f)); +} + +__host__ __device__ inline float quartic(const float x, const float inv_radius) { + const float u = x * inv_radius; + const float tmp = fmaxf(1 - u*u, 0.0f); + return ((float)15 / 16) * tmp * tmp; +} + +__host__ __device__ inline float quartic_cdf_deriv(const float x, const float inv_radius) { + return quartic(x, inv_radius) * inv_radius; +} + +__host__ __device__ inline float quartic_cdf(const float x, const float inv_radius) { + const float u = x * inv_radius; + const float u2 = u * u; + const float u4 = u2 * u2; + return fmaxf(0.0f, fminf(1.0f, ((float)15 / 16) * u * (1 - ((float)2 / 3) * u2 + ((float)1 / 5) * u4) + 0.5f)); +} + +__host__ __device__ inline uint32_t permute(uint32_t num, uint32_t size) { + const uint32_t A = 1434869437; // Large prime number + const uint32_t B = 2097192037; + return ((uint64_t)num * A + B) % size; +} + +template +__global__ void shuffle(const uint32_t n_elements, const uint32_t stride, const uint32_t seed, const T* __restrict__ in, T* __restrict__ out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= n_elements * stride) return; + + const uint32_t elem_id = i / stride; + const uint32_t member_id = i % stride; + + out[i] = in[permute(elem_id + seed, n_elements) * stride + member_id]; +} + +template +__global__ void fill_rollover(const uint32_t n_elements, const uint32_t stride, const uint32_t* n_input_elements_ptr, T* inout) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t n_input_elements = *n_input_elements_ptr; + + if (i < (n_input_elements * stride) || i >= (n_elements * stride) || n_input_elements == 0) return; + + T result = inout[i % (n_input_elements * stride)]; + inout[i] = result; +} + +template +__global__ void fill_rollover_and_rescale(const uint32_t n_elements, const uint32_t stride, const uint32_t* n_input_elements_ptr, T* inout) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t n_input_elements = *n_input_elements_ptr; + + if (i < (n_input_elements * stride) || i >= (n_elements * stride) || n_input_elements == 0) return; + + T result = inout[i % (n_input_elements * stride)]; + result = (T)((float)result * n_input_elements / n_elements); + inout[i] = result; +} + +template +__global__ void add(const uint32_t num_elements, const T1* data_in_1, const T2* data_in_2, T3* data_out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + data_out[i] = (T3)((float)data_in_1[i] + (float)data_in_2[i]); +} + +template +__global__ void add(const uint32_t num_elements, const T* __restrict__ data_in, T* __restrict__ data_in_out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + data_in_out[i] = data_in[i] + data_in_out[i]; +} + +template +__global__ void trim(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, T* __restrict__ data_out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + uint32_t idx = i % dims; + uint32_t elem = i / dims; + + data_out[i] = data_in[elem * stride + idx]; +} + +template +__global__ void trim_and_cast(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, float* __restrict__ data_out) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + uint32_t idx = i % dims; + uint32_t elem = i / dims; + + data_out[i] = (float)data_in[elem * stride + idx]; +} + +template +__global__ void cast(const uint32_t num_elements, const float* __restrict__ full_precision, T* __restrict__ target) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + target[i] = (T)full_precision[i]; +} + +template +__global__ void cast_from(const uint32_t num_elements, const T* __restrict__ precision, float* __restrict__ full_precision) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + full_precision[i] = (float)precision[i]; +} + +template +__global__ void extract_dimension_pos_neg_kernel(const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* __restrict__ encoded, const MatrixLayout layout, float* __restrict__ output) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + const uint32_t elem_idx = i / fan_out; + const uint32_t dim_idx = i % fan_out; + + const uint32_t encoded_idx = layout == MatrixLayout::AoS ? (elem_idx * fan_in + dim) : (elem_idx + dim * num_elements / fan_out); + + if (fan_out == 1) { + output[i] = (float)encoded[encoded_idx]; + return; + } + + if (dim_idx == 0) { + output[i] = fmaxf(-(float)encoded[encoded_idx], 0.0f); + } else if (dim_idx == 1) { + output[i] = fmaxf((float)encoded[encoded_idx], 0.0f); + } else if (dim_idx == 2) { + output[i] = 0; + } else { + output[i] = 1; + } +} + +template +__global__ void mult_scalar_kernel(const uint32_t num_elements, T* __restrict__ inout, float factor) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + inout[i] = (T)((float)inout[i] * factor); +} + +template +__global__ void mult_kernel(const uint32_t num_elements, const T* factor1, const T* factor2, T* result) { + const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= num_elements) return; + + result[i] = factor1[i] * factor2[i]; +} + +} diff --git a/gui/include/tiny-cuda-nn/common_host.h b/gui/include/tiny-cuda-nn/common_host.h new file mode 100644 index 0000000000000000000000000000000000000000..90fbb707f1d55bafd2451c11d03fec097ef17a59 --- /dev/null +++ b/gui/include/tiny-cuda-nn/common_host.h @@ -0,0 +1,502 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common_host.h + * @author Thomas Müller and Nikolaus Binder, NVIDIA + * @brief Common utilities that are needed by pretty much every component of this framework. + */ + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace tcnn { + +using namespace fmt::literals; + +enum class LogSeverity { + Info, + Debug, + Warning, + Error, + Success, +}; + +const std::function& log_callback(); +void set_log_callback(const std::function& callback); + +template +void log(LogSeverity severity, const std::string& msg, Ts&&... args) { + log_callback()(severity, fmt::format(msg, std::forward(args)...)); +} + +template void log_info(const std::string& msg, Ts&&... args) { log(LogSeverity::Info, msg, std::forward(args)...); } +template void log_debug(const std::string& msg, Ts&&... args) { log(LogSeverity::Debug, msg, std::forward(args)...); } +template void log_warning(const std::string& msg, Ts&&... args) { log(LogSeverity::Warning, msg, std::forward(args)...); } +template void log_error(const std::string& msg, Ts&&... args) { log(LogSeverity::Error, msg, std::forward(args)...); } +template void log_success(const std::string& msg, Ts&&... args) { log(LogSeverity::Success, msg, std::forward(args)...); } + +bool verbose(); +void set_verbose(bool verbose); + +#define CHECK_THROW(x) \ + do { if (!(x)) throw std::runtime_error{FILE_LINE " check failed: " #x}; } while(0) + +/// Checks the result of a cuXXXXXX call and throws an error on failure +#define CU_CHECK_THROW(x) \ + do { \ + CUresult _result = x; \ + if (_result != CUDA_SUCCESS) { \ + const char *msg; \ + cuGetErrorName(_result, &msg); \ + throw std::runtime_error{fmt::format(FILE_LINE " " #x " failed: {}", msg)}; \ + } \ + } while(0) + +/// Checks the result of a cuXXXXXX call and prints an error on failure +#define CU_CHECK_PRINT(x) \ + do { \ + CUresult _result = x; \ + if (_result != CUDA_SUCCESS) { \ + const char *msg; \ + cuGetErrorName(_result, &msg); \ + log_error(FILE_LINE " " #x " failed: {}", msg); \ + } \ + } while(0) + +/// Checks the result of a cudaXXXXXX call and throws an error on failure +#define CUDA_CHECK_THROW(x) \ + do { \ + cudaError_t _result = x; \ + if (_result != cudaSuccess) \ + throw std::runtime_error{fmt::format(FILE_LINE " " #x " failed: {}", cudaGetErrorString(_result))}; \ + } while(0) + +/// Checks the result of a cudaXXXXXX call and prints an error on failure +#define CUDA_CHECK_PRINT(x) \ + do { \ + cudaError_t _result = x; \ + if (_result != cudaSuccess) \ + log_error(FILE_LINE " " #x " failed: {}", cudaGetErrorString(_result)); \ + } while(0) + +/// Checks the result of optixXXXXXX call and throws an error on failure +#define OPTIX_CHECK_THROW(x) \ + do { \ + OptixResult _result = x; \ + if (_result != OPTIX_SUCCESS) { \ + throw std::runtime_error(std::string("Optix call '" #x "' failed.")); \ + } \ + } while(0) + +/// Checks the result of a optixXXXXXX call and throws an error with a log message on failure +#define OPTIX_CHECK_THROW_LOG(x) \ + do { \ + OptixResult _result = x; \ + const size_t sizeof_log_returned = sizeof_log; \ + sizeof_log = sizeof( log ); /* reset sizeof_log for future calls */ \ + if (_result != OPTIX_SUCCESS) { \ + throw std::runtime_error(std::string("Optix call '" #x "' failed. Log:\n") + log + (sizeof_log_returned == sizeof_log ? "" : "")); \ + } \ + } while(0) + +////////////////////////////// +// Enum<->string conversion // +////////////////////////////// + +Activation string_to_activation(const std::string& activation_name); +std::string to_string(Activation activation); + +GridType string_to_grid_type(const std::string& grid_type); +std::string to_string(GridType grid_type); + +HashType string_to_hash_type(const std::string& hash_type); +std::string to_string(HashType hash_type); + +InterpolationType string_to_interpolation_type(const std::string& interpolation_type); +std::string to_string(InterpolationType interpolation_type); + +ReductionType string_to_reduction_type(const std::string& reduction_type); +std::string to_string(ReductionType reduction_type); + +////////////////// +// Misc helpers // +////////////////// + +int cuda_runtime_version(); +inline std::string cuda_runtime_version_string() { + int v = cuda_runtime_version(); + return fmt::format("{}.{}", v / 1000, (v % 100) / 10); +} + +int cuda_device(); +void set_cuda_device(int device); +int cuda_device_count(); + +bool cuda_supports_virtual_memory(int device); +inline bool cuda_supports_virtual_memory() { + return cuda_supports_virtual_memory(cuda_device()); +} + +std::string cuda_device_name(int device); +inline std::string cuda_device_name() { + return cuda_device_name(cuda_device()); +} + +uint32_t cuda_compute_capability(int device); +inline uint32_t cuda_compute_capability() { + return cuda_compute_capability(cuda_device()); +} + +uint32_t cuda_max_supported_compute_capability(); +uint32_t cuda_supported_compute_capability(int device); +inline uint32_t cuda_supported_compute_capability() { + return cuda_supported_compute_capability(cuda_device()); +} + +size_t cuda_max_shmem(int device); +inline size_t cuda_max_shmem() { + return cuda_max_shmem(cuda_device()); +} + +uint32_t cuda_max_registers(int device); +inline uint32_t cuda_max_registers() { + return cuda_max_registers(cuda_device()); +} + +size_t cuda_memory_granularity(int device); +inline size_t cuda_memory_granularity() { + return cuda_memory_granularity(cuda_device()); +} + +struct MemoryInfo { + size_t total; + size_t free; + size_t used; +}; + +MemoryInfo cuda_memory_info(); + +// Hash helpers taken from https://stackoverflow.com/a/50978188 +template +T xorshift(T n, int i) { + return n ^ (n >> i); +} + +inline uint32_t distribute(uint32_t n) { + uint32_t p = 0x55555555ul; // pattern of alternating 0 and 1 + uint32_t c = 3423571495ul; // random uneven integer constant; + return c * xorshift(p * xorshift(n, 16), 16); +} + +inline uint64_t distribute(uint64_t n) { + uint64_t p = 0x5555555555555555ull; // pattern of alternating 0 and 1 + uint64_t c = 17316035218449499591ull;// random uneven integer constant; + return c * xorshift(p * xorshift(n, 32), 32); +} + +template +constexpr typename std::enable_if::value, T>::type rotl(const T n, const S i) { + const T m = (std::numeric_limits::digits - 1); + const T c = i & m; + return (n << c) | (n >> (((T)0 - c) & m)); // this is usually recognized by the compiler to mean rotation +} + +template +size_t hash_combine(std::size_t seed, const T& v) { + return rotl(seed, std::numeric_limits::digits / 3) ^ distribute(std::hash{}(v)); +} + +std::string generate_device_code_preamble(); + +std::string to_snake_case(const std::string& str); + +std::vector split(const std::string& text, const std::string& delim); + +template +std::string join(const T& components, const std::string& delim) { + std::ostringstream s; + for (const auto& component : components) { + if (&components[0] != &component) { + s << delim; + } + s << component; + } + + return s.str(); +} + +template +std::string dfmt(uint32_t indent, const std::string& format, Ts&&... args) { + // Trim empty lines at the beginning and end of format string. + // Also re-indent the format string `indent` deep. + uint32_t input_indent = std::numeric_limits::max(); + uint32_t n_empty_leading = 0, n_empty_trailing = 0; + bool leading = true; + + std::vector lines = split(format, "\n"); + for (const auto& line : lines) { + bool empty = true; + uint32_t line_indent = 0; + for (uint32_t i = 0; i < line.length(); ++i) { + if (empty && line[i] == '\t') { + line_indent = i+1; + } else { + empty = false; + break; + } + } + + if (empty) { + if (leading) { + ++n_empty_leading; + } + ++n_empty_trailing; + continue; + } + + n_empty_trailing = 0; + + leading = false; + input_indent = std::min(input_indent, line_indent); + } + + if (input_indent == std::numeric_limits::max()) { + return ""; + } + + lines.erase(lines.end() - n_empty_trailing, lines.end()); + lines.erase(lines.begin(), lines.begin() + n_empty_leading); + + for (auto& line : lines) { + if (line.length() >= input_indent) { + line = line.substr(input_indent); + line = line.insert(0, indent, '\t'); + } + } + + return fmt::format(join(lines, "\n"), std::forward(args)...); +} + +std::string to_lower(std::string str); +std::string to_upper(std::string str); +inline bool equals_case_insensitive(const std::string& str1, const std::string& str2) { + return to_lower(str1) == to_lower(str2); +} + +struct CaseInsensitiveHash { size_t operator()(const std::string& v) const { return std::hash{}(to_lower(v)); }}; +struct CaseInsensitiveEqual { bool operator()(const std::string& l, const std::string& r) const { return equals_case_insensitive(l, r); }}; + +template +using ci_hashmap = std::unordered_map; + + +template +std::string type_to_string(); + +template +std::string to_string(const tvec& v) { + return fmt::format("tvec<{}, {}, {}>({})", type_to_string(), N, A, join(v, ", ")); +} + +inline std::string bytes_to_string(size_t bytes) { + std::array suffixes = {{ "B", "KB", "MB", "GB", "TB", "PB", "EB" }}; + + double count = (double)bytes; + uint32_t i = 0; + for (; i < suffixes.size() && count >= 1024; ++i) { + count /= 1024; + } + + std::ostringstream oss; + oss.precision(3); + oss << count << " " << suffixes[i]; + return oss.str(); +} + +inline bool is_pot(uint32_t num, uint32_t* log2 = nullptr) { + if (log2) *log2 = 0; + if (num > 0) { + while (num % 2 == 0) { + num /= 2; + if (log2) ++*log2; + } + if (num == 1) { + return true; + } + } + + return false; +} + +inline uint32_t powi(uint32_t base, uint32_t exponent) { + uint32_t result = 1; + for (uint32_t i = 0; i < exponent; ++i) { + result *= base; + } + + return result; +} + +class ScopeGuard { +public: + ScopeGuard() = default; + ScopeGuard(const std::function& callback) : m_callback{callback} {} + ScopeGuard(std::function&& callback) : m_callback{std::move(callback)} {} + ScopeGuard& operator=(const ScopeGuard& other) = delete; + ScopeGuard(const ScopeGuard& other) = delete; + ScopeGuard& operator=(ScopeGuard&& other) { std::swap(m_callback, other.m_callback); return *this; } + ScopeGuard(ScopeGuard&& other) { *this = std::move(other); } + ~ScopeGuard() { if (m_callback) { m_callback(); } } + + void disarm() { + m_callback = {}; + } +private: + std::function m_callback; +}; + +template +class Lazy { +public: + template + T& get(F&& generator) { + if (!m_val) { + m_val = generator(); + } + + return m_val; + } + +private: + T m_val; +}; + +#if defined(__CUDACC__) || (defined(__clang__) && defined(__CUDA__)) +template +inline void linear_kernel(K kernel, uint32_t shmem_size, cudaStream_t stream, T n_elements, Types ... args) { + if (n_elements <= 0) { + return; + } + kernel<<>>(n_elements, args...); +} + +template +__global__ void parallel_for_kernel(const size_t n_elements, F fun) { + const size_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= n_elements) return; + + fun(i); +} + +template +inline void parallel_for_gpu(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, F&& fun) { + if (n_elements <= 0) { + return; + } + parallel_for_kernel<<>>(n_elements, fun); +} + +template +inline void parallel_for_gpu(cudaStream_t stream, size_t n_elements, F&& fun) { + parallel_for_gpu(0, stream, n_elements, std::forward(fun)); +} + +template +inline void parallel_for_gpu(size_t n_elements, F&& fun) { + parallel_for_gpu(nullptr, n_elements, std::forward(fun)); +} + +template +__global__ void parallel_for_aos_kernel(const size_t n_elements, const uint32_t n_dims, F fun) { + const size_t dim = threadIdx.x; + const size_t elem = threadIdx.y + blockIdx.x * blockDim.y; + if (dim >= n_dims) return; + if (elem >= n_elements) return; + + fun(elem, dim); +} + +template +inline void parallel_for_gpu_aos(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) { + if (n_elements <= 0 || n_dims <= 0) { + return; + } + + const dim3 threads = { n_dims, div_round_up(N_THREADS_LINEAR, n_dims), 1 }; + const size_t n_threads = threads.x * threads.y; + const dim3 blocks = { (uint32_t)div_round_up(n_elements * n_dims, n_threads), 1, 1 }; + + parallel_for_aos_kernel<<>>( + n_elements, n_dims, fun + ); +} + +template +inline void parallel_for_gpu_aos(cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) { + parallel_for_gpu_aos(0, stream, n_elements, n_dims, std::forward(fun)); +} + +template +inline void parallel_for_gpu_aos(size_t n_elements, uint32_t n_dims, F&& fun) { + parallel_for_gpu_aos(nullptr, n_elements, n_dims, std::forward(fun)); +} + +template +__global__ void parallel_for_soa_kernel(const size_t n_elements, const uint32_t n_dims, F fun) { + const size_t elem = threadIdx.x + blockIdx.x * blockDim.x; + const size_t dim = blockIdx.y; + if (elem >= n_elements) return; + if (dim >= n_dims) return; + + fun(elem, dim); +} + +template +inline void parallel_for_gpu_soa(uint32_t shmem_size, cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) { + if (n_elements <= 0 || n_dims <= 0) { + return; + } + + const dim3 blocks = { n_blocks_linear(n_elements), n_dims, 1 }; + + parallel_for_soa_kernel<<>>( + n_elements, n_dims, fun + ); +} + +template +inline void parallel_for_gpu_soa(cudaStream_t stream, size_t n_elements, uint32_t n_dims, F&& fun) { + parallel_for_gpu_soa(0, stream, n_elements, n_dims, std::forward(fun)); +} + +template +inline void parallel_for_gpu_soa(size_t n_elements, uint32_t n_dims, F&& fun) { + parallel_for_gpu_soa(nullptr, n_elements, n_dims, std::forward(fun)); +} +#endif + +} diff --git a/gui/include/tiny-cuda-nn/cuda_graph.h b/gui/include/tiny-cuda-nn/cuda_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..0cc3932fc9dc8ae6e73533b2878038e0a231eb8a --- /dev/null +++ b/gui/include/tiny-cuda-nn/cuda_graph.h @@ -0,0 +1,173 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file cuda_graph.h + * @author Thomas Müller, NVIDIA + * @brief Implementation of a CUDA graph capture/update with subsequent execution + */ + +#pragma once + +#include + +#include + +#include +#include + +namespace tcnn { + +class CudaGraph; + +inline std::deque& current_captures() { + static thread_local std::deque s_current_captures; + return s_current_captures; +} + +inline CudaGraph* current_capture() { + return current_captures().empty() ? nullptr : current_captures().front(); +} + +class CudaGraph { +public: + ~CudaGraph() { + try { + reset(); + } catch (const std::runtime_error& error) { + // Don't need to report on destruction problems when the driver is shutting down. + if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { + log_warning("Could not destroy cuda graph: {}", error.what()); + } + } + } + + ScopeGuard capture_guard(cudaStream_t stream) { + // Can't capture on the global stream + if (stream == nullptr || stream == cudaStreamLegacy) { + return {}; + } + + // If the caller is already capturing, no need for a nested capture. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK_THROW(cudaStreamIsCapturing(stream, &capture_status)); + if (capture_status != cudaStreamCaptureStatusNone) { + return {}; + } + + cudaError_t capture_result = cudaStreamIsCapturing(cudaStreamLegacy, &capture_status); + if (capture_result == cudaErrorStreamCaptureImplicit) { + return {}; + } + + CUDA_CHECK_THROW(capture_result); + if (capture_status != cudaStreamCaptureStatusNone) { + return {}; + } + + // Start capturing + if (m_graph) { + CUDA_CHECK_THROW(cudaGraphDestroy(m_graph)); + m_graph = nullptr; + } + + CUDA_CHECK_THROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); + current_captures().push_back(this); + + // Stop capturing again once the returned object goes out of scope + return ScopeGuard{[this, stream]() { + CUDA_CHECK_THROW(cudaStreamEndCapture(stream, &m_graph)); + + if (current_captures().back() != this) { + throw std::runtime_error{"CudaGraph: must end captures in reverse order of creation."}; + } + current_captures().pop_back(); + + if (m_synchronize_when_capture_done) { + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + m_synchronize_when_capture_done = false; + } + + // Capture failed for some reason. Reset state and don't execute anything. + // A corresponding exception is likely already in flight. + if (!m_graph) { + if (m_graph_instance) { + CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); + } + + m_graph = nullptr; + m_graph_instance = nullptr; + return; + } + + // If we previously created a graph instance, try to update it with the newly captured graph. + // This is cheaper than creating a new instance from scratch (and may involve just updating + // pointers rather than changing the topology of the graph.) + if (m_graph_instance) { +#if CUDA_VERSION >= 12000 + cudaGraphExecUpdateResultInfo update_result; + CUDA_CHECK_THROW(cudaGraphExecUpdate(m_graph_instance, m_graph, &update_result)); + + // If the update failed, reset graph instance. We will create a new one next. + if (update_result.result != cudaGraphExecUpdateSuccess) { + CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); + m_graph_instance = nullptr; + } +#else + cudaGraphExecUpdateResult update_result; + cudaGraphNode_t error_node; + CUDA_CHECK_THROW(cudaGraphExecUpdate(m_graph_instance, m_graph, &error_node, &update_result)); + + // If the update failed, reset graph instance. We will create a new one next. + if (update_result != cudaGraphExecUpdateSuccess) { + CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); + m_graph_instance = nullptr; + } +#endif + } + + if (!m_graph_instance) { + CUDA_CHECK_THROW(cudaGraphInstantiate(&m_graph_instance, m_graph, NULL, NULL, 0)); + } + + CUDA_CHECK_THROW(cudaGraphLaunch(m_graph_instance, stream)); + }}; + } + + void reset() { + if (m_graph) { + CUDA_CHECK_THROW(cudaGraphDestroy(m_graph)); + m_graph = nullptr; + } + + if (m_graph_instance) { + CUDA_CHECK_THROW(cudaGraphExecDestroy(m_graph_instance)); + m_graph_instance = nullptr; + } + } + + void schedule_synchronize() { + m_synchronize_when_capture_done = true; + } + +private: + cudaGraph_t m_graph = nullptr; + cudaGraphExec_t m_graph_instance = nullptr; + + bool m_synchronize_when_capture_done = false; +}; + +} diff --git a/gui/include/tiny-cuda-nn/gpu_matrix.h b/gui/include/tiny-cuda-nn/gpu_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..881c833180a97784dde0db1919bc2f6621e39bcb --- /dev/null +++ b/gui/include/tiny-cuda-nn/gpu_matrix.h @@ -0,0 +1,521 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file gpu_matrix.h + * @author Thomas Müller, NVIDIA + * @brief Matrix whose data resides in GPU (CUDA) memory + */ + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include + +namespace tcnn { + +template +class GPUMatrixDynamic; + +template +class GPUMatrix; + +class GPUMatrixBase { +public: + virtual ~GPUMatrixBase() {} + + virtual size_t n_bytes() const = 0; + virtual void set_data_unsafe(void* data) = 0; + + static void allocate_shared_memory(GPUMemory& memory, const std::vector& matrices) { + size_t total_n_bytes = 0; + for (auto* matrix : matrices) { + total_n_bytes += matrix->n_bytes(); + } + + if (memory.bytes() < total_n_bytes) { + log_debug("GPUMatrix: allocating {} shared among {} matrices.", bytes_to_string(total_n_bytes), matrices.size()); + memory.resize(total_n_bytes); + } + + size_t offset = 0; + for (auto* matrix : matrices) { + matrix->set_data_unsafe(memory.data() + offset); + offset += matrix->n_bytes(); + } + } + + template + static void allocate_shared_memory(GPUMemory& memory, std::vector>& matrices); + + template + static void allocate_shared_memory(GPUMemory& memory, std::vector>& matrices); + + static GPUMemoryArena::Allocation allocate_shared_memory(cudaStream_t stream, const std::vector& matrices) { + size_t total_n_bytes = 0; + for (auto* matrix : matrices) { + total_n_bytes += matrix->n_bytes(); + } + + auto alloc = allocate_workspace(stream, total_n_bytes); + + size_t offset = 0; + for (auto* matrix : matrices) { + matrix->set_data_unsafe(alloc.data() + offset); + offset += matrix->n_bytes(); + } + + return alloc; + } + + template + static GPUMemoryArena::Allocation allocate_shared_memory(cudaStream_t stream, std::vector>& matrices); + + template + static GPUMemoryArena::Allocation allocate_shared_memory(cudaStream_t stream, std::vector>& matrices); +}; + +template +class GPUMatrixDynamic : public GPUMatrixBase { +public: + using Type = T; + using View = MatrixView; + using ConstView = MatrixView; + + // Owning its memory as a GPUMemory + GPUMatrixDynamic(uint32_t m, uint32_t n, MatrixLayout layout = CM) + : m_rows{m}, m_cols{n}, m_layout{layout} { + m_malloc_allocation = std::make_shared>(m * n * sizeof(T)); + m_data = (T*)m_malloc_allocation->data(); + set_stride_contiguous(); + } + + // Owning its memory as an allocation from a stream's memory arena + GPUMatrixDynamic(uint32_t m, uint32_t n, cudaStream_t stream, MatrixLayout layout = CM) + : m_rows{m}, m_cols{n}, m_layout{layout} { + m_arena_allocation = std::make_shared(allocate_workspace(stream, m * n * sizeof(T))); + m_data = (T*)m_arena_allocation->data(); + set_stride_contiguous(); + } + + // Pointing to external memory + explicit GPUMatrixDynamic(T* data, uint32_t m, uint32_t n, MatrixLayout layout = CM, uint32_t stride = 0, std::shared_ptr> malloc_allocation = nullptr, std::shared_ptr arena_allocation = nullptr) + : m_data{data}, m_layout{layout}, m_malloc_allocation{malloc_allocation}, m_arena_allocation{arena_allocation} { + set(data, m, n, stride); + } + + GPUMatrixDynamic() : GPUMatrixDynamic{nullptr, 0, 0} {} + + GPUMatrixDynamic& operator=(GPUMatrixDynamic&& other) { + std::swap(m_data, other.m_data); + std::swap(m_rows, other.m_rows); + std::swap(m_cols, other.m_cols); + std::swap(m_stride, other.m_stride); + std::swap(m_layout, other.m_layout); + std::swap(m_malloc_allocation, other.m_malloc_allocation); + std::swap(m_arena_allocation, other.m_arena_allocation); + return *this; + } + + GPUMatrixDynamic(GPUMatrixDynamic&& other) { + *this = std::move(other); + } + + GPUMatrixDynamic(const GPUMatrixDynamic& other) = delete; + GPUMatrixDynamic& operator=(const GPUMatrixDynamic& other) = delete; + + virtual ~GPUMatrixDynamic() {} + + void set_data_unsafe(void* data) override { m_data = (T*)data; } + void set_size_unsafe(uint32_t rows, uint32_t cols, uint32_t stride = 0) { + m_rows = rows; + m_cols = cols; + + if (stride == 0) { + set_stride_contiguous(); + } else { + m_stride = stride; + } + } + + void set(T* data, uint32_t rows, uint32_t cols, uint32_t stride = 0) { + set_data_unsafe(data); + set_size_unsafe(rows, cols, stride); + } + + void resize(uint32_t rows, uint32_t cols) { + if (m_arena_allocation) { + cudaStream_t stream = m_arena_allocation->stream(); + m_arena_allocation.reset(); // reset is called explicitly to ensure memory is freed before being allocated + m_arena_allocation = std::make_shared(allocate_workspace(stream, rows * cols * sizeof(T))); + m_data = (T*)m_arena_allocation->data(); + } else if (m_malloc_allocation || !data()) { + m_malloc_allocation.reset(); // reset is called explicitly to ensure memory is freed before being allocated + m_malloc_allocation = std::make_shared>(rows * cols * sizeof(T)); + m_data = (T*)m_malloc_allocation->data(); + } else { + throw std::runtime_error{"GPUMatrix::resize is not permitted when the underlying memory is not owned. Use GPUMatrix::set instead."}; + } + + set_size_unsafe(rows, cols); + } + + uint32_t stride_contiguous() const { + return m_layout == CM ? m() : n(); + } + + bool is_contiguous() const { + return m_stride == stride_contiguous(); + } + + void set_stride_contiguous() { + m_stride = stride_contiguous(); + } + + GPUMatrixDynamic slice(uint32_t offset_rows, uint32_t new_rows, uint32_t offset_cols, uint32_t new_cols) const { + return GPUMatrixDynamic{ + data() + (layout() == CM ? (offset_rows + offset_cols * stride()) : (offset_cols + offset_rows * stride())), + new_rows, + new_cols, + layout(), + stride(), + m_malloc_allocation, + m_arena_allocation, + }; + } + + GPUMatrixDynamic slice_rows(uint32_t offset, uint32_t size) const { + return slice(offset, size, 0, cols()); + } + + GPUMatrixDynamic slice_cols(uint32_t offset, uint32_t size) const { + return slice(0, rows(), offset, size); + } + + GPUMatrixDynamic alias() const { + return slice(0, rows(), 0, cols()); + } + + View view() const { + return {data(), layout() == CM ? 1u : stride(), layout() == CM ? stride() : 1u}; + } + + ConstView const_view() const { + return view(); + } + + uint32_t rows() const { return m_rows; } + uint32_t fan_out() const { return m_rows; } + uint32_t m() const { return m_rows; } + + uint32_t cols() const { return m_cols; } + uint32_t fan_in() const { return m_cols; } + uint32_t n() const { return m_cols; } + + uint32_t stride() const { return m_stride; } + PitchedPtr pitched_ptr() { return {data(), stride()}; } + PitchedPtr pitched_ptr() const { return {data(), stride()}; } + + uint32_t n_elements() const { return m_rows * m_cols; } + size_t n_bytes() const override { return n_elements() * sizeof(T); } + + MatrixLayout layout() const { return m_layout; } + MatrixLayout transposed_layout() const { return m_layout == RM ? CM : RM; } + + T* data() const { return m_data; } + + void memset(int value) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + CUDA_CHECK_THROW(cudaMemset(data(), value, n_bytes())); + } + + void memset_async(cudaStream_t stream, int value) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + CUDA_CHECK_THROW(cudaMemsetAsync(data(), value, n_bytes(), stream)); + } + + std::vector to_cpu_vector() { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + std::vector v(n_elements()); + CUDA_CHECK_THROW(cudaMemcpy(v.data(), data(), n_bytes(), cudaMemcpyDeviceToHost)); + return v; + } + + // Various initializations + void initialize_uniform(pcg32& rnd, float low, float high) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + float scale = high - low; + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(low + rnd.next_float() * scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_xavier_uniform(pcg32& rnd, float scale = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + scale *= std::sqrt(6.0f / (float)(fan_in() + fan_out())); + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(rnd.next_float() * 2.0f * scale - scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_fa_uniform_forward(pcg32& rnd, float scale = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + scale *= std::sqrt(1.0f / (float)fan_in()); + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(rnd.next_float() * 2.0f * scale - scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_fa_uniform_backward(pcg32& rnd, float scale = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + scale *= std::sqrt(1.0f / (float)fan_out()); + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(rnd.next_float() * 2.0f * scale - scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_siren_uniform(pcg32& rnd, float scale = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + scale *= std::sqrt(6.0f / (float)fan_in()); + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(rnd.next_float() * 2.0f * scale - scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_siren_uniform_first(pcg32& rnd, float scale = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + // Define probability distribution + + // The 30 in the first layer comes from https://vsitzmann.github.io/siren/ + scale *= 30.0f / (float)fan_in(); + + // Sample initialized values + std::vector new_data(n_elements()); + + for (size_t i = 0; i < new_data.size(); ++i) { + new_data[i] = (T)(rnd.next_float() * 2.0f * scale - scale); + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_constant(float val) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + + std::vector new_data(n_elements(), (T)val); + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + void initialize_diagonal(float val = 1) { + CHECK_THROW(data()); + CHECK_THROW(is_contiguous()); + CHECK_THROW(n() == m()); // Must be square for diagonal init to make sense + + std::vector new_data(n_elements(), (T)0); + for (uint32_t i = 0; i < n(); ++i) { + new_data[i + i*n()] = (T)val; + } + + CUDA_CHECK_THROW(cudaMemcpy(data(), new_data.data(), n_bytes(), cudaMemcpyHostToDevice)); + } + + GPUMatrixDynamic transposed() const { + return GPUMatrixDynamic(data(), n(), m(), transposed_layout(), stride(), m_malloc_allocation, m_arena_allocation); + } + + GPUMatrix rm() const { + CHECK_THROW(m_layout == RM); + return GPUMatrix(data(), m(), n(), stride(), m_malloc_allocation, m_arena_allocation); + } + + GPUMatrix cm() const { + CHECK_THROW(m_layout == CM); + return GPUMatrix(data(), m(), n(), stride(), m_malloc_allocation, m_arena_allocation); + } + +private: + T* m_data; + uint32_t m_rows, m_cols, m_stride; + MatrixLayout m_layout; + + // References to corresponding memory allocations. These ensure that + // m_data does not accidentally become dangling. + std::shared_ptr> m_malloc_allocation; + std::shared_ptr m_arena_allocation; +}; + +template +class GPUMatrix : public GPUMatrixDynamic { +public: + static const MatrixLayout static_layout = _layout; + static const MatrixLayout static_transposed_layout = _layout == RM ? CM : RM; + + // Owning its memory as a GPUMemory + GPUMatrix(uint32_t m, uint32_t n) + : GPUMatrixDynamic{m, n, static_layout} { } + + // Owning its memory as an allocation from a stream's memory arena + GPUMatrix(uint32_t m, uint32_t n, cudaStream_t stream) + : GPUMatrixDynamic{m, n, stream, static_layout} { } + + // Pointing to external memory + explicit GPUMatrix(T* data, uint32_t m, uint32_t n, uint32_t stride = 0, std::shared_ptr> malloc_allocation = nullptr, std::shared_ptr arena_allocation = nullptr) + : GPUMatrixDynamic{data, m, n, static_layout, stride, malloc_allocation, arena_allocation} { } + + GPUMatrix() : GPUMatrix{nullptr, 0, 0} {} + + GPUMatrix& operator=(GPUMatrixDynamic&& other) { + *((GPUMatrixDynamic*)this) = std::move(other); + if (static_layout != this->layout()) { + throw std::runtime_error{"GPUMatrix must be constructed from a GPUMatrixDynamic with matching layout."}; + } + return *this; + } + + GPUMatrix(GPUMatrixDynamic&& other) noexcept { + *this = std::move(other); + } + + GPUMatrix& operator=(GPUMatrix&& other) noexcept { + *((GPUMatrixDynamic*)this) = std::move(other); + return *this; + } + + GPUMatrix(GPUMatrix&& other) noexcept { + *this = std::move(other); + } + + GPUMatrix(const GPUMatrixDynamic& other) = delete; + GPUMatrix& operator=(const GPUMatrixDynamic& other) = delete; + + virtual ~GPUMatrix() {} + + GPUMatrix slice(uint32_t offset_rows, uint32_t new_rows, uint32_t offset_cols, uint32_t new_cols) const { + return ((GPUMatrixDynamic*)this)->slice(offset_rows, new_rows, offset_cols, new_cols); + } + + GPUMatrix slice_rows(uint32_t offset, uint32_t size) const { + return ((GPUMatrixDynamic*)this)->slice_rows(offset, size); + } + + GPUMatrix slice_cols(uint32_t offset, uint32_t size) const { + return ((GPUMatrixDynamic*)this)->slice_cols(offset, size); + } + + GPUMatrix alias() const { + return ((GPUMatrixDynamic*)this)->alias(); + } + + GPUMatrix transposed() const { + return ((GPUMatrixDynamic*)this)->transposed(); + } +}; + +template +void GPUMatrixBase::allocate_shared_memory(GPUMemory& memory, std::vector>& matrices) { + std::vector matrix_pointers; + for (auto& matrix : matrices) { + matrix_pointers.emplace_back(&matrix); + } + allocate_shared_memory(memory, matrix_pointers); +} + +template +void GPUMatrixBase::allocate_shared_memory(GPUMemory& memory, std::vector>& matrices) { + std::vector matrix_pointers; + for (auto& matrix : matrices) { + matrix_pointers.emplace_back(&matrix); + } + allocate_shared_memory(memory, matrix_pointers); +} + +template +GPUMemoryArena::Allocation GPUMatrixBase::allocate_shared_memory(cudaStream_t stream, std::vector>& matrices) { + std::vector matrix_pointers; + for (auto& matrix : matrices) { + matrix_pointers.emplace_back(&matrix); + } + return allocate_shared_memory(stream, matrix_pointers); +} + +template +GPUMemoryArena::Allocation GPUMatrixBase::allocate_shared_memory(cudaStream_t stream, std::vector>& matrices) { + std::vector matrix_pointers; + for (auto& matrix : matrices) { + matrix_pointers.emplace_back(&matrix); + } + return allocate_shared_memory(stream, matrix_pointers); +} + +} diff --git a/gui/include/tiny-cuda-nn/gpu_memory.h b/gui/include/tiny-cuda-nn/gpu_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..70129eec33efd27b5a72c0101bc236587f2efec7 --- /dev/null +++ b/gui/include/tiny-cuda-nn/gpu_memory.h @@ -0,0 +1,728 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file gpu_memory.h + * @author Thomas Müller and Nikolaus Binder, NVIDIA + * @brief Managed memory on the GPU. Like a std::vector, memory is allocated either explicitly (resize/enlarge) + * or implicitly (resize_and_copy_from_host etc). Memory is always and automatically released in the destructor. + * Also contains a GPU memory arena for light-weight stream-ordered allocations of temporary memory. The + * memory arena makes use of virtual memory when available to avoid re-allocations during progressive growing. + */ + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tcnn { + +#define DEBUG_GUARD_SIZE 0 + +inline std::atomic& total_n_bytes_allocated() { + static std::atomic s_total_n_bytes_allocated{0}; + return s_total_n_bytes_allocated; +} + +/// Managed memory on the Device +template +class GPUMemory { +private: + T* m_data = nullptr; + size_t m_size = 0; // Number of elements + bool m_managed = false; + +public: + using Type = T; + using View = T*; + using ConstView = const T*; + + GPUMemory() {} + GPUMemory(size_t size, bool managed = false) : m_managed{managed} { + resize(size); + } + + GPUMemory& operator=(GPUMemory&& other) { + std::swap(m_data, other.m_data); + std::swap(m_size, other.m_size); + std::swap(m_managed, other.m_managed); + return *this; + } + + GPUMemory(GPUMemory&& other) { + *this = std::move(other); + } + + // Don't permit copy assignment to prevent performance accidents. + // Copy is permitted through an explicit copy constructor. + GPUMemory& operator=(const GPUMemory& other) = delete; + explicit GPUMemory(const GPUMemory& other) { + m_managed = other.managed(); + copy_from_device(other); + } + + void check_guards() const { +#if DEBUG_GUARD_SIZE > 0 + if (!m_data) + return; + uint8_t buf[DEBUG_GUARD_SIZE]; + const uint8_t *rawptr=(const uint8_t *)m_data; + cudaMemcpy(buf, rawptr-DEBUG_GUARD_SIZE, DEBUG_GUARD_SIZE, cudaMemcpyDeviceToHost); + for (int i=0;i 0 + CUDA_CHECK_THROW(cudaMemset(rawptr, 0xff, DEBUG_GUARD_SIZE)); + CUDA_CHECK_THROW(cudaMemset(rawptr + n_bytes + DEBUG_GUARD_SIZE, 0xfe, DEBUG_GUARD_SIZE)); +#endif + if (rawptr) rawptr += DEBUG_GUARD_SIZE; + m_data = (T*)(rawptr); + total_n_bytes_allocated() += n_bytes; + } + + void free_memory() { + if (!m_data) { + return; + } + + uint8_t *rawptr = (uint8_t*)m_data; + if (rawptr) rawptr -= DEBUG_GUARD_SIZE; + CUDA_CHECK_THROW(cudaFree(rawptr)); + + total_n_bytes_allocated() -= get_bytes(); + + m_data = nullptr; + m_size = 0; + } + + /// Frees memory again + TCNN_HOST_DEVICE ~GPUMemory() { +#ifndef __CUDA_ARCH__ + try { + if (m_data) { + free_memory(); + m_size = 0; + } + } catch (const std::runtime_error& error) { + // Don't need to report on memory-free problems when the driver is shutting down. + if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { + log_warning("Could not free memory: {}", error.what()); + } + } +#endif + } + + /** @name Resizing/enlargement + * @{ + */ + /// Resizes the array to the exact new size, even if it is already larger + void resize(const size_t size) { + if (m_size != size) { + if (m_size) { + try { + free_memory(); + } catch (const std::runtime_error& error) { + throw std::runtime_error{fmt::format("Could not free memory: {}", error.what())}; + } + } + + if (size > 0) { + try { + allocate_memory(size * sizeof(T)); + } catch (const std::runtime_error& error) { + throw std::runtime_error{fmt::format("Could not allocate memory: {}", error.what())}; + } + } + + m_size = size; + } + } + + /// Enlarges the array if its size is smaller + void enlarge(const size_t size) { + if (size > m_size) { + resize(size); + } + } + /** @} */ + + /** @name Memset + * @{ + */ + /// Sets the memory of the first num_elements to value + void memset(const int value, const size_t num_elements, const size_t offset = 0) { + if (num_elements + offset > m_size) { + throw std::runtime_error{fmt::format("Could not set memory: Number of elements {}+{} larger than allocated memory {}.", num_elements, offset, m_size)}; + } + + CUDA_CHECK_THROW(cudaMemset(m_data + offset, value, num_elements * sizeof(T))); + } + + /// Sets the memory of the all elements to value + void memset(const int value) { + memset(value, m_size); + } + /** @} */ + + /** @name Copy operations + * @{ + */ + /// Copy data of num_elements from the raw pointer on the host + void copy_from_host(const T* host_data, const size_t num_elements) { + CUDA_CHECK_THROW(cudaMemcpy(data(), host_data, num_elements * sizeof(T), cudaMemcpyHostToDevice)); + } + + /// Copy num_elements from the host vector + void copy_from_host(const std::vector& data, const size_t num_elements) { + if (data.size() < num_elements) { + throw std::runtime_error{fmt::format("Trying to copy {} elements, but vector size is only {}.", num_elements, data.size())}; + } + copy_from_host(data.data(), num_elements); + } + + /// Copies data from the raw host pointer to fill the entire array + void copy_from_host(const T* data) { + copy_from_host(data, m_size); + } + + /// Copies num_elements of data from the raw host pointer after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const T* data, const size_t num_elements) { + enlarge(num_elements); + copy_from_host(data, num_elements); + } + + /// Copies num_elements from the host vector after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const std::vector& data, const size_t num_elements) { + enlarge_and_copy_from_host(data.data(), num_elements); + } + + /// Copies the entire host vector after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const std::vector& data) { + enlarge_and_copy_from_host(data.data(), data.size()); + } + + /// Copies num_elements of data from the raw host pointer after resizing the array + void resize_and_copy_from_host(const T* data, const size_t num_elements) { + resize(num_elements); + copy_from_host(data, num_elements); + } + + /// Copies num_elements from the host vector after resizing the array + void resize_and_copy_from_host(const std::vector& data, const size_t num_elements) { + resize_and_copy_from_host(data.data(), num_elements); + } + + /// Copies the entire host vector after resizing the array + void resize_and_copy_from_host(const std::vector& data) { + resize_and_copy_from_host(data.data(), data.size()); + } + + /// Copies the entire host vector to the device. Fails if there is not enough space available. + void copy_from_host(const std::vector& data) { + if (data.size() < m_size) { + throw std::runtime_error{fmt::format("Trying to copy {} elements, but vector size is only {}.", m_size, data.size())}; + } + copy_from_host(data.data(), m_size); + } + + /// Copies num_elements of data from the raw host pointer to the device. Fails if there is not enough space available. + void copy_to_host(T* host_data, const size_t num_elements) const { + if (num_elements > m_size) { + throw std::runtime_error{fmt::format("Trying to copy {} elements, but memory size is only {}.", num_elements, m_size)}; + } + + CUDA_CHECK_THROW(cudaMemcpy(host_data, data(), num_elements * sizeof(T), cudaMemcpyDeviceToHost)); + } + + /// Copies num_elements from the device to a vector on the host + void copy_to_host(std::vector& data, const size_t num_elements) const { + if (data.size() < num_elements) { + throw std::runtime_error{fmt::format("Trying to copy {} elements, but vector size is only {}.", num_elements, data.size())}; + } + + copy_to_host(data.data(), num_elements); + } + + /// Copies num_elements from the device to a raw pointer on the host + void copy_to_host(T* data) const { + copy_to_host(data, m_size); + } + + /// Copies all elements from the device to a vector on the host + void copy_to_host(std::vector& data) const { + if (data.size() < m_size) { + throw std::runtime_error{fmt::format("Trying to copy {} elements, but vector size is only {}", m_size, data.size())}; + } + + copy_to_host(data.data(), m_size); + } + + /// Copies size elements from another device array to this one, automatically resizing it + void copy_from_device(const GPUMemory& other, const size_t size) { + if (size == 0) { + return; + } + + if (m_size < size) { + resize(size); + } + + CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, size * sizeof(T), cudaMemcpyDeviceToDevice)); + } + + /// Copies data from another device array to this one, automatically resizing it + void copy_from_device(const GPUMemory &other) { + copy_from_device(other, other.m_size); + } + + // Created an (owned) copy of the data + GPUMemory copy(size_t size) const { + GPUMemory result{size}; + result.copy_from_device(*this); + return result; + } + + GPUMemory copy() const { + return copy(m_size); + } + + T* data() const { + check_guards(); + return m_data; + } + + View view() const { return data(); } + ConstView const_view() const { return view(); } + + bool managed() const { + return m_managed; + } + + T& at(size_t idx) const { + if (!m_managed) { + throw std::runtime_error{fmt::format("GPUMemory::at() not permitted if not managed.")}; + } + + if (idx > m_size) { + throw std::runtime_error{fmt::format("GPUMemory out of bounds: idx={} size={}", idx, m_size)}; + } + + return m_data[idx]; + } + + TCNN_HOST_DEVICE T& operator[](size_t idx) const { +#ifdef DEBUG_BUFFER_OVERRUN + if (idx > m_size) { + printf("WARNING: buffer overrun of %p at idx %zu\n", idx); + } +#endif + return m_data[idx]; + } + + TCNN_HOST_DEVICE T& operator[](uint32_t idx) const { +#ifdef DEBUG_BUFFER_OVERRUN + if (idx > m_size) { + printf("WARNING: buffer overrun of %p at idx %u\n", idx); + } +#endif + return m_data[idx]; + } + + size_t get_num_elements() const { + return m_size; + } + + size_t size() const { + return get_num_elements(); + } + + size_t get_bytes() const { + return m_size * sizeof(T); + } + + size_t n_bytes() const { + return get_bytes(); + } + + size_t bytes() const { + return get_bytes(); + } +}; + +class GPUMemoryArena { +public: + GPUMemoryArena() { + m_device = cuda_device(); + + // Align memory at least by a cache line (128 bytes). + m_alignment = (size_t)128; + m_max_size = previous_multiple(cuda_memory_info().total, cuda_memory_granularity()); + + m_free_intervals = {{0, m_max_size}}; + + // Reserve an address range that would be sufficient for housing the entire + // available GPU RAM (if nothing else was using the GPU). This is unlikely + // to exhaust all available addresses (even if multiple GPUMemoryArenas are + // used simultaneously), while also ensuring that we never exhaust the + // reserved address range without running out of physical memory beforehand. + if (cuda_supports_virtual_memory() && cuMemAddressReserve(&m_base_address, m_max_size, 0, 0, 0) == CUDA_SUCCESS) { + return; + } + + // Use regular memory as fallback + m_fallback_memory = std::make_shared>(); + + static bool printed_warning = false; + if (!printed_warning) { + printed_warning = true; + log_warning( + "GPUMemoryArena: GPU {} does not support virtual memory. " + "Falling back to regular allocations, which will be larger and can cause occasional stutter.", + m_device + ); + } + } + + GPUMemoryArena(GPUMemoryArena&& other) = default; + GPUMemoryArena(const GPUMemoryArena& other) = delete; + GPUMemoryArena& operator=(GPUMemoryArena&& other) = delete; + GPUMemoryArena& operator=(const GPUMemoryArena& other) = delete; + + ~GPUMemoryArena() { + if (in_use()) { + log_warning("Attempting to free memory arena while it is still in use."); + } + + try { + // Make sure we're clearing the GPU memory arena on the correct device. + int previous_device = cuda_device(); + set_cuda_device(m_device); + ScopeGuard revert_device = {[&]() { set_cuda_device(previous_device); }}; + + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + if (m_base_address) { + total_n_bytes_allocated() -= m_size; + + CU_CHECK_THROW(cuMemUnmap(m_base_address, m_size)); + + for (const auto& handle : m_handles) { + CU_CHECK_THROW(cuMemRelease(handle)); + } + + CU_CHECK_THROW(cuMemAddressFree(m_base_address, m_max_size)); + } + } catch (const std::runtime_error& error) { + // Don't need to report on memory-free problems when the driver is shutting down. + if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { + log_warning("Could not free memory arena: {}", error.what()); + } + } + } + + uint8_t* data() { + return m_fallback_memory ? m_fallback_memory->data() : (uint8_t*)m_base_address; + } + + std::shared_ptr> backing_memory() { + return m_fallback_memory; + } + + // Finds the smallest interval of free memory in the GPUMemoryArena that's + // large enough to hold the requested number of bytes. Then allocates + // that memory. + size_t allocate(size_t n_bytes) { + // Permitting zero-sized allocations is error prone + if (n_bytes == 0) { + n_bytes = m_alignment; + } + + // Align allocations with the nearest cache line (at least the granularity of the memory allocations) + n_bytes = next_multiple(n_bytes, m_alignment); + + Interval* best_candidate = &m_free_intervals.back(); + for (auto& f : m_free_intervals) { + if (f.size() >= n_bytes && f.size() < best_candidate->size()) { + best_candidate = &f; + } + } + + size_t start = best_candidate->start; + + // Note: the += operator can turn `best_candidate` into an empty interval, which is fine because it will + // be absorbed into adjacent free intervals in later calls to `merge_adjacent_intervals`. + m_allocated_intervals[start] = best_candidate->start += n_bytes; + + enlarge(size()); + + return start; + } + + void free(size_t start) { + if (m_allocated_intervals.count(start) == 0) { + throw std::runtime_error{"Attempted to free arena memory that was not allocated."}; + } + + Interval interval = {start, m_allocated_intervals[start]}; + m_allocated_intervals.erase(start); + + m_free_intervals.insert( + std::upper_bound(std::begin(m_free_intervals), std::end(m_free_intervals), interval), + interval + ); + + merge_adjacent_intervals(); + } + + void enlarge(size_t n_bytes) { + if (n_bytes <= m_size) { + return; + } + + if (cuda_device() != m_device) { + throw std::runtime_error{fmt::format("Attempted to use a GPUMemoryArena of device {} from the wrong device {}.", m_device, cuda_device())}; + } + + log_debug("GPUMemoryArena: enlarging from {} to {}", bytes_to_string(m_size), bytes_to_string(n_bytes)); + + if (m_fallback_memory) { + static const double GROWTH_FACTOR = 1.5; + + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + m_size = next_multiple((size_t)(n_bytes * GROWTH_FACTOR), cuda_memory_granularity()); + m_fallback_memory = std::make_shared>(m_fallback_memory->copy(m_size)); + + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + return; + } + + size_t n_bytes_to_allocate = n_bytes - m_size; + n_bytes_to_allocate = next_multiple(n_bytes_to_allocate, cuda_memory_granularity()); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = m_device; + + m_handles.emplace_back(); + CU_CHECK_THROW(cuMemCreate(&m_handles.back(), n_bytes_to_allocate, &prop, 0)); + + CUmemAccessDesc access_desc = {}; + access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = prop.location.id; + access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CU_CHECK_THROW(cuMemMap(m_base_address + m_size, n_bytes_to_allocate, 0, m_handles.back(), 0)); + CU_CHECK_THROW(cuMemSetAccess(m_base_address + m_size, n_bytes_to_allocate, &access_desc, 1)); + m_size += n_bytes_to_allocate; + + total_n_bytes_allocated() += n_bytes_to_allocate; + + // Need to synchronize the device to make sure memory is available to all streams. + if (current_capture()) { + current_capture()->schedule_synchronize(); + } else { + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + } + } + + size_t size() const { + return m_free_intervals.back().start; + } + + bool in_use() const { + return m_free_intervals.size() != 1 || m_free_intervals.front().size() != m_max_size; + } + + class Allocation { + public: + Allocation() = default; + Allocation(cudaStream_t stream, size_t offset, const std::shared_ptr& workspace) + : m_stream{stream}, m_data{workspace->data() + offset}, m_offset{offset}, m_workspace{workspace}, m_backing_memory{workspace->backing_memory()} + {} + + ~Allocation() { + if (m_workspace) { + m_workspace->free(m_offset); + } + } + + Allocation(const Allocation& other) = delete; + + Allocation& operator=(Allocation&& other) { + std::swap(m_stream, other.m_stream); + std::swap(m_data, other.m_data); + std::swap(m_offset, other.m_offset); + std::swap(m_workspace, other.m_workspace); + std::swap(m_backing_memory, other.m_backing_memory); + return *this; + } + + Allocation(Allocation&& other) { + *this = std::move(other); + } + + uint8_t* data() { + return m_data; + } + + const uint8_t* data() const { + return m_data; + } + + cudaStream_t stream() const { + return m_stream; + } + + private: + cudaStream_t m_stream = nullptr; + uint8_t* m_data = nullptr; + size_t m_offset = 0; + std::shared_ptr m_workspace = nullptr; + + // Backing GPUMemory (if backed by a GPUMemory). Ensures that + // the backing memory is only freed once all allocations that + // use it were destroyed. + std::shared_ptr> m_backing_memory = nullptr; + }; + +private: + void merge_adjacent_intervals() { + size_t j = 0; + for (size_t i = 1; i < m_free_intervals.size(); ++i) { + Interval& prev = m_free_intervals[j]; + Interval& cur = m_free_intervals[i]; + + if (prev.end == cur.start) { + prev.end = cur.end; + } else { + ++j; + m_free_intervals[j] = m_free_intervals[i]; + } + } + m_free_intervals.resize(j+1); + } + + std::vector> m_free_intervals; + std::unordered_map m_allocated_intervals; + + int m_device = 0; + CUdeviceptr m_base_address = {}; + size_t m_size = 0; + + std::vector m_handles; + + // Used then virtual memory isn't supported. + // Requires more storage + memcpy, but is more portable. + std::shared_ptr> m_fallback_memory = nullptr; + + size_t m_alignment; + size_t m_max_size; +}; + +inline std::unordered_map>& stream_gpu_memory_arenas() { + static auto* stream_gpu_memory_arenas = new std::unordered_map>{}; + return *stream_gpu_memory_arenas; +} + +inline std::unordered_map>& global_gpu_memory_arenas() { + static auto* global_gpu_memory_arenas = new std::unordered_map>{}; + return *global_gpu_memory_arenas; +} + +inline GPUMemoryArena::Allocation allocate_workspace(cudaStream_t stream, size_t n_bytes) { + if (n_bytes == 0) { + // Return a null allocation if no bytes were requested. + return {}; + } + + auto& arena = stream ? stream_gpu_memory_arenas()[stream] : global_gpu_memory_arenas()[cuda_device()]; + if (!arena) { + arena = std::make_shared(); + } + return GPUMemoryArena::Allocation{stream, arena->allocate(n_bytes), arena}; +} + +inline size_t align_to_cacheline(size_t bytes) { + return next_multiple(bytes, (size_t)128); +} + +template +std::tuple allocate_workspace_and_distribute(cudaStream_t stream, GPUMemoryArena::Allocation* alloc, size_t offset, FirstSize first_size) { + *alloc = allocate_workspace(stream, offset + align_to_cacheline(first_size * sizeof(First))); + return std::make_tuple((First*)(alloc->data() + offset)); +} + +template = 0> +std::tuple allocate_workspace_and_distribute(cudaStream_t stream, GPUMemoryArena::Allocation* alloc, size_t offset, FirstSize first_size, Sizes... sizes) { + auto nested = allocate_workspace_and_distribute(stream, alloc, offset + align_to_cacheline(first_size * sizeof(First)), sizes...); + return std::tuple_cat(std::make_tuple((First*)(alloc->data() + offset)), nested); +} + +template = 0> +std::tuple allocate_workspace_and_distribute(cudaStream_t stream, GPUMemoryArena::Allocation* alloc, Sizes... sizes) { + return allocate_workspace_and_distribute(stream, alloc, (size_t)0, sizes...); +} + +inline void free_gpu_memory_arena(cudaStream_t stream) { + if (stream) { + stream_gpu_memory_arenas().erase(stream); + } else { + global_gpu_memory_arenas().erase(cuda_device()); + } +} + +inline void free_all_gpu_memory_arenas() { + stream_gpu_memory_arenas().clear(); + global_gpu_memory_arenas().clear(); +} + +} diff --git a/gui/include/tiny-cuda-nn/gpu_memory_json.h b/gui/include/tiny-cuda-nn/gpu_memory_json.h new file mode 100644 index 0000000000000000000000000000000000000000..9f58369c47c279844a2c5e3518e8da5686a62cd0 --- /dev/null +++ b/gui/include/tiny-cuda-nn/gpu_memory_json.h @@ -0,0 +1,66 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file gpu_memory_json.h + * @author Nikolaus Binder and Thomas Müller, NVIDIA + * @brief binding between GPUMemory and JSON librariy + */ + +#pragma once + +#include + +namespace tcnn { + +inline nlohmann::json::binary_t gpu_memory_to_json_binary(const void* gpu_data, size_t n_bytes) { + nlohmann::json::binary_t data_cpu; + data_cpu.resize(n_bytes); + CUDA_CHECK_THROW(cudaMemcpy(data_cpu.data(), gpu_data, n_bytes, cudaMemcpyDeviceToHost)); + return data_cpu; +} + +inline void json_binary_to_gpu_memory(const nlohmann::json::binary_t& cpu_data, void* gpu_data, size_t n_bytes) { + CUDA_CHECK_THROW(cudaMemcpy(gpu_data, cpu_data.data(), n_bytes, cudaMemcpyHostToDevice)); +} + +template +inline void to_json(nlohmann::json& j, const GPUMemory& gpu_data) { + j = gpu_memory_to_json_binary(gpu_data.data(), gpu_data.get_bytes()); +} + +template +inline void from_json(const nlohmann::json& j, GPUMemory& gpu_data) { + if (j.is_binary()) { + const nlohmann::json::binary_t& cpu_data = j.get_binary(); + gpu_data.resize(cpu_data.size()/sizeof(T)); + json_binary_to_gpu_memory(cpu_data, gpu_data.data(), gpu_data.get_bytes()); + } else if (j.is_object()) { + // https://json.nlohmann.me/features/binary_values/#json + json::array_t arr = j["bytes"]; + nlohmann::json::binary_t cpu_data; + cpu_data.resize(arr.size()); + for(size_t i = 0; i < arr.size(); ++i) { + cpu_data[i] = (uint8_t)arr[i]; + } + gpu_data.resize(cpu_data.size()/sizeof(T)); + json_binary_to_gpu_memory(cpu_data, gpu_data.data(), gpu_data.get_bytes()); + } else { + throw std::runtime_error("Invalid json type: must be either binary or object"); + } +} + +} diff --git a/gui/include/tiny-cuda-nn/multi_stream.h b/gui/include/tiny-cuda-nn/multi_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5ee7450751d067c630f5c4d3290da46faa8ace --- /dev/null +++ b/gui/include/tiny-cuda-nn/multi_stream.h @@ -0,0 +1,247 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file multi_stream.h + * @author Thomas Müller, NVIDIA + * @brief Helper class for parallelizing workload across multiple streams. + */ + +#pragma once + +#include + +#include + +namespace tcnn { + +void free_multi_streams(cudaStream_t parent_stream); + +// Synchronization helpers +struct StreamAndEvent { +public: + StreamAndEvent() { + CUDA_CHECK_THROW(cudaStreamCreate(&m_stream)); + CUDA_CHECK_THROW(cudaEventCreate(&m_event)); + } + + ~StreamAndEvent() { + if (m_stream) { + free_multi_streams(m_stream); + free_gpu_memory_arena(m_stream); + cudaStreamDestroy(m_stream); + } + + if (m_event) { + cudaEventDestroy(m_event); + } + } + + // Only allow moving of these guys. No copying. + StreamAndEvent& operator=(const StreamAndEvent&) = delete; + StreamAndEvent(const StreamAndEvent&) = delete; + StreamAndEvent& operator=(StreamAndEvent&& other) { + std::swap(m_stream, other.m_stream); + std::swap(m_event, other.m_event); + return *this; + } + + StreamAndEvent(StreamAndEvent&& other) { + *this = std::move(other); + } + + void wait_for(cudaEvent_t event) { + CUDA_CHECK_THROW(cudaStreamWaitEvent(m_stream, event, 0)); + } + + void wait_for(cudaStream_t stream) { + CUDA_CHECK_THROW(cudaEventRecord(m_event, stream)); + wait_for(m_event); + } + + void signal(cudaStream_t stream) { + CUDA_CHECK_THROW(cudaEventRecord(m_event, m_stream)); + CUDA_CHECK_THROW(cudaStreamWaitEvent(stream, m_event, 0)); + } + + cudaStream_t get() { + return m_stream; + } + +private: + cudaStream_t m_stream = {}; + cudaEvent_t m_event = {}; +}; + +struct MultiStream { +public: + MultiStream() { + CUDA_CHECK_THROW(cudaEventCreate(&m_event)); + } + + ~MultiStream() { + cudaEventDestroy(m_event); + } + + MultiStream& operator=(const MultiStream&) = delete; + MultiStream(const MultiStream&) = delete; + MultiStream& operator=(MultiStream&&) = delete; + MultiStream(MultiStream&&) = delete; + + void signal(cudaStream_t outer_stream) { + for (size_t i = 0; i < m_n_streams; ++i) { + m_streams[i].signal(outer_stream); + } + } + + void wait_for(cudaStream_t stream) { + if (m_n_streams == 0) { + return; + } + + CUDA_CHECK_THROW(cudaEventRecord(m_event, stream)); + for (size_t i = 0; i < m_n_streams; ++i) { + m_streams[i].wait_for(m_event); + } + } + + void resize(size_t n_streams) { + if (n_streams > m_streams.size()) { + m_streams.resize(n_streams); + } + m_n_streams = n_streams; + } + + cudaStream_t get(size_t idx) { + if (idx >= m_n_streams) { + throw std::runtime_error{fmt::format("MultiStream: invalid stream index requested: {}/{}", idx, m_n_streams)}; + } + return m_streams.at(idx).get(); + } + +private: + std::vector m_streams; + // May be less than m_streams.size()! + // The user may only need to sync fewer than that. + size_t m_n_streams = 0; + cudaEvent_t m_event; +}; + +inline std::unordered_map>>& stream_multi_streams() { + static auto* stream_multi_streams = new std::unordered_map>>{}; + return *stream_multi_streams; +} + +inline std::unordered_map>>& global_multi_streams() { + static auto* global_multi_streams = new std::unordered_map>>{}; + return *global_multi_streams; +} + +inline std::stack>& get_multi_stream_stack(cudaStream_t parent_stream) { + return parent_stream ? stream_multi_streams()[parent_stream] : global_multi_streams()[cuda_device()]; +} + +inline void free_multi_streams(cudaStream_t parent_stream) { + CHECK_THROW(parent_stream); + + // Copy the multi stream shared_ptr's into a separate variable, + // such that their destruction happens after unordered_map::erase(...) + // is already finished. This alleviates potential non-reentrancy problems. + auto multi_streams = stream_multi_streams()[parent_stream]; + stream_multi_streams().erase(parent_stream); +} + +inline std::shared_ptr reserve_multi_stream(cudaStream_t parent_stream, size_t n_streams) { + auto& stack = get_multi_stream_stack(parent_stream); + if (stack.empty()) { + stack.push(std::make_shared()); + } + auto result = stack.top(); + stack.pop(); + + result->resize(n_streams); + return result; +} + +inline void return_multi_stream(cudaStream_t parent_stream, std::shared_ptr multi_stream) { + if (parent_stream ? (stream_multi_streams().count(parent_stream) == 0) : (global_multi_streams().count(cuda_device()) == 0)) { + throw std::runtime_error{"Attempted to return multi stream to the wrong parent stream."}; + } + + auto& stack = get_multi_stream_stack(parent_stream); + stack.push(multi_stream); +} + +// RAII wrapper around MultiStream +struct SyncedMultiStream { +public: + SyncedMultiStream() = default; + SyncedMultiStream(cudaStream_t stream, size_t n_streams) : m_main_stream{stream}, m_n_streams{n_streams} { + if (m_n_streams == 0) { + throw std::runtime_error{"SyncedMultiStream: must request at least one stream"}; + } else if (m_n_streams == 1) { + return; + } + + m_multi_stream = reserve_multi_stream(m_main_stream, m_n_streams-1); + m_multi_stream->wait_for(m_main_stream); + } + + ~SyncedMultiStream() { + if (m_multi_stream) { + m_multi_stream->signal(m_main_stream); + return_multi_stream(m_main_stream, m_multi_stream); + } + } + + // Only allow moving of these guys. No copying. + SyncedMultiStream& operator=(const SyncedMultiStream& other) = delete; + SyncedMultiStream(const SyncedMultiStream&) = delete; + + SyncedMultiStream& operator=(SyncedMultiStream&& other) { + std::swap(m_multi_stream, other.m_multi_stream); + std::swap(m_main_stream, other.m_main_stream); + std::swap(m_n_streams, other.m_n_streams); + return *this; + } + + SyncedMultiStream(SyncedMultiStream&& other) { + *this = std::move(other); + } + + cudaStream_t get(size_t idx) { + if (m_n_streams == 0) { + throw std::runtime_error{"SyncedMultiStream: must have at least one stream"}; + } + + if (idx == 0) { + return m_main_stream; + } else { + if (!m_multi_stream) { + throw std::runtime_error{"SyncedMultiStream: invalid multistream"}; + } + + return m_multi_stream->get(idx-1); + } + } + +private: + std::shared_ptr m_multi_stream = nullptr; + cudaStream_t m_main_stream = nullptr; + size_t m_n_streams = 0; +}; + +} diff --git a/gui/include/tiny-cuda-nn/random.h b/gui/include/tiny-cuda-nn/random.h new file mode 100644 index 0000000000000000000000000000000000000000..eb81947fec0761509a803a9c7bc59732369a5cda --- /dev/null +++ b/gui/include/tiny-cuda-nn/random.h @@ -0,0 +1,80 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file random.h + * @author Thomas Müller, NVIDIA + * @brief Collection of CUDA kernels related to random numbers + */ + +#pragma once + +#include +#include + +#include + +namespace tcnn { + +template +__global__ void generate_random_kernel(const size_t n_elements, RNG rng, T* __restrict__ out, const F transform) { + const size_t i = threadIdx.x + blockIdx.x * blockDim.x; + const size_t n_threads = blockDim.x * gridDim.x; + + rng.advance(i*N_TO_GENERATE); + + TCNN_PRAGMA_UNROLL + for (size_t j = 0; j < N_TO_GENERATE; ++j) { + const size_t idx = i + n_threads * j; + if (idx >= n_elements) { + return; + } + + out[idx] = transform((T)rng.next_float()); + } +} + +template +void generate_random(cudaStream_t stream, RNG& rng, size_t n_elements, T* out, F&& transform) { + static constexpr size_t N_TO_GENERATE = 4; + + size_t n_threads = div_round_up(n_elements, N_TO_GENERATE); + generate_random_kernel<<>>(n_elements, rng, out, transform); + + rng.advance(n_elements); +} + +template +void generate_random_uniform(cudaStream_t stream, RNG& rng, size_t n_elements, T* out, const T lower = (T)0.0, const T upper = (T)1.0) { + generate_random(stream, rng, n_elements, out, [upper, lower] __device__ (T val) { return val * (upper - lower) + lower; }); +} + +template +void generate_random_uniform(RNG& rng, size_t n_elements, T* out, const T lower = (T)0.0, const T upper = (T)1.0) { + generate_random_uniform(nullptr, rng, n_elements, out, lower, upper); +} + +template +void generate_random_logistic(cudaStream_t stream, RNG& rng, size_t n_elements, T* out, const T mean = (T)0.0, const T stddev = (T)1.0) { + generate_random(stream, rng, n_elements, out, [mean, stddev] __device__ (T val) { return (T)logit(val) * stddev * 0.551328895f + mean; }); +} + +template +void generate_random_logistic(RNG& rng, size_t n_elements, T* out, const T mean = (T)0.0, const T stddev = (T)1.0) { + generate_random_logistic(nullptr, rng, n_elements, out, mean, stddev); +} + +} diff --git a/gui/include/tiny-cuda-nn/vec.h b/gui/include/tiny-cuda-nn/vec.h new file mode 100644 index 0000000000000000000000000000000000000000..3614a453d2486a813820ffe754a4d846847e6447 --- /dev/null +++ b/gui/include/tiny-cuda-nn/vec.h @@ -0,0 +1,1203 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file vec.h + * @author Thomas Müller, NVIDIA + * @brief Tiny vector / matrix / quaternion implementation. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace tcnn { + +template struct conjunction : std::true_type {}; +template struct conjunction : B1 {}; +template struct conjunction : std::conditional_t, B1> {}; + +template +using enable_if_size_and_type_match_t = std::enable_if_t...>::value>; + +#define TVEC_BODY \ + using underlying_type = T; \ + \ + tvec() = default; \ + \ + TCNN_HOST_DEVICE tvec(T scalar) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + (*this)[i] = scalar; \ + } \ + } \ + \ + TCNN_HOST_DEVICE static constexpr tvec ones() { return tvec((T)1); } \ + TCNN_HOST_DEVICE static constexpr tvec zero() { return tvec((T)0); } \ + \ + TCNN_HOST_DEVICE tvec(const T* coeffs) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + (*this)[i] = coeffs[i]; \ + } \ + } \ + \ + template \ + TCNN_HOST_DEVICE tvec(const tvec& other) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + (*this)[i] = i < M ? (T)other[i] : (T)0; \ + } \ + } \ + \ + TCNN_HOST_DEVICE void to_array(T* coeffs) const { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + coeffs[i] = (*this)[i]; \ + } \ + } \ + \ + TCNN_HOST_DEVICE T* data() { return (T*)this; } \ + TCNN_HOST_DEVICE const T* data() const { return (const T*)this; } \ + \ + TCNN_HOST_DEVICE T& operator[](uint32_t idx) { return ((T*)this)[idx]; } \ + TCNN_HOST_DEVICE const T& operator[](uint32_t idx) const { return ((T*)this)[idx]; } \ + TCNN_HOST_DEVICE T& operator()(uint32_t idx) { return ((T*)this)[idx]; } \ + TCNN_HOST_DEVICE const T& operator()(uint32_t idx) const { return ((T*)this)[idx]; } \ + \ + template \ + TCNN_HOST_DEVICE tvec& slice() { \ + static_assert(OFFSET + M <= N, "Slice must be part of the vector."); \ + return *(tvec*)(data() + OFFSET); \ + } \ + \ + template \ + TCNN_HOST_DEVICE const tvec& slice() const { \ + static_assert(OFFSET + M <= N, "Slice must be part of the vector."); \ + return *(tvec*)(data() + OFFSET); \ + } \ + \ + TCNN_HOST_DEVICE tvec& xy() { return slice<0, 2>(); } \ + TCNN_HOST_DEVICE const tvec& xy() const { return slice<0, 2>(); } \ + TCNN_HOST_DEVICE tvec& yz() { return slice<1, 2>(); } \ + TCNN_HOST_DEVICE const tvec& yz() const { return slice<1, 2>(); } \ + TCNN_HOST_DEVICE tvec& xyz() { return slice<0, 3>(); } \ + TCNN_HOST_DEVICE const tvec& xyz() const { return slice<0, 3>(); } \ + TCNN_HOST_DEVICE tvec& rgb() { return slice<0, 3>(); } \ + TCNN_HOST_DEVICE const tvec& rgb() const { return slice<0, 3>(); } \ + TCNN_HOST_DEVICE tvec& xyzw() { return slice<0, 4>(); } \ + TCNN_HOST_DEVICE const tvec& rgba() const { return slice<0, 4>(); } \ + \ + TCNN_HOST_DEVICE static constexpr uint32_t size() { return N; } + +template +struct alignas(ALIGNMENT) tvec { + TVEC_BODY + T elems[N]; + + template > + TCNN_HOST_DEVICE tvec(Ts... coeffs) : elems{coeffs...} {} +}; + +template +struct alignas(ALIGNMENT) tvec { + static constexpr uint32_t N = 1; + TVEC_BODY + union { T x, r; }; +}; + +template +struct alignas(ALIGNMENT) tvec { + static constexpr uint32_t N = 2; + TVEC_BODY + union { T x, r; }; + union { T y, g; }; + + TCNN_HOST_DEVICE tvec(T a, T b) : x{a}, y{b} {} +}; + +template +struct alignas(ALIGNMENT) tvec { + static constexpr uint32_t N = 3; + TVEC_BODY + union { T x, r; }; + union { T y, g; }; + union { T z, b; }; + + TCNN_HOST_DEVICE tvec(T a, T b, T c) : x{a}, y{b}, z{c} {} + template TCNN_HOST_DEVICE tvec(const tvec& a, T b) : x{a.x}, y{a.y}, z{b} {} + template TCNN_HOST_DEVICE tvec(T a, const tvec& b) : x{a}, y{b.x}, z{b.y} {} +}; + +template +struct alignas(ALIGNMENT) tvec { + static constexpr uint32_t N = 4; + TVEC_BODY + union { T x, r; }; + union { T y, g; }; + union { T z, b; }; + union { T w, a; }; + + TCNN_HOST_DEVICE tvec(T a, T b, T c, T d) : x{a}, y{b}, z{c}, w{d} {} + template TCNN_HOST_DEVICE tvec(const tvec& a, T b) : x{a.x}, y{a.y}, z{a.z}, w{b} {} + template TCNN_HOST_DEVICE tvec(const tvec& a, const tvec& b) : x{a.x}, y{a.y}, z{b.x}, w{b.y} {} + template TCNN_HOST_DEVICE tvec(const tvec& a, T b, T c) : x{a.x}, y{a.y}, z{b}, w{c} {} + template TCNN_HOST_DEVICE tvec(T a, const tvec& b, T c) : x{a}, y{b.x}, z{b.y}, w{c} {} + template TCNN_HOST_DEVICE tvec(T a, T b, const tvec& c) : x{a}, y{b}, z{c.x}, w{c.y} {} + template TCNN_HOST_DEVICE tvec(T a, const tvec& b) : x{a}, y{b.x}, z{b.y}, w{b.z} {} +}; + +#undef TVEC_BODY + +// Import external cwise functions into ngp namespace to avoid +// name resolution problems related to the vector-values versions defined below. +template TCNN_HOST_DEVICE T min(T a, T b) { return std::min(a, b); } +template TCNN_HOST_DEVICE T max(T a, T b) { return std::max(a, b); } +template TCNN_HOST_DEVICE T clamp(T a, T b, T c) { return a < b ? b : (c < a ? c : a); } +template TCNN_HOST_DEVICE T copysign(T a, T b) { return std::copysign(a, b); } +template TCNN_HOST_DEVICE T sign(T a) { return std::copysign((T)1, a); } +template TCNN_HOST_DEVICE T mix(T a, T b, T c) { return a * ((T)1 - c) + b * c; } +template TCNN_HOST_DEVICE T floor(T a) { return std::floor(a); } +template TCNN_HOST_DEVICE T round(T a) { return std::round(a); } +template TCNN_HOST_DEVICE T ceil(T a) { return std::ceil(a); } +template TCNN_HOST_DEVICE T abs(T a) { return std::abs(a); } +template TCNN_HOST_DEVICE T distance(T a, T b) { return std::abs(a - b); } +template TCNN_HOST_DEVICE T sin(T a) { return std::sin(a); } +template TCNN_HOST_DEVICE T asin(T a) { return std::asin(a); } +template TCNN_HOST_DEVICE T cos(T a) { return std::cos(a); } +template TCNN_HOST_DEVICE T acos(T a) { return std::acos(a); } +template TCNN_HOST_DEVICE T tan(T a) { return std::tan(a); } +template TCNN_HOST_DEVICE T atan(T a) { return std::atan(a); } +template TCNN_HOST_DEVICE T sqrt(T a) { return std::sqrt(a); } +template TCNN_HOST_DEVICE T exp(T a) { return std::exp(a); } +template TCNN_HOST_DEVICE T log(T a) { return std::log(a); } +template TCNN_HOST_DEVICE T exp2(T a) { return std::exp2(a); } +template TCNN_HOST_DEVICE T log2(T a) { return std::log2(a); } +template TCNN_HOST_DEVICE T pow(T a, T b) { return std::pow(a, b); } +template TCNN_HOST_DEVICE T isfinite(T a) { +#if defined(__CUDA_ARCH__) + return ::isfinite(a); +#else + return std::isfinite(a); +#endif +} + +inline TCNN_HOST_DEVICE float fma(float a, float b, float c) { return fmaf(a, b, c); } +#ifdef __CUDACC__ +inline TCNN_DEVICE __half fma(__half a, __half b, __half c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + return __hfma(a, b, c); +#else + return fmaf(a, b, c); +#endif +} +#endif + +#define TVEC tvec +#define BVEC tvec + +#define CWISE_OP(operation, type_result, expr, ...) \ +template \ +TCNN_HOST_DEVICE type_result operation(__VA_ARGS__) { \ + type_result result; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + result[i] = expr; \ + } \ + return result; \ +} + +CWISE_OP(operator+, TVEC, a[i] + b[i], const TVEC& a, const TVEC& b) +CWISE_OP(operator+, TVEC, a + b[i], T a, const TVEC& b) +CWISE_OP(operator+, TVEC, a[i] + b, const TVEC& a, T b) + +CWISE_OP(operator-, TVEC, a[i] - b[i], const TVEC& a, const TVEC& b) +CWISE_OP(operator-, TVEC, a - b[i], T a, const TVEC& b) +CWISE_OP(operator-, TVEC, a[i] - b, const TVEC& a, T b) + +CWISE_OP(operator*, TVEC, a[i] * b[i], const TVEC& a, const TVEC& b) +CWISE_OP(operator*, TVEC, a * b[i], T a, const TVEC& b) +CWISE_OP(operator*, TVEC, a[i] * b, const TVEC& a, T b) + +CWISE_OP(operator/, TVEC, a[i] / b[i], const TVEC& a, const TVEC& b) +CWISE_OP(operator/, TVEC, a / b[i], T a, const TVEC& b) +CWISE_OP(operator/, TVEC, a[i] / b, const TVEC& a, T b) + +CWISE_OP(fma, TVEC, fma(a[i], b[i], c[i]), const TVEC& a, const TVEC& b, const TVEC& c) +CWISE_OP(fma, TVEC, fma(a[i], b[i], c), const TVEC& a, const TVEC& b, T c) +CWISE_OP(fma, TVEC, fma(a[i], b, c[i]), const TVEC& a, T b, const TVEC& c) +CWISE_OP(fma, TVEC, fma(a[i], b, c), const TVEC& a, T b, T c) +CWISE_OP(fma, TVEC, fma(a, b[i], c[i]), T a, const TVEC& b, const TVEC& c) +CWISE_OP(fma, TVEC, fma(a, b[i], c), T a, const TVEC& b, T c) +CWISE_OP(fma, TVEC, fma(a, b, c[i]), T a, T b, const TVEC& c) + +CWISE_OP(min, TVEC, min(a[i], b[i]), const TVEC& a, const TVEC& b) +CWISE_OP(min, TVEC, min(a[i], b), const TVEC& a, T b) +CWISE_OP(min, TVEC, min(a, b[i]), T a, const TVEC& b) + +CWISE_OP(max, TVEC, max(a[i], b[i]), const TVEC& a, const TVEC& b) +CWISE_OP(max, TVEC, max(a[i], b), const TVEC& a, T b) +CWISE_OP(max, TVEC, max(a, b[i]), T a, const TVEC& b) + +CWISE_OP(clamp, TVEC, clamp(a[i], b[i], c[i]), const TVEC& a, const TVEC& b, const TVEC& c) +CWISE_OP(clamp, TVEC, clamp(a[i], b[i], c), const TVEC& a, const TVEC& b, T c) +CWISE_OP(clamp, TVEC, clamp(a[i], b, c[i]), const TVEC& a, T b, const TVEC& c) +CWISE_OP(clamp, TVEC, clamp(a[i], b, c), const TVEC& a, T b, T c) + +CWISE_OP(copysign, TVEC, copysign(a[i], b[i]), const TVEC& a, const TVEC& b) +CWISE_OP(copysign, TVEC, copysign(a[i], b), const TVEC& a, T b) +CWISE_OP(copysign, TVEC, copysign(a, b[i]), T a, const TVEC& b) + +CWISE_OP(sign, TVEC, sign(a[i]), const TVEC& a) + +CWISE_OP(mix, TVEC, a[i] * ((T)1 - c[i]) + b[i] * c[i], const TVEC& a, const TVEC& b, const TVEC& c) +CWISE_OP(mix, TVEC, a[i] * ((T)1 - c) + b[i] * c, const TVEC& a, const TVEC& b, T c) + +CWISE_OP(operator-, TVEC, -a[i], const TVEC& a) +CWISE_OP(floor, TVEC, floor(a[i]), const TVEC& a) +CWISE_OP(round, TVEC, round(a[i]), const TVEC& a) +CWISE_OP(ceil, TVEC, ceil(a[i]), const TVEC& a) +CWISE_OP(abs, TVEC, abs(a[i]), const TVEC& a) +CWISE_OP(sin, TVEC, sin(a[i]), const TVEC& a) +CWISE_OP(asin, TVEC, asin(a[i]), const TVEC& a) +CWISE_OP(cos, TVEC, cos(a[i]), const TVEC& a) +CWISE_OP(acos, TVEC, acos(a[i]), const TVEC& a) +CWISE_OP(tan, TVEC, tan(a[i]), const TVEC& a) +CWISE_OP(atan, TVEC, atan(a[i]), const TVEC& a) +CWISE_OP(sqrt, TVEC, sqrt(a[i]), const TVEC& a) +CWISE_OP(exp, TVEC, exp(a[i]), const TVEC& a) +CWISE_OP(log, TVEC, log(a[i]), const TVEC& a) +CWISE_OP(exp2, TVEC, exp2(a[i]), const TVEC& a) +CWISE_OP(log2, TVEC, log2(a[i]), const TVEC& a) +CWISE_OP(pow, TVEC, pow(a[i], b), const TVEC& a, T b) +CWISE_OP(pow, TVEC, pow(a[i], b[i]), const TVEC& a, const TVEC& b) + +CWISE_OP(isfinite, BVEC, isfinite(a[i]), const TVEC& a) + +#if defined(__CUDACC__) +inline TCNN_DEVICE void atomic_add_gmem_float(float* addr, float in) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + int in_int = *((int*)&in); + asm ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "r"(in_int)); +#else + atomicAdd(addr, in); +#endif +} + +template +TCNN_DEVICE void atomic_add(T* dst, const tvec& a) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + atomicAdd(dst + i, a[i]); + } +} + +template +TCNN_DEVICE void atomic_add_gmem(float* dst, const tvec& a) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + atomic_add_gmem_float(dst + i, a[i]); + } +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 // atomicAdd(__half2) is only supported with compute capability 60 and above +inline TCNN_DEVICE void atomic_add_gmem_h2(half2* addr, half2 in) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + int in_int = *((int*)&in); + asm ("red.relaxed.gpu.global.add.noftz.f16x2 [%0], %1;" :: "l"(addr), "r"(in_int)); +#else + atomicAdd(addr, in); +#endif +} + +template > +TCNN_DEVICE void atomic_add(__half* dst, const tvec<__half, N, A>& a) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; i += 2) { + atomicAdd((__half2*)(dst + i), __half2(a[i], a[i+1])); + } +} + +template > +TCNN_DEVICE void atomic_add_gmem(__half* dst, const tvec<__half, N, A>& a) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; i += 2) { + atomic_add_gmem_h2((__half2*)(dst + i), __half2(a[i], a[i+1])); + } +} +#endif +#endif + +#undef CWISE_OP + +// __half2 specializations for aligned vectors with 2*N fp16 coefficients. +#if defined(__CUDACC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + +#define HVEC tvec<__half, N, A> +#define HALF_CWISE_OP(operation, type_result, expr, ...) \ +template > \ +TCNN_DEVICE type_result operation(__VA_ARGS__) { \ + type_result result; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; i += 2) { \ + *(__half2*)&result[i] = expr; \ + } \ + return result; \ +} + +HALF_CWISE_OP(fma, HVEC, __hfma2(*(__half2*)&a[i], *(__half2*)&b[i], *(__half2*)&c[i]), const HVEC& a, const HVEC& b, const HVEC& c) +HALF_CWISE_OP(fma, HVEC, __hfma2(*(__half2*)&a[i], *(__half2*)&b[i], __half2half2(c)), const HVEC& a, const HVEC& b, __half c) +HALF_CWISE_OP(fma, HVEC, __hfma2(*(__half2*)&a[i], __half2half2(b), *(__half2*)&c[i]), const HVEC& a, __half b, const HVEC& c) +HALF_CWISE_OP(fma, HVEC, __hfma2(*(__half2*)&a[i], __half2half2(b), __half2half2(c)), const HVEC& a, __half b, __half c) +HALF_CWISE_OP(fma, HVEC, __hfma2(__half2half2(a), *(__half2*)&b[i], *(__half2*)&c[i]), __half a, const HVEC& b, const HVEC& c) +HALF_CWISE_OP(fma, HVEC, __hfma2(__half2half2(a), *(__half2*)&b[i], __half2half2(c)), __half a, const HVEC& b, __half c) +HALF_CWISE_OP(fma, HVEC, __hfma2(__half2half2(a), __half2half2(b), *(__half2*)&c[i]), __half a, __half b, const HVEC& c) + +HALF_CWISE_OP(operator+, HVEC, __hadd2(*(__half2*)&a[i], *(__half2*)&b[i]), const HVEC& a, const HVEC& b) +HALF_CWISE_OP(operator+, HVEC, __hadd2(__half2half2(a), *(__half2*)&b[i]), __half a, const HVEC& b) +HALF_CWISE_OP(operator+, HVEC, __hadd2(*(__half2*)&a[i], __half2half2(b)), const HVEC& a, __half b) + +HALF_CWISE_OP(operator-, HVEC, __hsub2(*(__half2*)&a[i], *(__half2*)&b[i]), const HVEC& a, const HVEC& b) +HALF_CWISE_OP(operator-, HVEC, __hsub2(__half2half2(a), *(__half2*)&b[i]), __half a, const HVEC& b) +HALF_CWISE_OP(operator-, HVEC, __hsub2(*(__half2*)&a[i], __half2half2(b)), const HVEC& a, __half b) + +HALF_CWISE_OP(operator*, HVEC, __hmul2(*(__half2*)&a[i], *(__half2*)&b[i]), const HVEC& a, const HVEC& b) +HALF_CWISE_OP(operator*, HVEC, __hmul2(__half2half2(a), *(__half2*)&b[i]), __half a, const HVEC& b) +HALF_CWISE_OP(operator*, HVEC, __hmul2(*(__half2*)&a[i], __half2half2(b)), const HVEC& a, __half b) + +HALF_CWISE_OP(operator/, HVEC, __h2div(*(__half2*)&a[i], *(__half2*)&b[i]), const HVEC& a, const HVEC& b) +HALF_CWISE_OP(operator/, HVEC, __h2div(*(__half2*)&a[i], __half2half2(b)), const HVEC& a, __half b) + +#endif + +#define INPLACE_OP(operation, type_b, expr) \ +template \ +TCNN_HOST_DEVICE TVEC& operation(TVEC& a, type_b b) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + expr; \ + } \ + return a; \ +} + +INPLACE_OP(operator*=, const TVEC&, a[i] *= b[i]) +INPLACE_OP(operator/=, const TVEC&, a[i] /= b[i]) +INPLACE_OP(operator+=, const TVEC&, a[i] += b[i]) +INPLACE_OP(operator-=, const TVEC&, a[i] -= b[i]) + +INPLACE_OP(operator*=, T, a[i] *= b) +INPLACE_OP(operator/=, T, a[i] /= b) + +#undef INPLACE_OP + +#define REDUCTION_OP(operation, type_result, init, expr, ...) \ +template \ +TCNN_HOST_DEVICE type_result operation(__VA_ARGS__) { \ + type_result result = init; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + expr; \ + } \ + return result; \ +} + +REDUCTION_OP(dot, T, (T)0, result += a[i] * b[i], const TVEC& a, const TVEC& b) +REDUCTION_OP(sum, T, (T)0, result += a[i], const TVEC& a) +REDUCTION_OP(mean, T, (T)0, result += a[i] / (T)N, const TVEC& a) +REDUCTION_OP(product, T, (T)1, result *= a[i], const TVEC& a) +REDUCTION_OP(min, T, (T)std::numeric_limits::infinity(), result = min(result, a[i]), const TVEC& a) +REDUCTION_OP(max, T, (T)-std::numeric_limits::infinity(), result = max(result, a[i]), const TVEC& a) +REDUCTION_OP(length2, T, (T)0, result += a[i] * a[i], const TVEC& a) + +REDUCTION_OP(operator==, bool, true, result &= a[i] == b[i], const TVEC& a, const TVEC& b) +REDUCTION_OP(operator!=, bool, false, result |= a[i] != b[i], const TVEC& a, const TVEC& b) + +#undef REDUCTION_OP + +#define BOOL_REDUCTION_OP(operation, type_result, init, expr, ...) \ +template \ +TCNN_HOST_DEVICE type_result operation(__VA_ARGS__) { \ + type_result result = init; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + expr; \ + } \ + return result; \ +} + +BOOL_REDUCTION_OP(all, bool, true, result &= a[i], const BVEC& a) +BOOL_REDUCTION_OP(any, bool, false, result |= a[i], const BVEC& a) + +#undef BOOL_REDUCTION_OP + +template +TCNN_HOST_DEVICE T length(const TVEC& a) { + return std::sqrt(length2(a)); +} + +template +TCNN_HOST_DEVICE T distance(const TVEC& a, const TVEC& b) { + return length(a - b); +} + +template +TCNN_HOST_DEVICE TVEC normalize(const TVEC& v) { + T len = length(v); + if (len <= (T)0) { + TVEC result{(T)0}; + result[0] = (T)1; + return result; + } + return v / len; +} + +template +TCNN_HOST_DEVICE TVEC cross(const TVEC& a, const TVEC& b) { + return { + a.y * b.z - a.z * b.y, + a.z * b.x - a.x * b.z, + a.x * b.y - a.y * b.x, + }; +} + +template +TCNN_HOST_DEVICE TVEC faceforward(const TVEC& n, const TVEC& i, const TVEC& nref) { + return n * -copysign((T)1, dot(i, nref)); +} + +#undef TVEC +#undef BVEC + +#define DEF_NON_TEMPLATED_VECTOR_TYPES(name, T) \ +template using name = tvec; \ +template using a##name = tvec; \ +using name##1 = name<1>; \ +using name##2 = name<2>; \ +using name##3 = name<3>; \ +using name##4 = name<4>; + +DEF_NON_TEMPLATED_VECTOR_TYPES(bvec, bool) +DEF_NON_TEMPLATED_VECTOR_TYPES(vec, float) +DEF_NON_TEMPLATED_VECTOR_TYPES(dvec, double) +DEF_NON_TEMPLATED_VECTOR_TYPES(ivec, int) +DEF_NON_TEMPLATED_VECTOR_TYPES(uvec, unsigned int) +DEF_NON_TEMPLATED_VECTOR_TYPES(i32vec, int32_t) +DEF_NON_TEMPLATED_VECTOR_TYPES(u32vec, uint32_t) +DEF_NON_TEMPLATED_VECTOR_TYPES(i16vec, int16_t) +DEF_NON_TEMPLATED_VECTOR_TYPES(u16vec, uint16_t) +DEF_NON_TEMPLATED_VECTOR_TYPES(i8vec, int8_t) +DEF_NON_TEMPLATED_VECTOR_TYPES(u8vec, uint8_t) +#if defined(__CUDACC__) +DEF_NON_TEMPLATED_VECTOR_TYPES(hvec, __half) +#endif + +#if defined(__CUDACC__) +inline TCNN_HOST_DEVICE float4 to_float4(const vec4& x) { return {x.x, x.y, x.z, x.w}; } +inline TCNN_HOST_DEVICE float3 to_float3(const vec3& x) { return {x.x, x.y, x.z}; } +inline TCNN_HOST_DEVICE float2 to_float2(const vec2& x) { return {x.x, x.y}; } +inline TCNN_HOST_DEVICE vec4 to_vec4(const float4& x) { return {x.x, x.y, x.z, x.w}; } +inline TCNN_HOST_DEVICE vec3 to_vec3(const float3& x) { return {x.x, x.y, x.z}; } +inline TCNN_HOST_DEVICE vec2 to_vec2(const float2& x) { return {x.x, x.y}; } +#endif + +template +struct tmat { + tmat() = default; + + TCNN_HOST_DEVICE tmat(T scalar) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + TCNN_PRAGMA_UNROLL + for (uint32_t j = 0; j < M; ++j) { + m[i][j] = i == j ? scalar : (T)0; + } + } + } + + TCNN_HOST_DEVICE static constexpr tmat identity() { return tmat((T)1); } + TCNN_HOST_DEVICE static constexpr tmat zero() { return tmat((T)0); } + + template > + TCNN_HOST_DEVICE tmat(Ts... coeffs) : d{coeffs...} {} + + TCNN_HOST_DEVICE tmat(const T* coeffs) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + TCNN_PRAGMA_UNROLL + for (uint32_t j = 0; j < M; ++j) { + m[i][j] = *(coeffs++); + } + } + } + + template + TCNN_HOST_DEVICE tmat(const tvec& a) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + m[i] = a; + } + } + + template + TCNN_HOST_DEVICE tmat(const tvec& a, const tvec& b) { + static_assert(N == 2, "Matrix must have 2 columns."); + m[0] = a; m[1] = b; + } + + template + TCNN_HOST_DEVICE tmat(const tvec& a, const tvec& b, const tvec& c) { + static_assert(N == 3, "Matrix must have 3 columns."); + m[0] = a; m[1] = b; m[2] = c; + } + + template + TCNN_HOST_DEVICE tmat(const tvec& a, const tvec& b, const tvec& c, const tvec& d) { + static_assert(N == 4, "Matrix must have 4 columns."); + m[0] = a; m[1] = b; m[2] = c; m[3] = d; + } + + template + TCNN_HOST_DEVICE tmat(const tmat& other) { + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + TCNN_PRAGMA_UNROLL + for (uint32_t j = 0; j < M; ++j) { + m[i][j] = i < P && j < O ? other[i][j] : (i == j ? (T)1 : (T)0); + } + } + } + + template + TCNN_HOST_DEVICE tvec operator*(const tvec& v) const { + tvec result((T)0); + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + TCNN_PRAGMA_UNROLL + for (uint32_t j = 0; j < M; ++j) { + result[j] += m[i][j] * v[i]; + } + } + return result; + } + + template + TCNN_HOST_DEVICE tmat operator*(const tmat& other) const { + tmat result; + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < K; ++i) { + result[i] = (*this) * other[i]; + } + return result; + } + + TCNN_HOST_DEVICE tvec& at(uint32_t idx) { return m[idx]; } + TCNN_HOST_DEVICE tvec at(uint32_t idx) const { return m[idx]; } + + TCNN_HOST_DEVICE tvec& operator[](uint32_t idx) { return m[idx]; } + TCNN_HOST_DEVICE tvec operator[](uint32_t idx) const { return m[idx]; } + + TCNN_HOST_DEVICE T* data() { return d; } + TCNN_HOST_DEVICE const T* data() const { return d; } + + union { + tvec m[N]; + T d[M*N]; + }; +}; + +template +TCNN_HOST_DEVICE tmat& operator*=(tmat& m, const tmat& other) { + m = m * other; + return m; +} + +template +TCNN_HOST_DEVICE T frobenius_norm(const tmat& m) { + T result = (T)0; + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + result += length2(m[i]); + } + return sqrt(result); +} + +template +TCNN_HOST_DEVICE tmat transpose(const tmat& m) { + tmat result; + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + TCNN_PRAGMA_UNROLL + for (uint32_t j = 0; j < M; ++j) { + result[j][i] = m[i][j]; + } + } + return result; +} + +template +TCNN_HOST_DEVICE tvec row(const tmat& m, int r) { + tvec result; + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + result[i] = m[i][r]; + } + return result; +} + +template +TCNN_HOST_DEVICE tmat row(const tmat& m, int r, const tvec& v) { + tmat result = m; + TCNN_PRAGMA_UNROLL + for (uint32_t i = 0; i < N; ++i) { + result[i][r] = v[i]; + } + return result; +} + +#define TMAT tmat + +#define CWISE_OP(operation, type_a, type_b, expr) \ +template \ +TCNN_HOST_DEVICE TMAT operation(type_a a, type_b b) { \ + TMAT result; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t j = 0; j < M; ++j) { \ + result[i][j] = expr; \ + } \ + } \ + return result; \ +} + +CWISE_OP(operator+, const TMAT&, const TMAT&, a[i][j] + b[i][j]) +CWISE_OP(operator-, const TMAT&, const TMAT&, a[i][j] - b[i][j]) + +CWISE_OP(operator*, T, const TMAT&, a * b[i][j]) +CWISE_OP(operator*, const TMAT&, T, a[i][j] * b) +CWISE_OP(operator/, const TMAT&, T, a[i][j] / b) + +#undef CWISE_OP + +#define INPLACE_OP(operation, type_b, expr) \ +template \ +TCNN_HOST_DEVICE TMAT& operation(TMAT& a, type_b b) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t j = 0; j < M; ++j) { \ + expr; \ + } \ + } \ + return a; \ +} + +INPLACE_OP(operator+=, const TMAT&, a[i][j] += b[i][j]) +INPLACE_OP(operator-=, const TMAT&, a[i][j] -= b[i][j]) + +INPLACE_OP(operator*=, T, a[i][j] *= b) +INPLACE_OP(operator/=, T, a[i][j] /= b) + +#undef INPLACE_OP + +#define REDUCTION_OP(operation, type_result, init, expr, ...) \ +template \ +TCNN_HOST_DEVICE type_result operation(__VA_ARGS__) { \ + type_result result = init; \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t i = 0; i < N; ++i) { \ + TCNN_PRAGMA_UNROLL \ + for (uint32_t j = 0; j < M; ++j) { \ + expr; \ + } \ + } \ + return result; \ +} + +REDUCTION_OP(operator==, bool, true, result &= a[i][j] == b[i][j], const TMAT& a, const TMAT& b) +REDUCTION_OP(operator!=, bool, false, result |= a[i][j] != b[i][j], const TMAT& a, const TMAT& b) +REDUCTION_OP(isfinite, bool, true, result &= isfinite(a[i][j]), const TMAT& a) + +#undef REDUCTION_OP + +// The following implementations of determinants, adjoints, inverses, and quaternions +// (and only those) were adapted from glm per the MIT license, which is included below in full. +// ================================================================================ +// The MIT License +// -------------------------------------------------------------------------------- +// Copyright (c) 2005 - G-Truc Creation + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +template +TCNN_HOST_DEVICE T determinant(const tmat& m) { + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; +} + +template +TCNN_HOST_DEVICE T determinant(const tmat& m) { + return + m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + + -m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]) + ; +} + +template +TCNN_HOST_DEVICE T determinant(const tmat& m) { + T s0 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + T s1 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + T s2 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + T s3 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + T s4 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + T s5 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + + tvec coeff{ + (m[1][1] * s0 - m[1][2] * s1 + m[1][3] * s2), + -(m[1][0] * s0 - m[1][2] * s3 + m[1][3] * s4), + (m[1][0] * s1 - m[1][1] * s3 + m[1][3] * s5), + -(m[1][0] * s2 - m[1][1] * s4 + m[1][2] * s5), + }; + + return + m[0][0] * coeff[0] + m[0][1] * coeff[1] + + m[0][2] * coeff[2] + m[0][3] * coeff[3] + ; +} + +template +TCNN_HOST_DEVICE tmat adjoint(const tmat& m) { + return { + m[1][1], -m[0][1], + -m[1][0], m[0][0], + }; +} + +template +TCNN_HOST_DEVICE tmat adjoint(const tmat& m) { + const T m00 = determinant(tmat{m[1][1], m[2][1], m[1][2], m[2][2]}); + const T m01 = determinant(tmat{m[0][1], m[2][1], m[0][2], m[2][2]}); + const T m02 = determinant(tmat{m[0][1], m[1][1], m[0][2], m[1][2]}); + + const T m10 = determinant(tmat{m[1][0], m[2][0], m[1][2], m[2][2]}); + const T m11 = determinant(tmat{m[0][0], m[2][0], m[0][2], m[2][2]}); + const T m12 = determinant(tmat{m[0][0], m[1][0], m[0][2], m[1][2]}); + + const T m20 = determinant(tmat{m[1][0], m[2][0], m[1][1], m[2][1]}); + const T m21 = determinant(tmat{m[0][0], m[2][0], m[0][1], m[2][1]}); + const T m22 = determinant(tmat{m[0][0], m[1][0], m[0][1], m[1][1]}); + + return { + m00, -m01, m02, + -m10, m11, -m12, + m20, -m21, m22, + }; +} + +template +TCNN_HOST_DEVICE tmat adjoint(const tmat& m) { + const T m00 = determinant(tmat{m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], m[3][3]}); + const T m01 = determinant(tmat{m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], m[3][3]}); + const T m02 = determinant(tmat{m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], m[3][3]}); + const T m03 = determinant(tmat{m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], m[3][2]}); + + const T m10 = determinant(tmat{m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], m[3][3]}); + const T m11 = determinant(tmat{m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], m[3][3]}); + const T m12 = determinant(tmat{m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], m[3][3]}); + const T m13 = determinant(tmat{m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], m[3][2]}); + + const T m20 = determinant(tmat{m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], m[3][3]}); + const T m21 = determinant(tmat{m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], m[3][3]}); + const T m22 = determinant(tmat{m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], m[3][3]}); + const T m23 = determinant(tmat{m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], m[3][2]}); + + const T m30 = determinant(tmat{m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3]}); + const T m31 = determinant(tmat{m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3]}); + const T m32 = determinant(tmat{m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3]}); + const T m33 = determinant(tmat{m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2]}); + + return { + m00, -m10, m20, -m30, + -m01, m11, -m21, m31, + m02, -m12, m22, -m32, + -m03, m13, -m23, m33, + }; +} + +template +TCNN_HOST_DEVICE tmat inverse(const tmat& m) { + return adjoint(m) / determinant(m); +} + +template +TCNN_HOST_DEVICE tmat rotmat(T angle, const tvec& axis) { + T s, c; + sincos(angle, &s, &c); + T oc = (T)1 - c; + + return { + oc * axis.x * axis.x + c, oc * axis.x * axis.y + axis.z * s, oc * axis.z * axis.x - axis.y * s, + oc * axis.x * axis.y - axis.z * s, oc * axis.y * axis.y + c, oc * axis.y * axis.z + axis.x * s, + oc * axis.z * axis.x + axis.y * s, oc * axis.y * axis.z - axis.x * s, oc * axis.z * axis.z + c, + }; +} + +template +TCNN_HOST_DEVICE tmat rotmat(const tvec& v) { + T angle = length(v); + if (angle == (T)0) { + return tmat::identity(); + } + + return rotmat(angle, v / angle); +} + +template +TCNN_HOST_DEVICE tmat mat_sqrt(const tmat& m, T eps = (T)1e-10f) { + tmat X = m, Y = tmat::identity(); + for (uint32_t i = 0; i < 32; ++i) { + if (frobenius_norm(X * X - m) < eps) { + return X; + } + + tmat iX = inverse(X); + X = (T)0.5f * (X + inverse(Y)); + Y = (T)0.5f * (Y + iX); + } + + return X; +} + +template +TCNN_HOST_DEVICE tmat mat_log_hawkins(const tmat& m, T eps = (T)1e-10f) { + tmat A = m - tmat::identity(), Z = A, X = A; + for (uint32_t i = 2; i < 32; ++i) { + if (frobenius_norm(Z) < eps) { + return X; + } + + Z = Z * A; + X += ((T)1 / (T)i) * Z; + } + + return X; +} + +template +TCNN_HOST_DEVICE tmat mat_exp_pade(const tmat& m) { + // Pade approximation with scaling; same as Matlab. + // Pseudocode translated from Hawkins and Grimm [2007] + tmat mX = tmat::identity(), mD = tmat::identity(), mN = tmat::identity(); + T c = (T)1; + constexpr uint32_t q = 6; // Matlab's default when using this algorithm + + T s = -(T)1; + for (uint32_t k = 1; k <= q; ++k) { + c = c * (q - k + 1) / (k * (2 * q - k + 1)); + mX = m * mX; + auto cmX = c * mX; + mN = mN + cmX; + mD = mD + s * cmX; + s = -s; + } + + return inverse(mD) * mN; +} + +template +TCNN_HOST_DEVICE tmat mat_log(const tmat& m) { + tmat result(m); + + uint32_t j = 0; + for (; j < 32; ++j) { + if (frobenius_norm(result - tmat::identity()) < (T)1e-5f) { + break; + } + + result = mat_sqrt(result); + } + + result = mat_log_hawkins(result); + return (T)scalbnf(1.0f, j) * result; +} + +template +TCNN_HOST_DEVICE tmat mat_exp(const tmat& m) { + uint32_t N_SQUARING = max(0, 1 + (int)floor(log2(frobenius_norm(m)))); + + tmat result = (T)scalbnf(1.0f, -N_SQUARING) * m; + result = mat_exp_pade(result); + + for (uint32_t i = 0; i < N_SQUARING; ++i) { + result *= result; + } + + return result; +} + +template +TCNN_HOST_DEVICE tmat orthogonalize(const tmat& m) { + // Iteration to bring an almost orthogonal matrix nearer to its closest + // orthogonal matrix. This can be run multiple times until convergence + // is measured or, alternatively, once per frame on something like a + // camera matrix to ensure it does not degenerate over time. + return (T)1.5f * m - (T)0.5f * (m * transpose(m) * m); +} + +template +TCNN_HOST_DEVICE tmat orthogonalize(const tmat& m) { + auto rot = orthogonalize(tmat{m}); + return tmat{rot[0], rot[1], rot[2], m[3]}; +} + +template +TCNN_HOST_DEVICE tmat so3_log(const tmat& m) { + T tr = clamp(m[0][0] + m[1][1] + m[2][2], -(T)1 + std::numeric_limits::epsilon(), (T)1); + T radians = acosf((tr - (T)1) / (T)2); + return radians / sqrt(((T)1 + tr) * ((T)3 - tr)) * (m - transpose(m)); +} + +template +TCNN_HOST_DEVICE tmat so3_exp(const tmat& m) { + tvec axis = {-m[2][1], m[2][0], -m[1][0]}; + T radians_sq = length2(axis); + if (radians_sq == (T)0) { + return tmat::identity(); + } + + T radians = sqrt(radians_sq); + return tmat::identity() + (sin(radians) / radians) * m + (((T)1 - cos(radians)) / radians_sq) * (m * m); +} + +template +TCNN_HOST_DEVICE tmat se3_log(const tmat& m) { + auto omega = so3_log(tmat(m)); + tvec axis = {-omega[2][1], omega[2][0], -omega[1][0]}; + T radians_sq = length2(axis); + auto inv_trans = tmat::identity(); + if (radians_sq > (T)0) { + T radians = sqrt(radians_sq); + inv_trans += -(T)0.5 * omega + (((T)1 - (T)0.5 * radians * cos((T)0.5 * radians) / sin((T)0.5 * radians)) / radians_sq) * (omega * omega); + } + + return {omega[0], omega[1], omega[2], inv_trans * m[3]}; +} + +template +TCNN_HOST_DEVICE tmat se3_exp(const tmat& m) { + tmat omega = m; + tvec axis = {-omega[2][1], omega[2][0], -omega[1][0]}; + T radians_sq = length2(axis); + auto trans = tmat::identity(); + if (radians_sq > (T)0) { + T radians = sqrt(radians_sq); + trans += (((T)1 - cos(radians)) / radians_sq) * omega + ((radians - sin(radians)) / (radians * radians_sq)) * (omega * omega); + } + + auto rot = so3_exp(omega); + return {rot[0], rot[1], rot[2], trans * m[3]}; +} + +template +TCNN_HOST_DEVICE tmat se3_log(const tmat& m) { + auto result = tmat(se3_log(tmat(m))); + result[3][3] = (T)0; + return result; +} + +template +TCNN_HOST_DEVICE tmat se3_exp(const tmat& m) { + return tmat(se3_exp(tmat(m))); +} + +#define DEF_NON_TEMPLATED_MATRIX_TYPES(name, T) \ +template \ +using name = tmat; \ +using name##4x4 = name<4, 4>; \ +using name##4x3 = name<4, 3>; \ +using name##4x2 = name<4, 2>; \ +using name##3x4 = name<3, 4>; \ +using name##3x3 = name<3, 3>; \ +using name##3x2 = name<3, 2>; \ +using name##2x4 = name<2, 4>; \ +using name##2x3 = name<2, 3>; \ +using name##2x2 = name<2, 2>; \ +using name##4 = name##4x4; \ +using name##3 = name##3x3; \ +using name##2 = name##2x2; + +DEF_NON_TEMPLATED_MATRIX_TYPES(mat, float) +DEF_NON_TEMPLATED_MATRIX_TYPES(dmat, double) +#if defined(__CUDACC__) +DEF_NON_TEMPLATED_MATRIX_TYPES(hmat, __half) +#endif + +template +struct tquat { + tquat() = default; + TCNN_HOST_DEVICE tquat(T w, T x, T y, T z) : w{w}, x{x}, y{y}, z{z} {} + TCNN_HOST_DEVICE tquat(const tmat& m) { + // Code adapted from https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ + T tr = m[0][0] + m[1][1] + m[2][2]; + + if (tr > (T)0) { + T S = sqrt(tr + (T)1) * (T)2; // S=4*qw + w = (T)0.25 * S; + x = (m[1][2] - m[2][1]) / S; + y = (m[2][0] - m[0][2]) / S; + z = (m[0][1] - m[1][0]) / S; + } else if (m[0][0] > m[1][1] && m[0][0] > m[2][2]) { + T S = sqrt((T)1 + m[0][0] - m[1][1] - m[2][2]) * (T)2; // S=4*x + w = (m[1][2] - m[2][1]) / S; + x = (T)0.25 * S; + y = (m[1][0] + m[0][1]) / S; + z = (m[2][0] + m[0][2]) / S; + } else if (m[1][1] > m[2][2]) { + T S = sqrt((T)1 + m[1][1] - m[0][0] - m[2][2]) * (T)2; // S=4*y + w = (m[2][0] - m[0][2]) / S; + x = (m[1][0] + m[0][1]) / S; + y = (T)0.25 * S; + z = (m[2][1] + m[1][2]) / S; + } else { + T S = sqrt((T)1 + m[2][2] - m[0][0] - m[1][1]) * (T)2; // S=4*z + w = (m[0][1] - m[1][0]) / S; + x = (m[2][0] + m[0][2]) / S; + y = (m[2][1] + m[1][2]) / S; + z = (T)0.25 * S; + } + } + + T w, x, y, z; +}; + +template TCNN_HOST_DEVICE tquat operator-(const tquat& a) { return {-a.w, -a.x, -a.y, -a.z}; } +template TCNN_HOST_DEVICE tquat operator+(const tquat& a, const tquat& b) { return {a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z}; } +template TCNN_HOST_DEVICE tquat operator-(const tquat& a, const tquat& b) { return {a.w - b.w, a.x - b.x, a.y - b.y, a.z - b.z}; } +template TCNN_HOST_DEVICE tquat operator*(T a, const tquat& b) { return {a * b.w, a * b.x, a * b.y, a * b.z}; } +template TCNN_HOST_DEVICE tquat operator*(const tquat& a, T b) { return {a.w * b, a.x * b, a.y * b, a.z * b}; } +template TCNN_HOST_DEVICE tquat operator/(const tquat& a, T b) { return {a.w / b, a.x / b, a.y / b, a.z / b}; } + +template TCNN_HOST_DEVICE T dot(const tquat& a, const tquat& b) { return (a.w * b.w + a.x * b.x) + (a.y * b.y + a.z * b.z); } +template TCNN_HOST_DEVICE T length2(const tquat& a) { return dot(a, a); } +template TCNN_HOST_DEVICE T length(const tquat& a) { return sqrt(length2(a)); } + +template TCNN_HOST_DEVICE tquat mix(const tquat& a, const tquat& b, T t) { return a * ((T)1 - t) + b * t; } + +template +TCNN_HOST_DEVICE tquat normalize(const tquat& a) { + T len = length(a); + if (len <= (T)0) { + return {(T)1, (T)0, (T)0, (T)0}; + } + return a / len; +} + +template +TCNN_HOST_DEVICE tquat cross(const tquat& a, const tquat& b) { + return { + a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, + a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, + a.w * b.y + a.y * b.w + a.z * b.x - a.x * b.z, + a.w * b.z + a.z * b.w + a.x * b.y - a.y * b.x + }; +} + +template +TCNN_HOST_DEVICE tquat slerp(const tquat& x, const tquat& y, T t) { + tquat z = y; + + T cos_theta = dot(x, y); + + // If cos_theta < 0, the interpolation will take the long way around the sphere. + // To fix this, one quat must be negated. + if (cos_theta < (T)0) { + z = -y; + cos_theta = -cos_theta; + } + + // Perform a linear interpolation when cos_theta is close to 1 to avoid side effect of sin(angle) becoming a zero denominator + if (cos_theta > (T)1 - std::numeric_limits::epsilon()) { + return mix(x, z, t); + } else { + // Essential Mathematics, page 467 + T angle = acos(cos_theta); + return (sin(((T)1 - t) * angle) * x + sin(t * angle) * z) / sin(angle); + } +} + +template +TCNN_HOST_DEVICE T angle(const tquat& x) { + return acos(clamp(x.w, (T)-1, (T)1)) * (T)2; +} + +template +TCNN_HOST_DEVICE tvec axis(const tquat& x) { + const T tmp1 = (T)1 - x.w * x.w; + if (tmp1 <= (T)0) { + return {(T)0, (T)0, (T)1}; + } + + const T tmp2 = (T)1 / sqrt(tmp1); + return {x.x * tmp2, x.y * tmp2, x.z * tmp2}; +} + +template +TCNN_HOST_DEVICE tmat to_mat3(const tquat& q) { + T qxx = q.x * q.x, qyy = q.y * q.y, qzz = q.z * q.z; + T qxz = q.x * q.z, qxy = q.x * q.y, qyz = q.y * q.z; + T qwx = q.w * q.x, qwy = q.w * q.y, qwz = q.w * q.z; + + return { + (T)1 - (T)2 * (qyy + qzz), (T)2 * (qxy + qwz), (T)2 * (qxz - qwy), + (T)2 * (qxy - qwz), (T)1 - (T)2 * (qxx + qzz), (T)2 * (qyz + qwx), + (T)2 * (qxz + qwy), (T)2 * (qyz - qwx), (T)1 - (T)2 * (qxx + qyy), + }; +} + +template +TCNN_HOST_DEVICE tmat slerp(const tmat& a, const tmat& b, float t) { + return to_mat3(normalize(slerp(normalize(tquat(a)), normalize(tquat(b)), t))); +} + +template +TCNN_HOST_DEVICE tvec rotvec(const tmat& mat) { + tquat tmp = mat; + return axis(tmp) * angle(tmp); +} + +using quat = tquat; + +} diff --git a/gui/include/tiny-cuda-nn/vec_json.h b/gui/include/tiny-cuda-nn/vec_json.h new file mode 100644 index 0000000000000000000000000000000000000000..e0eb641e1414723421bb671f952a5291929d2192 --- /dev/null +++ b/gui/include/tiny-cuda-nn/vec_json.h @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file vec_json.h + * @author Thomas Müller, NVIDIA + * @brief Conversion between tcnn's vector / matrix / quaternion types + * and nlohmann::json. + */ + +#pragma once + +#include + +#include + +namespace tcnn { + +template +void to_json(nlohmann::json& j, const tmat& mat) { + for (int row = 0; row < M; ++row) { + nlohmann::json column = nlohmann::json::array(); + for (int col = 0; col < N; ++col) { + column.push_back(mat[col][row]); + } + j.push_back(column); + } +} + +template +void from_json(const nlohmann::json& j, tmat& mat) { + for (std::size_t row = 0; row < M; ++row) { + const auto& jrow = j.at(row); + for (std::size_t col = 0; col < N; ++col) { + const auto& value = jrow.at(col); + mat[col][row] = value.get(); + } + } +} + +template +void to_json(nlohmann::json& j, const tvec& v) { + for (uint32_t i = 0; i < N; ++i) { + j.push_back(v[i]); + } +} + +template +void from_json(const nlohmann::json& j, tvec& v) { + for (uint32_t i = 0; i < N; ++i) { + v[i] = j.at(i).get(); + } +} + +template +void to_json(nlohmann::json& j, const tquat& q) { + j.push_back(q.x); + j.push_back(q.y); + j.push_back(q.z); + j.push_back(q.w); +} + +template +void from_json(const nlohmann::json& j, tquat& q) { + q.x = j.at(0).get(); + q.y = j.at(1).get(); + q.z = j.at(2).get(); + q.w = j.at(3).get(); +} + +} diff --git a/gui/include/tiny-cuda-nn/vec_pybind11.h b/gui/include/tiny-cuda-nn/vec_pybind11.h new file mode 100644 index 0000000000000000000000000000000000000000..be10c0d31472acbc8e007ef0512b6498c0fd5900 --- /dev/null +++ b/gui/include/tiny-cuda-nn/vec_pybind11.h @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file vec_pybind11.h + * @author Thomas Müller, NVIDIA + * @brief pybind11 bindings for NGP's vector and matrix types. Adapted from + * Patrik Huber's glm binding code per the BSD license of pybind11. + */ + +#pragma once + +#include + +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +namespace pybind11 { +namespace detail { + +template +struct type_caster> { + using vector_type = tcnn::tvec; + using Scalar = T; + static constexpr std::size_t num_elements = N; + + bool load(handle src, bool) { + auto buf = array_t::ensure(src); + if (!buf) { + return false; + } + + if (buf.ndim() != 1) { + return false; // not a rank-1 tensor (i.e. vector) + } + + if (buf.shape(0) != num_elements) { + return false; // not a 2-elements vector + } + + for (size_t i = 0; i < num_elements; ++i) { + value[i] = *buf.data(i); + } + + return true; + } + + static handle cast(const vector_type& src, return_value_policy, handle) { + return array( + num_elements, + src.data() + ).release(); + } + + // Specifies the doc-string for the type in Python: + PYBIND11_TYPE_CASTER(vector_type, _("vec")); +}; + +template +struct type_caster> { + using matrix_type = tcnn::tmat; + using Scalar = T; + static constexpr std::size_t num_rows = M; + static constexpr std::size_t num_cols = N; + + bool load(handle src, bool) { + auto buf = array_t::ensure(src); + if (!buf) { + return false; + } + + if (buf.ndim() != 2) { + return false; // not a rank-2 tensor (i.e. matrix) + } + + if (buf.shape(0) != num_rows || buf.shape(1) != num_cols) { + return false; // not a 4x4 matrix + } + + for (size_t i = 0; i < num_cols; ++i) { + for (size_t j = 0; j < num_rows; ++j) { + value[i][j] = *buf.data(j, i); + } + } + + return true; + } + + static handle cast(const matrix_type& src, return_value_policy, handle) { + return array( + { num_rows, num_cols }, + { sizeof(Scalar), sizeof(Scalar) * num_rows }, // strides - flip the row/col layout! + src.data() + ).release(); + } + + // Specifies the doc-string for the type in Python: + PYBIND11_TYPE_CASTER(matrix_type, _("mat")); +}; + +} +} + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/gui/scripts/common.py b/gui/scripts/common.py new file mode 100644 index 0000000000000000000000000000000000000000..385f13c85e2aae91062b73c9650f9b78d84771e3 --- /dev/null +++ b/gui/scripts/common.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import code +import glob +import imageio +import numpy as np +import os +from pathlib import PurePosixPath +from scipy.ndimage.filters import convolve1d +import struct +import sys + +from constants import * + +# Search for pyngp in the build folder. +sys.path += [os.path.dirname(pyd) for pyd in glob.iglob(os.path.join(ROOT_DIR, "build*", "**/*.pyd"), recursive=True)] +sys.path += [os.path.dirname(pyd) for pyd in glob.iglob(os.path.join(ROOT_DIR, "build*", "**/*.so"), recursive=True)] + +def repl(testbed): + print("-------------------\npress Ctrl-Z to return to gui\n---------------------------") + code.InteractiveConsole(locals=locals()).interact() + print("------- returning to gui...") + +def mse2psnr(x): return -10.*np.log(x)/np.log(10.) + +def sanitize_path(path): + return str(PurePosixPath(path.relative_to(PAPER_FOLDER))) + +# from https://stackoverflow.com/questions/31638651/how-can-i-draw-lines-into-numpy-arrays +def trapez(y,y0,w): + return np.clip(np.minimum(y+1+w/2-y0, -y+1+w/2+y0),0,1) + +def weighted_line(r0, c0, r1, c1, w, rmin=0, rmax=np.inf): + # The algorithm below works fine if c1 >= c0 and c1-c0 >= abs(r1-r0). + # If either of these cases are violated, do some switches. + if abs(c1-c0) < abs(r1-r0): + # Switch x and y, and switch again when returning. + xx, yy, val = weighted_line(c0, r0, c1, r1, w, rmin=rmin, rmax=rmax) + return (yy, xx, val) + + # At this point we know that the distance in columns (x) is greater + # than that in rows (y). Possibly one more switch if c0 > c1. + if c0 > c1: + return weighted_line(r1, c1, r0, c0, w, rmin=rmin, rmax=rmax) + + # The following is now always < 1 in abs + slope = (r1-r0) / (c1-c0) + + # Adjust weight by the slope + w *= np.sqrt(1+np.abs(slope)) / 2 + + # We write y as a function of x, because the slope is always <= 1 + # (in absolute value) + x = np.arange(c0, c1+1, dtype=float) + y = x * slope + (c1*r0-c0*r1) / (c1-c0) + + # Now instead of 2 values for y, we have 2*np.ceil(w/2). + # All values are 1 except the upmost and bottommost. + thickness = np.ceil(w/2) + yy = (np.floor(y).reshape(-1,1) + np.arange(-thickness-1,thickness+2).reshape(1,-1)) + xx = np.repeat(x, yy.shape[1]) + vals = trapez(yy, y.reshape(-1,1), w).flatten() + + yy = yy.flatten() + + # Exclude useless parts and those outside of the interval + # to avoid parts outside of the picture + mask = np.logical_and.reduce((yy >= rmin, yy < rmax, vals > 0)) + + return (yy[mask].astype(int), xx[mask].astype(int), vals[mask]) + +def diagonally_truncated_mask(shape, x_threshold, angle): + result = np.zeros(shape, dtype=bool) + for x in range(shape[1]): + for y in range(shape[0]): + thres = x_threshold * shape[1] - (angle * shape[0] / 2) + y * angle + result[y, x, ...] = x < thres + return result + +def diagonally_combine_two_images(img1, img2, x_threshold, angle, gap=0, color=1): + if img2.shape != img1.shape: + raise ValueError(f"img1 and img2 must have the same shape; {img1.shape} vs {img2.shape}") + mask = diagonally_truncated_mask(img1.shape, x_threshold, angle) + result = img2.copy() + result[mask] = img1[mask] + if gap > 0: + rr, cc, val = weighted_line(0, int(x_threshold * img1.shape[1] - (angle * img1.shape[0] / 2)), img1.shape[0]-1, int(x_threshold * img1.shape[1] + (angle * img1.shape[0] / 2)), gap) + result[rr, cc, :] = result[rr, cc, :] * (1 - val[...,np.newaxis]) + val[...,np.newaxis] * color + return result + +def diagonally_combine_images(images, x_thresholds, angle, gap=0, color=1): + result = images[0] + for img, thres in zip(images[1:], x_thresholds): + result = diagonally_combine_two_images(result, img, thres, angle, gap, color) + return result + +def write_image_imageio(img_file, img, quality): + img = (np.clip(img, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + kwargs = {} + if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]: + if img.ndim >= 3 and img.shape[2] > 3: + img = img[:,:,:3] + kwargs["quality"] = quality + kwargs["subsampling"] = 0 + imageio.imwrite(img_file, img, **kwargs) + +def read_image_imageio(img_file): + img = imageio.imread(img_file) + img = np.asarray(img).astype(np.float32) + if len(img.shape) == 2: + img = img[:,:,np.newaxis] + return img / 255.0 + +def srgb_to_linear(img): + limit = 0.04045 + return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) + +def linear_to_srgb(img): + limit = 0.0031308 + return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) + +def read_image(file): + if os.path.splitext(file)[1] == ".bin": + with open(file, "rb") as f: + bytes = f.read() + h, w = struct.unpack("ii", bytes[:8]) + img = np.frombuffer(bytes, dtype=np.float16, count=h*w*4, offset=8).astype(np.float32).reshape([h, w, 4]) + else: + img = read_image_imageio(file) + if img.shape[2] == 4: + img[...,0:3] = srgb_to_linear(img[...,0:3]) + # Premultiply alpha + img[...,0:3] *= img[...,3:4] + else: + img = srgb_to_linear(img) + return img + +def write_image(file, img, quality=95): + if os.path.splitext(file)[1] == ".bin": + if img.shape[2] < 4: + img = np.dstack((img, np.ones([img.shape[0], img.shape[1], 4 - img.shape[2]]))) + with open(file, "wb") as f: + f.write(struct.pack("ii", img.shape[0], img.shape[1])) + f.write(img.astype(np.float16).tobytes()) + else: + if img.shape[2] == 4: + img = np.copy(img) + # Unmultiply alpha + img[...,0:3] = np.divide(img[...,0:3], img[...,3:4], out=np.zeros_like(img[...,0:3]), where=img[...,3:4] != 0) + img[...,0:3] = linear_to_srgb(img[...,0:3]) + else: + img = linear_to_srgb(img) + write_image_imageio(file, img, quality) + +def trim(error, skip=0.000001): + error = np.sort(error.flatten()) + size = error.size + skip = int(skip * size) + return error[skip:size-skip].mean() + +def luminance(a): + return 0.2126 * a[:,:,0] + 0.7152 * a[:,:,1] + 0.0722 * a[:,:,2] + +def SSIM(a, b): + def blur(a): + k = np.array([0.120078, 0.233881, 0.292082, 0.233881, 0.120078]) + x = convolve1d(a, k, axis=0) + return convolve1d(x, k, axis=1) + a = luminance(a) + b = luminance(b) + mA = blur(a) + mB = blur(b) + sA = blur(a*a) - mA**2 + sB = blur(b*b) - mB**2 + sAB = blur(a*b) - mA*mB + c1 = 0.01**2 + c2 = 0.03**2 + p1 = (2.0*mA*mB + c1)/(mA*mA + mB*mB + c1) + p2 = (2.0*sAB + c2)/(sA + sB + c2) + error = p1 * p2 + return error + +def L1(img, ref): + return np.abs(img - ref) + +def APE(img, ref): + return L1(img, ref) / (1e-2 + ref) + +def SAPE(img, ref): + return L1(img, ref) / (1e-2 + (ref + img) / 2.) + +def L2(img, ref): + return (img - ref)**2 + +def RSE(img, ref): + return L2(img, ref) / (1e-2 + ref**2) + +def rgb_mean(img): + return np.mean(img, axis=2) + +def compute_error_img(metric, img, ref): + img[np.logical_not(np.isfinite(img))] = 0 + img = np.maximum(img, 0.) + + if metric == "MAE": + return L1(img, ref) + elif metric == "MAPE": + return APE(img, ref) + elif metric == "SMAPE": + return SAPE(img, ref) + elif metric == "MSE": + return L2(img, ref) + elif metric == "MScE": + return L2(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0)) + elif metric == "MRSE": + return RSE(img, ref) + elif metric == "MtRSE": + return trim(RSE(img, ref)) + elif metric == "MRScE": + return RSE(np.clip(img, 0, 100), np.clip(ref, 0, 100)) + elif metric == "SSIM": + return SSIM(np.clip(img, 0.0, 1.0), np.clip(ref, 0.0, 1.0)) + + raise ValueError(f"Unknown metric: {metric}.") + +def compute_error(metric, img, ref): + metric_map = compute_error_img(metric, img, ref) + metric_map[np.logical_not(np.isfinite(metric_map))] = 0 + if len(metric_map.shape) == 3: + metric_map = np.mean(metric_map, axis=2) + mean = np.mean(metric_map) + return mean diff --git a/gui/scripts/constants.py b/gui/scripts/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..96966938b9163866c19060233690580fcdddc93c --- /dev/null +++ b/gui/scripts/constants.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + + +PAPER_FOLDER = Path(__file__).resolve().parent.parent +SUPPL_FOLDER = PAPER_FOLDER/"supplemental" +SCRIPTS_FOLDER = PAPER_FOLDER/"scripts" +TEMPLATE_FOLDER = SCRIPTS_FOLDER/"template" +DATA_FOLDER = SCRIPTS_FOLDER/"data" + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +RESULTS_DIR = os.path.join(ROOT_DIR, "results") diff --git a/gui/scripts/download_ffmpeg.bat b/gui/scripts/download_ffmpeg.bat new file mode 100644 index 0000000000000000000000000000000000000000..4a92c39fce23fff59cc2b900bc4a580c88a1d17b --- /dev/null +++ b/gui/scripts/download_ffmpeg.bat @@ -0,0 +1,22 @@ +:: Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +:: +:: NVIDIA CORPORATION and its licensors retain all intellectual property +:: and proprietary rights in and to this software, related documentation +:: and any modifications thereto. Any use, reproduction, disclosure or +:: distribution of this software and related documentation without an express +:: license agreement from NVIDIA CORPORATION is strictly prohibited. + +@echo off + +set cwd=%cd% +cd /D %~dp0 + +echo Downloading FFmpeg... +powershell -Command "(New-Object Net.WebClient).DownloadFile('https://github.com/GyanD/codexffmpeg/releases/download/5.1.2/ffmpeg-5.1.2-essentials_build.zip', 'ffmpeg.zip')" + +echo Unzipping... +powershell Expand-Archive ffmpeg.zip -DestinationPath ..\external\ffmpeg -Force + +echo Cleaning up... +if exist ffmpeg.zip del /f /q ffmpeg.zip +exit /b diff --git a/gui/src/camera_path.cu b/gui/src/camera_path.cu new file mode 100644 index 0000000000000000000000000000000000000000..b1fb9c24a9bc9e8a44d82dbda73f49b1ff6dd112 --- /dev/null +++ b/gui/src/camera_path.cu @@ -0,0 +1,693 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file camera_path.cpp + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#include +#include +#include + +#ifdef NGP_GUI +# include +# include +#endif + +#include +#include + +using namespace nlohmann; + +namespace ngp { + +CameraKeyframe lerp(const CameraKeyframe& p0, const CameraKeyframe& p1, float t, float t0, float t1) { + t = (t - t0) / (t1 - t0); + quat R1 = p1.R; + + // take the short path + if (dot(R1, p0.R) < 0.0f) { + R1 = -R1; + } + + return { + normalize(slerp(p0.R, R1, t)), + p0.T + (p1.T - p0.T) * t, + p0.fov + (p1.fov - p0.fov) * t, + p0.timestamp + (p1.timestamp - p0.timestamp) * t, + }; +} + +CameraKeyframe normalize(const CameraKeyframe& p0) { + CameraKeyframe result = p0; + result.R = normalize(result.R); + return result; +} + +CameraKeyframe spline_cm(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2, const CameraKeyframe& p3) { + CameraKeyframe q0 = lerp(p0, p1, t, -1.f, 0.f); + CameraKeyframe q1 = lerp(p1, p2, t, 0.f, 1.f); + CameraKeyframe q2 = lerp(p2, p3, t, 1.f, 2.f); + CameraKeyframe r0 = lerp(q0, q1, t, -1.f, 1.f); + CameraKeyframe r1 = lerp(q1, q2, t, 0.f, 2.f); + return lerp(r0, r1, t, 0.f, 1.f); +} + +CameraKeyframe spline_cubic(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2, const CameraKeyframe& p3) { + float tt = t * t; + float ttt = t * t * t; + float a = (1 - t) * (1 - t) * (1 - t) * (1.f / 6.f); + float b = (3.f * ttt - 6.f * tt + 4.f) * (1.f / 6.f); + float c = (-3.f * ttt + 3.f * tt + 3.f * t + 1.f) * (1.f / 6.f); + float d = ttt * (1.f / 6.f); + return normalize(p0 * a + p1 * b + p2 * c + p3 * d); +} + +CameraKeyframe spline_quadratic(float t, const CameraKeyframe& p0, const CameraKeyframe& p1, const CameraKeyframe& p2) { + float tt = t * t; + float a = (1 - t) * (1 - t) * 0.5f; + float b = (-2.f * tt + 2.f * t + 1.f) * 0.5f; + float c = tt * 0.5f; + return normalize(p0 * a + p1 * b + p2 * c); +} + +CameraKeyframe spline_linear(float t, const CameraKeyframe& p0, const CameraKeyframe& p1) { return normalize(p0 * (1.0f - t) + p1 * t); } + +void to_json(json& j, const CameraKeyframe& p) { + j = json{ + {"R", p.R }, + {"T", p.T }, + {"fov", p.fov }, + {"timestamp", p.timestamp}, + }; +} + +bool load_relative_to_first = false; // set to true when using a camera path that is aligned with the first training image, such that it is + // invariant to changes in the space of the training data + +void from_json(bool is_first, const json& j, CameraKeyframe& p, const CameraKeyframe& first, const mat4x3& ref) { + if (is_first && load_relative_to_first) { + p.from_m(ref); + } else { + p.R = j.at("R"); + p.T = j.at("T"); + + if (load_relative_to_first) { + mat4 ref4 = {ref}; + mat4 first4 = {first.m()}; + mat4 p4 = {p.m()}; + p.from_m(mat4x3(ref4 * inverse(first4) * p4)); + } + } + j.at("fov").get_to(p.fov); + if (j.contains("timestamp")) { + j.at("timestamp").get_to(p.timestamp); + } else { + p.timestamp = 0.f; + } +} + +void CameraPath::save(const fs::path& path) { + json j = { + {"loop", loop }, + {"time", play_time }, + {"path", keyframes }, + {"duration_seconds", duration_seconds()}, + {"spline_order", spline_order }, + }; + std::ofstream f(native_string(path)); + f << j; +} + +void CameraPath::load(const fs::path& path, const mat4x3& first_xform) { + std::ifstream f{native_string(path)}; + if (!f) { + throw std::runtime_error{fmt::format("Camera path {} does not exist.", path.str())}; + } + + json j; + f >> j; + + CameraKeyframe first; + + keyframes.clear(); + if (j.contains("loop")) { + loop = j["loop"]; + } + if (j.contains("time")) { + play_time = j["time"]; + } + if (j.contains("path")) { + for (auto& el : j["path"]) { + CameraKeyframe p; + bool is_first = keyframes.empty(); + from_json(is_first, el, p, first, first_xform); + if (is_first) { + first = p; + } + keyframes.push_back(p); + } + } + + spline_order = j.value("spline_order", 3); + sanitize_keyframes(); + + play_time = 0.0f; + + if (keyframes.size() >= 16) { + keyframe_subsampling = keyframes.size() - 1; + editing_kernel_type = EEditingKernel::Gaussian; + } +} + +void CameraPath::add_camera(const mat4x3& camera, float fov, float timestamp) { + int n = std::max(0, int(keyframes.size()) - 1); + int i = (int)ceil(play_time * (float)n + 0.001f); + if (i > keyframes.size()) { + i = keyframes.size(); + } + if (i < 0) { + i = 0; + } + keyframes.insert(keyframes.begin() + i, CameraKeyframe(camera, fov, timestamp)); + update_cam_from_path = false; + play_time = get_playtime(i); + + sanitize_keyframes(); +} + +float editing_kernel(float x, EEditingKernel kernel) { + x = kernel == EEditingKernel::Gaussian ? x : clamp(x, -1.0f, 1.0f); + switch (kernel) { + case EEditingKernel::Gaussian: return expf(-2.0f * x * x); + case EEditingKernel::Quartic: return (1.0f - x * x) * (1.0f - x * x); + case EEditingKernel::Hat: return 1.0f - fabsf(x); + case EEditingKernel::Box: return x > -1.0f && x < 1.0f ? 1.0f : 0.0f; + case EEditingKernel::None: return fabs(x) < 0.0001f ? 1.0f : 0.0f; + default: throw std::runtime_error{"Unknown editing kernel"}; + } +} + +#ifdef NGP_GUI +int CameraPath::imgui(char path_filename_buf[1024], float frame_milliseconds, const mat4x3& camera, float fov, const mat4x3& first_xform) { + int n = std::max(0, int(keyframes.size()) - 1); + int read = 0; // 1=smooth, 2=hard + + ImGui::InputText("##PathFile", path_filename_buf, 1024); + ImGui::SameLine(); + static std::string camera_path_load_error_string = ""; + + if (rendering) { + ImGui::BeginDisabled(); + } + + if (ImGui::Button("Load")) { + try { + load(path_filename_buf, first_xform); + } catch (const std::exception& e) { + ImGui::OpenPopup("Camera path load error"); + camera_path_load_error_string = std::string{"Failed to load camera path: "} + e.what(); + } + } + + if (rendering) { + ImGui::EndDisabled(); + } + + if (ImGui::BeginPopupModal("Camera path load error", NULL, ImGuiWindowFlags_AlwaysAutoResize)) { + ImGui::Text("%s", camera_path_load_error_string.c_str()); + if (ImGui::Button("OK", ImVec2(120, 0))) { + ImGui::CloseCurrentPopup(); + } + ImGui::EndPopup(); + } + + if (!keyframes.empty()) { + ImGui::SameLine(); + if (ImGui::Button("Save")) { + save(path_filename_buf); + } + } + + if (rendering) { + ImGui::BeginDisabled(); + } + + if (ImGui::Button("Add from cam")) { + const float duration = duration_seconds(); + add_camera(camera, fov, 0.0f); + make_keyframe_timestamps_equidistant(duration); + read = 2; + } + + auto p = get_pos(play_time); + + if (!keyframes.empty()) { + ImGui::SameLine(); + if (ImGui::Button("Split")) { + update_cam_from_path = false; + int i = clamp(p.kfidx + 1, 0, (int)keyframes.size()); + const float duration = duration_seconds(); + keyframes.insert(keyframes.begin() + i, eval_camera_path(play_time)); + make_keyframe_timestamps_equidistant(duration); + play_time = get_playtime(i); + read = 2; + } + ImGui::SameLine(); + + int i = p.kfidx; + if (!loop) { + i += (int)round(p.t); + } + + if (ImGui::Button("|<")) { + play_time = 0.f; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button("<")) { + play_time = n ? (get_playtime(i - 1) + 0.0001f) : 0.f; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button(update_cam_from_path ? "Stop" : "Read")) { + update_cam_from_path = !update_cam_from_path; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button(">")) { + play_time = n ? (get_playtime(i + 1) + 0.0001f) : 1.0f; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button(">|")) { + play_time = 1.0f; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button("Dup")) { + update_cam_from_path = false; + const float duration = duration_seconds(); + keyframes.insert(keyframes.begin() + i, keyframes[i]); + make_keyframe_timestamps_equidistant(duration); + play_time = get_playtime(i); + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button("Del")) { + update_cam_from_path = false; + const float duration = duration_seconds(); + keyframes.erase(keyframes.begin() + i); + make_keyframe_timestamps_equidistant(duration); + play_time = get_playtime(i - 1); + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button("Set")) { + keyframes[i] = CameraKeyframe(camera, fov, keyframes[i].timestamp); + read = 2; + if (n) { + play_time = get_playtime(i); + } + } + + if (ImGui::RadioButton("Translate", m_gizmo_op == ImGuizmo::TRANSLATE)) { + m_gizmo_op = ImGuizmo::TRANSLATE; + } + ImGui::SameLine(); + if (ImGui::RadioButton("Rotate", m_gizmo_op == ImGuizmo::ROTATE)) { + m_gizmo_op = ImGuizmo::ROTATE; + } + ImGui::SameLine(); + if (ImGui::RadioButton("Local", m_gizmo_mode == ImGuizmo::LOCAL)) { + m_gizmo_mode = ImGuizmo::LOCAL; + } + ImGui::SameLine(); + if (ImGui::RadioButton("World", m_gizmo_mode == ImGuizmo::WORLD)) { + m_gizmo_mode = ImGuizmo::WORLD; + } + ImGui::SameLine(); + ImGui::Checkbox("Loop path", &loop); + + if (ImGui::Button("Start") && !keyframes.empty()) { + auto_play_speed = 0.0f; + play_time = 0.0f; + read = 2; + } + ImGui::SameLine(); + if (ImGui::Button("Rev") && !keyframes.empty()) { + auto_play_speed = -1.0f / duration_seconds(); + } + ImGui::SameLine(); + if (ImGui::Button(auto_play_speed != 0 ? "Pause" : "Play") && !keyframes.empty()) { + auto_play_speed = auto_play_speed == 0.0f ? (1.0f / duration_seconds()) : 0.0f; + } + ImGui::SameLine(); + if (ImGui::Button("End") && !keyframes.empty()) { + auto_play_speed = 0.0f; + play_time = 1.0f; + read = 2; + } + + ImGui::SliderFloat("Playback speed", &auto_play_speed, -1.0f, 1.0f); + if (auto_play_speed != 0.0f) { + float prev = play_time; + play_time = clamp(play_time + auto_play_speed * (frame_milliseconds / 1000.f), 0.0f, 1.0f); + + if (play_time != prev) { + read = 1; + } + } + + if (ImGui::SliderFloat("Camera path time", &play_time, 0.0f, 1.0f)) { + read = 1; + } + ImGui::Text("Current keyframe %d/%d:", i, n + 1); + + if (ImGui::SliderFloat("Field of view", &keyframes[i].fov, 0.0f, 120.0f)) { + read = 2; + } + if (ImGui::Button("Apply to all keyframes")) { + for (auto& k : keyframes) { + k.fov = keyframes[i].fov; + } + } + + + if (ImGui::TreeNodeEx("Batch keyframe editing")) { + ImGui::Combo("Editing kernel", (int*)&editing_kernel_type, EditingKernelStr); + ImGui::SliderFloat( + "Editing kernel radius", &editing_kernel_radius, 0.001f, 10.0f, "%.4f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat + ); + + ImGui::TreePop(); + } + + if (ImGui::TreeNodeEx("Advanced camera path settings")) { + ImGui::SliderInt("Spline order", &spline_order, 0, 3); + ImGui::SliderInt("Keyframe subsampling", &keyframe_subsampling, 1, max((int)keyframes.size() - 1, 1)); + ImGui::TreePop(); + } + } + + if (rendering) { + ImGui::EndDisabled(); + } + + return keyframes.empty() ? 0 : read; +} + +bool debug_project(const mat4& proj, vec3 p, ImVec2& o) { + vec4 ph{p.x, p.y, p.z, 1.0f}; + vec4 pa = proj * ph; + if (pa.w <= 0.f) { + return false; + } + + o.x = pa.x / pa.w; + o.y = pa.y / pa.w; + return true; +} + +void add_debug_line(ImDrawList* list, const mat4& proj, vec3 a, vec3 b, uint32_t col, float thickness) { + ImVec2 aa, bb; + if (debug_project(proj, a, aa) && debug_project(proj, b, bb)) { + list->AddLine(aa, bb, col, thickness * 2.0f); + } +} + +void visualize_cube(ImDrawList* list, const mat4& world2proj, const vec3& a, const vec3& b, const mat3& render_aabb_to_local) { + mat3 m = transpose(render_aabb_to_local); + add_debug_line(list, world2proj, m * vec3{a.x, a.y, a.z}, m * vec3{a.x, a.y, b.z}, 0xffff4040); // Z + add_debug_line(list, world2proj, m * vec3{b.x, a.y, a.z}, m * vec3{b.x, a.y, b.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{a.x, b.y, a.z}, m * vec3{a.x, b.y, b.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{b.x, b.y, a.z}, m * vec3{b.x, b.y, b.z}, 0xffffffff); + + add_debug_line(list, world2proj, m * vec3{a.x, a.y, a.z}, m * vec3{b.x, a.y, a.z}, 0xff4040ff); // X + add_debug_line(list, world2proj, m * vec3{a.x, b.y, a.z}, m * vec3{b.x, b.y, a.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{a.x, a.y, b.z}, m * vec3{b.x, a.y, b.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{a.x, b.y, b.z}, m * vec3{b.x, b.y, b.z}, 0xffffffff); + + add_debug_line(list, world2proj, m * vec3{a.x, a.y, a.z}, m * vec3{a.x, b.y, a.z}, 0xff40ff40); // Y + add_debug_line(list, world2proj, m * vec3{b.x, a.y, a.z}, m * vec3{b.x, b.y, a.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{a.x, a.y, b.z}, m * vec3{a.x, b.y, b.z}, 0xffffffff); + add_debug_line(list, world2proj, m * vec3{b.x, a.y, b.z}, m * vec3{b.x, b.y, b.z}, 0xffffffff); +} + +void visualize_camera(ImDrawList* list, const mat4& world2proj, const mat4x3& xform, float aspect, uint32_t col, float thickness) { + const float axis_size = 0.025f; + const vec3* xforms = (const vec3*)&xform; + vec3 pos = xforms[3]; + add_debug_line(list, world2proj, pos, pos + axis_size * xforms[0], 0xff4040ff, thickness); + add_debug_line(list, world2proj, pos, pos + axis_size * xforms[1], 0xff40ff40, thickness); + add_debug_line(list, world2proj, pos, pos + axis_size * xforms[2], 0xffff4040, thickness); + float xs = axis_size * aspect; + float ys = axis_size; + float zs = axis_size * 2.0f * aspect; + vec3 a = pos + xs * xforms[0] + ys * xforms[1] + zs * xforms[2]; + vec3 b = pos - xs * xforms[0] + ys * xforms[1] + zs * xforms[2]; + vec3 c = pos - xs * xforms[0] - ys * xforms[1] + zs * xforms[2]; + vec3 d = pos + xs * xforms[0] - ys * xforms[1] + zs * xforms[2]; + add_debug_line(list, world2proj, pos, a, col, thickness); + add_debug_line(list, world2proj, pos, b, col, thickness); + add_debug_line(list, world2proj, pos, c, col, thickness); + add_debug_line(list, world2proj, pos, d, col, thickness); + add_debug_line(list, world2proj, a, b, col, thickness); + add_debug_line(list, world2proj, b, c, col, thickness); + add_debug_line(list, world2proj, c, d, col, thickness); + add_debug_line(list, world2proj, d, a, col, thickness); +} + +bool CameraPath::has_valid_timestamps() const { + float prev_timestamp = 0.0f; + for (size_t i = 0; i < keyframes.size(); ++i) { + if (!(keyframes[i].timestamp > prev_timestamp)) { + return false; + } + + prev_timestamp = keyframes[i].timestamp; + } + + return true; +} + +void CameraPath::make_keyframe_timestamps_equidistant(const float duration_seconds) { + const float sanitized_duration = duration_seconds > 0.0f ? duration_seconds : default_duration_seconds; + for (size_t i = 0; i < keyframes.size(); ++i) { + keyframes[i].timestamp = sanitized_duration * (i + 1) / (float)keyframes.size(); + } +} + +void CameraPath::sanitize_keyframes() { + if (has_valid_timestamps()) { + return; + } + + // Timestamps are invalid. Best effort is to equally space all frames. Default to 3 seconds duration. + make_keyframe_timestamps_equidistant(default_duration_seconds); +} + +float CameraPath::duration_seconds() const { + if (keyframes.empty()) { + return 0.0f; + } + + return keyframes.back().timestamp; +} + +void CameraPath::set_duration_seconds(const float duration) { + const float old_duration = duration_seconds(); + if (!(old_duration > 0.0f)) { + make_keyframe_timestamps_equidistant(duration); + return; + } + + const float multiplier = duration / old_duration; + for (auto& kf : keyframes) { + kf.timestamp *= multiplier; + } +} + +CameraPath::Pos CameraPath::get_pos(float playtime) { + if (keyframes.empty()) { + return {-1, 0.0f}; + } else if (keyframes.size() == 1) { + return {0, playtime}; + } + + const float duration = loop ? keyframes.back().timestamp : keyframes[keyframes.size() - 2].timestamp; + playtime *= duration; + + CameraKeyframe dummy; + dummy.timestamp = playtime; + + // Binary search to obtain relevant keyframe in O(log(n_keyframes)) time + auto it = std::upper_bound(keyframes.begin(), keyframes.end(), dummy, [](const auto& a, const auto& b) { + return a.timestamp < b.timestamp; + }); + + int i = clamp((int)std::distance(keyframes.begin(), it), 0, (int)keyframes.size() - (loop ? 1 : 2)); + float prev_timestamp = i == 0 ? 0.0f : keyframes[i - 1].timestamp; + + return { + i, + (playtime - prev_timestamp) / (keyframes[i].timestamp - prev_timestamp), + }; +} + +bool CameraPath::imgui_viz( + ImDrawList* list, + mat4& view2proj, + mat4& world2proj, + mat4& world2view, + vec2 focal, + float aspect, + float znear, + float zfar +) { + bool changed = false; + // float flx = focal.x; + float fly = focal.y; + mat4 view2proj_guizmo = transpose( + mat4{ + fly * 2.0f / aspect, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + -fly * 2.0f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + (zfar + znear) / (zfar - znear), + -(2.0f * zfar * znear) / (zfar - znear), + 0.0f, + 0.0f, + 1.0f, + 0.0f, + } + ); + + if (!update_cam_from_path && !keyframes.empty()) { + auto p = get_pos(play_time); + int cur_cam_i = p.kfidx; + if (!loop) { + cur_cam_i += (int)round(p.t); + } + + vec3 prevp; + for (int i = 0; i < keyframes.size(); i += max(min(keyframe_subsampling, (int)keyframes.size() - 1 - i), 1)) { + visualize_camera(list, world2proj, keyframes[i].m(), aspect, (i == cur_cam_i) ? 0xff80c0ff : 0x8080c0ff); + vec3 p = keyframes[i].T; + if (i && keyframe_subsampling == 1) { + add_debug_line(list, world2proj, prevp, p, 0xccffc040); + } + prevp = p; + } + + ImGuiIO& io = ImGui::GetIO(); + mat4 matrix = keyframes[cur_cam_i].m(); + ImGuizmo::SetRect(0, 0, io.DisplaySize.x, io.DisplaySize.y); + if (ImGuizmo::Manipulate( + (const float*)&world2view, + (const float*)&view2proj_guizmo, + (ImGuizmo::OPERATION)m_gizmo_op, + (ImGuizmo::MODE)m_gizmo_mode, + (float*)&matrix, + NULL, + NULL + )) { + // Find overlapping keypoints... + int i0 = cur_cam_i; + while (i0 > 0 && keyframes[cur_cam_i].same_pos_as(keyframes[i0 - 1])) { + i0--; + } + int i1 = cur_cam_i; + while (i1 < keyframes.size() - 1 && keyframes[cur_cam_i].same_pos_as(keyframes[i1 + 1])) { + i1++; + } + + vec3 tdiff = matrix[3].xyz() - keyframes[cur_cam_i].T; + mat3 rdiff = mat_log(mat3(matrix) * inverse(to_mat3(normalize(keyframes[cur_cam_i].R)))); + + for (int i = 0; i < keyframes.size(); ++i) { + float x = (get_playtime(i) - get_playtime(cur_cam_i)) / editing_kernel_radius; + float w = editing_kernel(x, editing_kernel_type); + + keyframes[i].T += w * tdiff; + keyframes[i].R = quat(mat_exp(w * rdiff) * to_mat3(normalize(keyframes[i].R))); + } + + // ...and ensure overlapping keypoints were edited exactly in tandem + for (int i = i0; i <= i1; ++i) { + keyframes[i].T = keyframes[cur_cam_i].T; + keyframes[i].R = keyframes[cur_cam_i].R; + } + + changed = true; + } + + visualize_camera(list, world2proj, eval_camera_path(play_time).m(), aspect, 0xff80ff80); + + float dt = 0.001f; + float total_length = 0.0f; + for (float t = 0.0f;; t += dt) { + if (t > 1.0f) { + t = 1.0f; + } + vec3 p = eval_camera_path(t).T; + if (t) { + total_length += distance(prevp, p); + } + prevp = p; + if (t >= 1.0f) { + break; + } + } + + dt = 0.001f / total_length; + static const uint32_t N_DASH_STEPS = 10; + uint32_t i = 0; + for (float t = 0.0f;; t += dt, ++i) { + if (t > 1.0f) { + t = 1.0f; + } + vec3 p = eval_camera_path(t).T; + if (t && (i / N_DASH_STEPS) % 2 == 0) { + float thickness = 1.0f; + if (editing_kernel_type != EEditingKernel::None) { + float x = (t + dt / 2.0f - get_playtime(cur_cam_i)) / editing_kernel_radius; + thickness += 4.0f * editing_kernel(x, editing_kernel_type); + } + + add_debug_line(list, world2proj, prevp, p, 0xff80c0ff, thickness); + } + + prevp = p; + if (t >= 1.0f) { + break; + } + } + + } + + return changed; +} +#endif // NGP_GUI + +} // namespace ngp diff --git a/gui/src/common_host.cu b/gui/src/common_host.cu new file mode 100644 index 0000000000000000000000000000000000000000..62b07c192a06446676181289326b2e4a92246f9b --- /dev/null +++ b/gui/src/common_host.cu @@ -0,0 +1,281 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common_host.cu + * @author Thomas Müller, NVIDIA + */ + +#include +#include +#include +#include + +#include + +#include + +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_WRITE_IMPLEMENTATION + +#ifdef __CUDACC__ +# ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ +# pragma nv_diag_suppress 550 +# else +# pragma diag_suppress 550 +# endif +#endif +#include +#include +#ifdef __CUDACC__ +# ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ +# pragma nv_diag_default 550 +# else +# pragma diag_default 550 +# endif +#endif + +#ifdef _WIN32 +# include +#else +# include +# include +#endif + +#undef min +#undef max +#undef near +#undef far + +namespace ngp { + +bool is_wsl() { +#ifdef _WIN32 + return false; +#else + fs::path path = "/proc/sys/kernel/osrelease"; + if (!path.exists()) { + return false; + } + + std::ifstream f{native_string(path)}; + std::string content((std::istreambuf_iterator(f)), (std::istreambuf_iterator())); + return content.find("microsoft") != std::string::npos; +#endif +} + +#ifdef _WIN32 +std::string utf16_to_utf8(const std::wstring& utf16) { + std::string utf8; + if (!utf16.empty()) { + int size = WideCharToMultiByte(CP_UTF8, 0, &utf16[0], (int)utf16.size(), NULL, 0, NULL, NULL); + utf8.resize(size, 0); + WideCharToMultiByte(CP_UTF8, 0, &utf16[0], (int)utf16.size(), &utf8[0], size, NULL, NULL); + } + return utf8; +} + +std::wstring utf8_to_utf16(const std::string& utf8) { + std::wstring utf16; + if (!utf8.empty()) { + int size = MultiByteToWideChar(CP_UTF8, 0, &utf8[0], (int)utf8.size(), NULL, 0); + utf16.resize(size, 0); + MultiByteToWideChar(CP_UTF8, 0, &utf8[0], (int)utf8.size(), &utf16[0], size); + } + return utf16; +} + +std::wstring native_string(const fs::path& path) { return path.wstr(); } +#else +std::string native_string(const fs::path& path) { return path.str(); } +#endif + +fs::path discover_executable_dir() { +#ifdef _WIN32 + WCHAR path[1024]; + if (GetModuleFileNameW(NULL, path, 1024) == 0) { + return "."; + } + return fs::path{std::wstring{path}}.parent_path(); +#else + char path[PATH_MAX]; + ssize_t count = readlink("/proc/self/exe", path, PATH_MAX); + if (count == -1) { + return "."; + } + return fs::path{std::string{path}}.parent_path(); +#endif +} + +fs::path discover_root_dir() { + auto executable_dir = discover_executable_dir(); + fs::path exists_in_root_dir = "scripts"; + for (const auto& candidate : { + fs::path{"."} / exists_in_root_dir, + fs::path{".."} / exists_in_root_dir, + executable_dir / exists_in_root_dir, + executable_dir / ".." / exists_in_root_dir, + }) { + if (candidate.exists()) { + return candidate.parent_path(); + } + } + + tlog::warning() << "Could not find root directory."; + return "."; +} + +bool ends_with(const std::string& str, const std::string& ending) { + if (ending.length() > str.length()) { + return false; + } + return std::equal(std::rbegin(ending), std::rend(ending), std::rbegin(str)); +} + +bool ends_with_case_insensitive(const std::string& str, const std::string& ending) { return ends_with(to_lower(str), to_lower(ending)); } + +ETestbedMode mode_from_scene(const std::string& scene) { + return ETestbedMode::None; +} + +ETestbedMode mode_from_string(const std::string& str) { + if (equals_case_insensitive(str, "image")) { + return ETestbedMode::Gen3c; + } else { + return ETestbedMode::None; + } +} + +std::string to_string(ETestbedMode mode) { + switch (mode) { + case ETestbedMode::Gen3c: return "gen3c"; + case ETestbedMode::None: return "none"; + default: throw std::runtime_error{fmt::format("Can not convert mode {} to string.", (int)mode)}; + } +} + +static const stbi_io_callbacks istream_stbi_callbacks = { + // Read + [](void* context, char* data, int size) { + auto stream = reinterpret_cast(context); + stream->read(data, size); + return (int)stream->gcount(); + }, + // Seek + [](void* context, int size) { reinterpret_cast(context)->seekg(size, std::ios_base::cur); }, + // EOF + [](void* context) { return (int)!!(*reinterpret_cast(context)); }, +}; + +void istream_stbi_write_func(void* context, void* data, int size) { + reinterpret_cast(context)->write(reinterpret_cast(data), size); +} + +uint8_t* load_stbi(const fs::path& path, int* width, int* height, int* comp, int req_comp) { + std::ifstream f{native_string(path), std::ios::in | std::ios::binary}; + return stbi_load_from_callbacks(&istream_stbi_callbacks, &f, width, height, comp, req_comp); +} + +float* load_stbi_float(const fs::path& path, int* width, int* height, int* comp, int req_comp) { + std::ifstream f{native_string(path), std::ios::in | std::ios::binary}; + return stbi_loadf_from_callbacks(&istream_stbi_callbacks, &f, width, height, comp, req_comp); +} + +uint16_t* load_stbi_16(const fs::path& path, int* width, int* height, int* comp, int req_comp) { + std::ifstream f{native_string(path), std::ios::in | std::ios::binary}; + return stbi_load_16_from_callbacks(&istream_stbi_callbacks, &f, width, height, comp, req_comp); +} + +bool is_hdr_stbi(const fs::path& path) { + std::ifstream f{native_string(path), std::ios::in | std::ios::binary}; + return stbi_is_hdr_from_callbacks(&istream_stbi_callbacks, &f); +} + +int write_stbi(const fs::path& path, int width, int height, int comp, const uint8_t* pixels, int quality) { + std::ofstream f{native_string(path), std::ios::out | std::ios::binary}; + + if (equals_case_insensitive(path.extension(), "jpg") || equals_case_insensitive(path.extension(), "jpeg")) { + return stbi_write_jpg_to_func(istream_stbi_write_func, &f, width, height, comp, pixels, quality); + } else if (equals_case_insensitive(path.extension(), "png")) { + return stbi_write_png_to_func(istream_stbi_write_func, &f, width, height, comp, pixels, width * comp); + } else if (equals_case_insensitive(path.extension(), "tga")) { + return stbi_write_tga_to_func(istream_stbi_write_func, &f, width, height, comp, pixels); + } else if (equals_case_insensitive(path.extension(), "bmp")) { + return stbi_write_bmp_to_func(istream_stbi_write_func, &f, width, height, comp, pixels); + } else { + throw std::runtime_error{fmt::format("write_stbi: unknown image extension '{}'", path.extension())}; + } +} + +FILE* native_fopen(const fs::path& path, const char* mode) { +#ifdef _WIN32 + return _wfopen(path.wstr().c_str(), utf8_to_utf16(mode).c_str()); +#else + return fopen(path.str().c_str(), mode); +#endif +} + +GPUMemory load_stbi_gpu(const fs::path& path, int* width, int* height) { + bool is_hdr = is_hdr_stbi(path); + + void* data; // width * height * RGBA + int comp; + if (is_hdr) { + data = load_stbi_float(path, width, height, &comp, 4); + } else { + data = load_stbi(path, width, height, &comp, 4); + } + + if (!data) { + throw std::runtime_error{std::string{stbi_failure_reason()}}; + } + + ScopeGuard mem_guard{[&]() { stbi_image_free(data); }}; + + if (*width == 0 || *height == 0) { + throw std::runtime_error{"Image has zero pixels."}; + } + + GPUMemory result((*width) * (*height) * 4); + if (is_hdr) { + result.copy_from_host((float*)data); + } else { + GPUMemory bytes((*width) * (*height) * 4); + bytes.copy_from_host((uint8_t*)data); + linear_kernel(from_rgba32, 0, nullptr, (*width) * (*height), bytes.data(), result.data(), false, false, 0); + } + + return result; +} + +std::ostream& operator<<(std::ostream& os, const BoundingBox& bb) { + os << "["; + os << "min=[" << bb.min.x << "," << bb.min.y << "," << bb.min.z << "], "; + os << "max=[" << bb.max.x << "," << bb.max.y << "," << bb.max.z << "]"; + os << "]"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const Triangle& triangle) { + os << "["; + os << "a=[" << triangle.a.x << "," << triangle.a.y << "," << triangle.a.z << "], "; + os << "b=[" << triangle.b.x << "," << triangle.b.y << "," << triangle.b.z << "], "; + os << "c=[" << triangle.c.x << "," << triangle.c.y << "," << triangle.c.z << "]"; + os << "]"; + return os; +} + +} // namespace ngp diff --git a/gui/src/dlss.cu b/gui/src/dlss.cu new file mode 100644 index 0000000000000000000000000000000000000000..e140cffdea9fb81c9faec5be9dd243e0d562d589 --- /dev/null +++ b/gui/src/dlss.cu @@ -0,0 +1,1230 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file dlss.cu + * @author Thomas Müller, NVIDIA + */ + +#include +#include + +#include + +#include + +#if !defined(NGP_VULKAN) || !defined(NGP_GUI) +static_assert(false, "DLSS can only be compiled when both Vulkan and GUI support is enabled.") +#endif + +#ifdef _WIN32 +# include +#else +# include +#endif +#include + +#ifdef _WIN32 +# include +#endif + +// NGX's macro `NVSDK_NGX_FAILED` results in a change of sign, which does not affect correctness. +// Thus, suppress the corresponding warning. +#ifdef __CUDACC__ +# ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ +# pragma nv_diag_suppress = integer_sign_change +# else +# pragma diag_suppress = integer_sign_change +# endif +#endif +#include +#include +#include + +#include +#include +#include + +namespace ngp { + +extern std::atomic g_total_n_bytes_allocated; + +/// Checks the result of a vkXXXXXX call and throws an error on failure +#define VK_CHECK_THROW(x) \ + do { \ + VkResult result = x; \ + if (result != VK_SUCCESS) \ + throw std::runtime_error(std::string(FILE_LINE " " #x " failed")); \ + } while(0) + +std::string ngx_error_string(NVSDK_NGX_Result result) { + std::wstring wstr = GetNGXResultAsString(result); + std::wstring_convert, wchar_t> converter; + return converter.to_bytes(wstr); +}; + +/// Checks the result of a NVSDK_NGX_XXXXXX call and throws an error on failure +#define NGX_CHECK_THROW(x) \ + do { \ + NVSDK_NGX_Result result = x; \ + if (NVSDK_NGX_FAILED(result)) \ + throw std::runtime_error(std::string(FILE_LINE " " #x " failed with error ") + ngx_error_string(result)); \ + } while(0) + +static VKAPI_ATTR VkBool32 VKAPI_CALL vk_debug_callback( + VkDebugUtilsMessageSeverityFlagBitsEXT message_severity, + VkDebugUtilsMessageTypeFlagsEXT message_type, + const VkDebugUtilsMessengerCallbackDataEXT* callback_data, + void* user_data +) { + // Ignore json files that couldn't be found... third party tools sometimes install bogus layers + // that manifest as warnings like this. + if (std::string{callback_data->pMessage}.find("Failed to open JSON file") != std::string::npos) { + return VK_FALSE; + } + + if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) { + tlog::warning() << "Vulkan error: " << callback_data->pMessage; + } else if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT) { + tlog::warning() << "Vulkan: " << callback_data->pMessage; + } else { + tlog::info() << "Vulkan: " << callback_data->pMessage; + } + + return VK_FALSE; +} + +std::set vk_supported_instance_layers() { + uint32_t count = 0; + VK_CHECK_THROW(vkEnumerateInstanceLayerProperties(&count, nullptr)); + std::vector layer_properties(count); + VK_CHECK_THROW(vkEnumerateInstanceLayerProperties(&count, layer_properties.data())); + + std::set layers; + for (auto& l : layer_properties) { + layers.insert(l.layerName); + } + + return layers; +} + +std::set vk_supported_device_layers(VkPhysicalDevice device) { + uint32_t count = 0; + VK_CHECK_THROW(vkEnumerateDeviceLayerProperties(device, &count, nullptr)); + std::vector layer_properties(count); + VK_CHECK_THROW(vkEnumerateDeviceLayerProperties(device, &count, layer_properties.data())); + + std::set layers; + for (auto& l : layer_properties) { + layers.insert(l.layerName); + } + + return layers; +} + +std::set vk_supported_instance_extensions(const char* layer_name) { + uint32_t count = 0; + VK_CHECK_THROW(vkEnumerateInstanceExtensionProperties(layer_name, &count, nullptr)); + std::vector extension_properties(count); + VK_CHECK_THROW(vkEnumerateInstanceExtensionProperties(layer_name, &count, extension_properties.data())); + + std::set extensions; + for (auto& e : extension_properties) { + extensions.insert(e.extensionName); + } + + return extensions; +} + +std::set vk_supported_device_extensions(VkPhysicalDevice device, const char* layer_name) { + uint32_t count = 0; + VK_CHECK_THROW(vkEnumerateDeviceExtensionProperties(device, layer_name, &count, nullptr)); + std::vector extension_properties(count); + VK_CHECK_THROW(vkEnumerateDeviceExtensionProperties(device, layer_name, &count, extension_properties.data())); + + std::set extensions; + for (auto& e : extension_properties) { + extensions.insert(e.extensionName); + } + + return extensions; +} + +class VulkanAndNgx : public IDlssProvider, public std::enable_shared_from_this { +public: + VulkanAndNgx() { + ScopeGuard cleanup_guard{[&]() { clear(); }}; + + if (!glfwVulkanSupported()) { + throw std::runtime_error{"!glfwVulkanSupported()"}; + } + + // ------------------------------- + // Vulkan Instance + // ------------------------------- + VkApplicationInfo app_info{}; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pApplicationName = "NGP"; + app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0); + app_info.pEngineName = "No engine"; + app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0); + app_info.apiVersion = VK_API_VERSION_1_0; + + VkInstanceCreateInfo instance_create_info = {}; + instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instance_create_info.pApplicationInfo = &app_info; + + std::vector instance_extensions; + std::vector device_extensions; + + uint32_t n_ngx_instance_extensions = 0; + const char** ngx_instance_extensions; + + uint32_t n_ngx_device_extensions = 0; + const char** ngx_device_extensions; + + NVSDK_NGX_VULKAN_RequiredExtensions(&n_ngx_instance_extensions, &ngx_instance_extensions, &n_ngx_device_extensions, &ngx_device_extensions); + + for (uint32_t i = 0; i < n_ngx_instance_extensions; ++i) { + instance_extensions.emplace_back(ngx_instance_extensions[i]); + } + + instance_extensions.emplace_back(VK_KHR_DEVICE_GROUP_CREATION_EXTENSION_NAME); + instance_extensions.emplace_back(VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME); + instance_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME); + instance_extensions.emplace_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + + auto supported_instance_layers = vk_supported_instance_layers(); + + const char* validation_layer_name = "VK_LAYER_KHRONOS_validation"; + bool instance_validation_layer_enabled = supported_instance_layers.count(validation_layer_name) > 0; + if (!instance_validation_layer_enabled) { + tlog::warning() << "Vulkan instance validation layer is not available. Vulkan errors will be difficult to diagnose."; + } + + std::vector instance_layers; + if (instance_validation_layer_enabled) { + instance_layers.emplace_back(validation_layer_name); + } + + instance_create_info.enabledLayerCount = static_cast(instance_layers.size()); + instance_create_info.ppEnabledLayerNames = instance_layers.empty() ? nullptr : instance_layers.data(); + + if (instance_validation_layer_enabled) { + instance_extensions.emplace_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + } + + auto supported_instance_extensions = vk_supported_instance_extensions(nullptr); + for (const auto& e : instance_extensions) { + if (supported_instance_extensions.count(e) == 0) { + throw std::runtime_error{fmt::format("Required instance extension '{}' is not supported.", e)}; + } + } + + instance_create_info.enabledExtensionCount = (uint32_t)instance_extensions.size(); + instance_create_info.ppEnabledExtensionNames = instance_extensions.data(); + + VkDebugUtilsMessengerCreateInfoEXT debug_messenger_create_info = {}; + debug_messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; + debug_messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT; + debug_messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; + debug_messenger_create_info.pfnUserCallback = vk_debug_callback; + debug_messenger_create_info.pUserData = nullptr; + + if (instance_validation_layer_enabled) { + instance_create_info.pNext = &debug_messenger_create_info; + } + + VK_CHECK_THROW(vkCreateInstance(&instance_create_info, nullptr, &m_vk_instance)); + + if (instance_validation_layer_enabled) { + auto CreateDebugUtilsMessengerEXT = [](VkInstance instance, const VkDebugUtilsMessengerCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugUtilsMessengerEXT* pDebugMessenger) { + auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkCreateDebugUtilsMessengerEXT"); + if (func != nullptr) { + return func(instance, pCreateInfo, pAllocator, pDebugMessenger); + } else { + return VK_ERROR_EXTENSION_NOT_PRESENT; + } + }; + + if (CreateDebugUtilsMessengerEXT(m_vk_instance, &debug_messenger_create_info, nullptr, &m_vk_debug_messenger) != VK_SUCCESS) { + tlog::warning() << "Vulkan: could not initialize debug messenger."; + } + } + + // ------------------------------- + // Vulkan Physical Device + // ------------------------------- + uint32_t n_devices = 0; + vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, nullptr); + + if (n_devices == 0) { + throw std::runtime_error{"Failed to find GPUs with Vulkan support."}; + } + + std::vector devices(n_devices); + vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, devices.data()); + + struct QueueFamilyIndices { + int graphics_family = -1; + int compute_family = -1; + int transfer_family = -1; + int all_family = -1; + }; + + auto find_queue_families = [](VkPhysicalDevice device) { + QueueFamilyIndices indices; + + uint32_t queue_family_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr); + + std::vector queue_families(queue_family_count); + vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data()); + + int i = 0; + for (const auto& queue_family : queue_families) { + if (queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) { + indices.graphics_family = i; + } + + if (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) { + indices.compute_family = i; + } + + if (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT) { + indices.transfer_family = i; + } + + if ((queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) && (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) && (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT)) { + indices.all_family = i; + } + + i++; + } + + return indices; + }; + + cudaDeviceProp cuda_device_prop; + CUDA_CHECK_THROW(cudaGetDeviceProperties(&cuda_device_prop, cuda_device())); + + auto is_same_as_cuda_device = [&](VkPhysicalDevice device) { + VkPhysicalDeviceIDProperties physical_device_id_properties = {}; + physical_device_id_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES; + physical_device_id_properties.pNext = NULL; + + VkPhysicalDeviceProperties2 physical_device_properties = {}; + physical_device_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + physical_device_properties.pNext = &physical_device_id_properties; + + vkGetPhysicalDeviceProperties2(device, &physical_device_properties); + + return !memcmp(&cuda_device_prop.uuid, physical_device_id_properties.deviceUUID, VK_UUID_SIZE) && find_queue_families(device).all_family >= 0; + }; + + uint32_t device_id = 0; + for (uint32_t i = 0; i < n_devices; ++i) { + if (is_same_as_cuda_device(devices[i])) { + m_vk_physical_device = devices[i]; + device_id = i; + break; + } + } + + if (m_vk_physical_device == VK_NULL_HANDLE) { + throw std::runtime_error{"Failed to find Vulkan device corresponding to CUDA device."}; + } + + for (uint32_t i = 0; i < n_ngx_device_extensions; ++i) { + device_extensions.emplace_back(ngx_device_extensions[i]); + } + + device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME); +#ifdef _WIN32 + device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME); +#else + device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME); +#endif + device_extensions.emplace_back(VK_KHR_DEVICE_GROUP_EXTENSION_NAME); + + auto supported_device_extensions = vk_supported_device_extensions(m_vk_physical_device, nullptr); + for (const auto& e : device_extensions) { + if (supported_device_extensions.count(e) == 0) { + throw std::runtime_error{fmt::format("Required device extension '{}' is not supported.", e)}; + } + } + + // ------------------------------- + // Vulkan Logical Device + // ------------------------------- + VkPhysicalDeviceProperties physical_device_properties; + vkGetPhysicalDeviceProperties(m_vk_physical_device, &physical_device_properties); + + QueueFamilyIndices indices = find_queue_families(m_vk_physical_device); + + VkDeviceQueueCreateInfo queue_create_info{}; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.queueFamilyIndex = indices.all_family; + queue_create_info.queueCount = 1; + + float queue_priority = 1.0f; + queue_create_info.pQueuePriorities = &queue_priority; + + VkPhysicalDeviceFeatures device_features = {}; + device_features.shaderStorageImageWriteWithoutFormat = true; + + VkDeviceCreateInfo device_create_info = {}; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pEnabledFeatures = &device_features; + device_create_info.enabledExtensionCount = (uint32_t)device_extensions.size(); + device_create_info.ppEnabledExtensionNames = device_extensions.data(); + +#ifdef VK_EXT_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME + VkPhysicalDeviceBufferDeviceAddressFeaturesEXT buffer_device_address_feature = {}; + buffer_device_address_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_EXT; + buffer_device_address_feature.bufferDeviceAddress = VK_TRUE; + device_create_info.pNext = &buffer_device_address_feature; +#else + throw std::runtime_error{"Buffer device address extension not available."}; +#endif + + VK_CHECK_THROW(vkCreateDevice(m_vk_physical_device, &device_create_info, nullptr, &m_vk_device)); + + // ----------------------------------------------- + // Vulkan queue / command pool / command buffer + // ----------------------------------------------- + vkGetDeviceQueue(m_vk_device, indices.all_family, 0, &m_vk_queue); + + VkCommandPoolCreateInfo command_pool_info = {}; + command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + command_pool_info.queueFamilyIndex = indices.all_family; + + VK_CHECK_THROW(vkCreateCommandPool(m_vk_device, &command_pool_info, nullptr, &m_vk_command_pool)); + + VkCommandBufferAllocateInfo command_buffer_alloc_info = {}; + command_buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + command_buffer_alloc_info.commandPool = m_vk_command_pool; + command_buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + command_buffer_alloc_info.commandBufferCount = 1; + + VK_CHECK_THROW(vkAllocateCommandBuffers(m_vk_device, &command_buffer_alloc_info, &m_vk_command_buffer)); + + // ------------------------------- + // NGX init + // ------------------------------- + std::wstring path; +#ifdef _WIN32 + path = fs::path::getcwd().wstr(); +#else + std::string tmp = fs::path::getcwd().str(); + std::wstring_convert, wchar_t> converter; + path = converter.from_bytes(tmp); +#endif + + NGX_CHECK_THROW(NVSDK_NGX_VULKAN_Init_with_ProjectID("ea75345e-5a42-4037-a5c9-59bf94dee157", NVSDK_NGX_ENGINE_TYPE_CUSTOM, "1.0.0", path.c_str(), m_vk_instance, m_vk_physical_device, m_vk_device)); + m_ngx_initialized = true; + + // ------------------------------- + // Ensure DLSS capability + // ------------------------------- + NGX_CHECK_THROW(NVSDK_NGX_VULKAN_GetCapabilityParameters(&m_ngx_parameters)); + + int needs_updated_driver = 0; + unsigned int min_driver_version_major = 0; + unsigned int min_driver_version_minor = 0; + NVSDK_NGX_Result result_updated_driver = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_NeedsUpdatedDriver, &needs_updated_driver); + NVSDK_NGX_Result result_min_driver_version_major = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMajor, &min_driver_version_major); + NVSDK_NGX_Result result_min_driver_version_minor = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMinor, &min_driver_version_minor); + if (result_updated_driver == NVSDK_NGX_Result_Success && result_min_driver_version_major == NVSDK_NGX_Result_Success && result_min_driver_version_minor == NVSDK_NGX_Result_Success) { + if (needs_updated_driver) { + throw std::runtime_error{fmt::format("Driver too old. Minimum version required is {}.{}", min_driver_version_major, min_driver_version_minor)}; + } + } + + int dlss_available = 0; + NVSDK_NGX_Result ngx_result = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_Available, &dlss_available); + if (ngx_result != NVSDK_NGX_Result_Success || !dlss_available) { + ngx_result = NVSDK_NGX_Result_Fail; + NVSDK_NGX_Parameter_GetI(m_ngx_parameters, NVSDK_NGX_Parameter_SuperSampling_FeatureInitResult, (int*)&ngx_result); + throw std::runtime_error{fmt::format("DLSS not available: {}", ngx_error_string(ngx_result))}; + } + + cleanup_guard.disarm(); + + tlog::success() << "Initialized Vulkan and NGX on GPU #" << device_id << ": " << physical_device_properties.deviceName; + } + + virtual ~VulkanAndNgx() { + clear(); + } + + void clear() { + if (m_ngx_parameters) { + NVSDK_NGX_VULKAN_DestroyParameters(m_ngx_parameters); + m_ngx_parameters = nullptr; + } + + if (m_ngx_initialized) { + NVSDK_NGX_VULKAN_Shutdown(); + m_ngx_initialized = false; + } + + if (m_vk_command_pool) { + vkDestroyCommandPool(m_vk_device, m_vk_command_pool, nullptr); + m_vk_command_pool = VK_NULL_HANDLE; + } + + if (m_vk_device) { + vkDestroyDevice(m_vk_device, nullptr); + m_vk_device = VK_NULL_HANDLE; + } + + if (m_vk_debug_messenger) { + auto DestroyDebugUtilsMessengerEXT = [](VkInstance instance, VkDebugUtilsMessengerEXT debugMessenger, const VkAllocationCallbacks* pAllocator) { + auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkDestroyDebugUtilsMessengerEXT"); + if (func != nullptr) { + func(instance, debugMessenger, pAllocator); + } + }; + + DestroyDebugUtilsMessengerEXT(m_vk_instance, m_vk_debug_messenger, nullptr); + m_vk_debug_messenger = VK_NULL_HANDLE; + } + + if (m_vk_instance) { + vkDestroyInstance(m_vk_instance, nullptr); + m_vk_instance = VK_NULL_HANDLE; + } + } + + uint32_t vk_find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags properties) { + VkPhysicalDeviceMemoryProperties mem_properties; + vkGetPhysicalDeviceMemoryProperties(m_vk_physical_device, &mem_properties); + + for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++) { + if (type_filter & (1 << i) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties) { + return i; + } + } + + throw std::runtime_error{"Failed to find suitable memory type."}; + } + + void vk_command_buffer_begin() { + VkCommandBufferBeginInfo begin_info = {}; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + begin_info.pInheritanceInfo = nullptr; + + VK_CHECK_THROW(vkBeginCommandBuffer(m_vk_command_buffer, &begin_info)); + } + + void vk_command_buffer_end() { + VK_CHECK_THROW(vkEndCommandBuffer(m_vk_command_buffer)); + } + + void vk_command_buffer_submit() { + VkSubmitInfo submit_info = { VK_STRUCTURE_TYPE_SUBMIT_INFO }; + submit_info.commandBufferCount = 1; + submit_info.pCommandBuffers = &m_vk_command_buffer; + + VK_CHECK_THROW(vkQueueSubmit(m_vk_queue, 1, &submit_info, VK_NULL_HANDLE)); + } + + void vk_synchronize() { + VK_CHECK_THROW(vkDeviceWaitIdle(m_vk_device)); + } + + void vk_command_buffer_submit_sync() { + vk_command_buffer_submit(); + vk_synchronize(); + } + + void vk_command_buffer_end_and_submit_sync() { + vk_command_buffer_end(); + vk_command_buffer_submit_sync(); + } + + const VkCommandBuffer& vk_command_buffer() const { + return m_vk_command_buffer; + } + + const VkDevice& vk_device() const { + return m_vk_device; + } + + NVSDK_NGX_Parameter* ngx_parameters() const { + return m_ngx_parameters; + } + + size_t allocated_bytes() const override { + unsigned long long allocated_bytes = 0; + if (!m_ngx_parameters) { + return 0; + } + + try { + NGX_CHECK_THROW(NGX_DLSS_GET_STATS(m_ngx_parameters, &allocated_bytes)); + } catch (...) { + return 0; + } + + return allocated_bytes; + } + + std::unique_ptr init_dlss(const ivec2& out_resolution) override; + +private: + VkInstance m_vk_instance = VK_NULL_HANDLE; + VkDebugUtilsMessengerEXT m_vk_debug_messenger = VK_NULL_HANDLE; + VkPhysicalDevice m_vk_physical_device = VK_NULL_HANDLE; + VkDevice m_vk_device = VK_NULL_HANDLE; + VkQueue m_vk_queue = VK_NULL_HANDLE; + VkCommandPool m_vk_command_pool = VK_NULL_HANDLE; + VkCommandBuffer m_vk_command_buffer = VK_NULL_HANDLE; + NVSDK_NGX_Parameter* m_ngx_parameters = nullptr; + bool m_ngx_initialized = false; +}; + +std::shared_ptr init_vulkan_and_ngx() { + return std::make_shared(); +} + +class VulkanTexture { +public: + VulkanTexture(std::shared_ptr vk, const ivec2& size, uint32_t n_channels) : m_vk{vk}, m_size{size}, m_n_channels{n_channels} { + ScopeGuard cleanup_guard{[&]() { clear(); }}; + + VkImageCreateInfo image_info{}; + image_info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO; + image_info.imageType = VK_IMAGE_TYPE_2D; + image_info.extent.width = static_cast(m_size.x); + image_info.extent.height = static_cast(m_size.y); + image_info.extent.depth = 1; + image_info.mipLevels = 1; + image_info.arrayLayers = 1; + + switch (n_channels) { + case 1: image_info.format = VK_FORMAT_R32_SFLOAT; break; + case 2: image_info.format = VK_FORMAT_R32G32_SFLOAT; break; + case 3: image_info.format = VK_FORMAT_R32G32B32_SFLOAT; break; + case 4: image_info.format = VK_FORMAT_R32G32B32A32_SFLOAT; break; + default: throw std::runtime_error{"VulkanTexture only supports 1, 2, 3, or 4 channels."}; + } + + image_info.tiling = VK_IMAGE_TILING_OPTIMAL; + image_info.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; + image_info.usage = VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT; + image_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + image_info.samples = VK_SAMPLE_COUNT_1_BIT; + image_info.flags = 0; + + VkExternalMemoryImageCreateInfoKHR ext_image_info = {}; + ext_image_info.sType = VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO_KHR; + +#ifdef _WIN32 + ext_image_info.handleTypes |= VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR; +#else + ext_image_info.handleTypes |= VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR; +#endif + + image_info.pNext = &ext_image_info; + + VK_CHECK_THROW(vkCreateImage(m_vk->vk_device(), &image_info, nullptr, &m_vk_image)); + + // Create device memory to back up the image + VkMemoryRequirements mem_requirements = {}; + + vkGetImageMemoryRequirements(m_vk->vk_device(), m_vk_image, &mem_requirements); + + VkMemoryAllocateInfo mem_alloc_info = {}; + mem_alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + mem_alloc_info.allocationSize = mem_requirements.size; + mem_alloc_info.memoryTypeIndex = m_vk->vk_find_memory_type(mem_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + + VkExportMemoryAllocateInfoKHR export_info = {}; + export_info.sType = VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR; + export_info.handleTypes = ext_image_info.handleTypes; + + mem_alloc_info.pNext = &export_info; + + VK_CHECK_THROW(vkAllocateMemory(m_vk->vk_device(), &mem_alloc_info, nullptr, &m_vk_device_memory)); + VK_CHECK_THROW(vkBindImageMemory(m_vk->vk_device(), m_vk_image, m_vk_device_memory, 0)); + + m_vk->vk_command_buffer_begin(); + + VkImageMemoryBarrier barrier = {}; + barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER; + barrier.oldLayout = VK_IMAGE_LAYOUT_UNDEFINED; + barrier.newLayout = VK_IMAGE_LAYOUT_GENERAL; + barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + barrier.image = m_vk_image; + barrier.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + barrier.subresourceRange.baseMipLevel = 0; + barrier.subresourceRange.levelCount = 1; + barrier.subresourceRange.baseArrayLayer = 0; + barrier.subresourceRange.layerCount = 1; + barrier.srcAccessMask = 0; + barrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT | VK_ACCESS_MEMORY_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_COLOR_ATTACHMENT_READ_BIT | VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT; + + vkCmdPipelineBarrier( + m_vk->vk_command_buffer(), + VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 0, nullptr, + 0, nullptr, + 1, &barrier + ); + + m_vk->vk_command_buffer_end_and_submit_sync(); + + // Image view + VkImageViewCreateInfo view_info = {}; + view_info.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO; + view_info.image = m_vk_image; + view_info.viewType = VK_IMAGE_VIEW_TYPE_2D; + view_info.format = image_info.format; + view_info.subresourceRange = barrier.subresourceRange; + + VK_CHECK_THROW(vkCreateImageView(m_vk->vk_device(), &view_info, nullptr, &m_vk_image_view)); + + // Map to NGX + m_ngx_resource = NVSDK_NGX_Create_ImageView_Resource_VK(m_vk_image_view, m_vk_image, view_info.subresourceRange, image_info.format, m_size.x, m_size.y, true); + + // Map to CUDA memory: VkDeviceMemory->FD/HANDLE->cudaExternalMemory->CUDA pointer +#ifdef _WIN32 + HANDLE handle = nullptr; + VkMemoryGetWin32HandleInfoKHR handle_info = {}; + handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_WIN32_HANDLE_INFO_KHR; + handle_info.memory = m_vk_device_memory; + handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT; + auto pfn_vkGetMemory = (PFN_vkGetMemoryWin32HandleKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryWin32HandleKHR"); +#else + int handle = -1; + VkMemoryGetFdInfoKHR handle_info = {}; + handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_FD_INFO_KHR; + handle_info.memory = m_vk_device_memory; + handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR; + auto pfn_vkGetMemory = (PFN_vkGetMemoryFdKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryFdKHR"); +#endif + + if (!pfn_vkGetMemory) { + throw std::runtime_error{"Failed to locate pfn_vkGetMemory."}; + } + + VK_CHECK_THROW(pfn_vkGetMemory(m_vk->vk_device(), &handle_info, &handle)); + + // Map handle to CUDA memory + cudaExternalMemoryHandleDesc external_memory_handle_desc = {}; + memset(&external_memory_handle_desc, 0, sizeof(external_memory_handle_desc)); + +#ifdef _WIN32 + external_memory_handle_desc.type = cudaExternalMemoryHandleTypeOpaqueWin32; + external_memory_handle_desc.handle.win32.handle = handle; +#else + external_memory_handle_desc.type = cudaExternalMemoryHandleTypeOpaqueFd; + external_memory_handle_desc.handle.fd = handle; +#endif + external_memory_handle_desc.size = mem_requirements.size; + + CUDA_CHECK_THROW(cudaImportExternalMemory(&m_cuda_external_memory, &external_memory_handle_desc)); + + cudaExternalMemoryBufferDesc external_memory_buffer_desc = {}; + memset(&external_memory_buffer_desc, 0, sizeof(external_memory_buffer_desc)); + external_memory_buffer_desc.offset = 0; + external_memory_buffer_desc.size = mem_requirements.size; + + void* ptr; + CUDA_CHECK_THROW(cudaExternalMemoryGetMappedBuffer(&ptr, m_cuda_external_memory, &external_memory_buffer_desc)); + m_cuda_data = (float*)ptr; + + // ---------------- + // Also get a surface object array, as the above buffer might be too cumbersome to deal with + // ---------------- + cudaExternalMemoryMipmappedArrayDesc external_memory_mipmapped_array_desc = {}; + memset(&external_memory_mipmapped_array_desc, 0, sizeof(external_memory_mipmapped_array_desc)); + + cudaChannelFormatDesc channel_format = {}; + channel_format.f = cudaChannelFormatKindFloat; + switch (n_channels) { + case 1: channel_format.x = 32; channel_format.y = 0; channel_format.z = 0; channel_format.w = 0; break; + case 2: channel_format.x = 32; channel_format.y = 32; channel_format.z = 0; channel_format.w = 0; break; + case 3: channel_format.x = 32; channel_format.y = 32; channel_format.z = 32; channel_format.w = 0; break; + case 4: channel_format.x = 32; channel_format.y = 32; channel_format.z = 32; channel_format.w = 32; break; + default: throw std::runtime_error{"VulkanTexture only supports 1, 2, 3, or 4 channels."}; + } + + cudaExtent extent = {}; + extent.width = m_size.x; + extent.height = m_size.y; + extent.depth = 0; + + external_memory_mipmapped_array_desc.offset = 0; + external_memory_mipmapped_array_desc.formatDesc = channel_format; + external_memory_mipmapped_array_desc.extent = extent; + external_memory_mipmapped_array_desc.flags = cudaArraySurfaceLoadStore; + external_memory_mipmapped_array_desc.numLevels = 1; + + cudaExternalMemoryGetMappedMipmappedArray(&m_cuda_mipmapped_array, m_cuda_external_memory, &external_memory_mipmapped_array_desc); + + cudaArray_t first_level_array; + CUDA_CHECK_THROW(cudaGetMipmappedArrayLevel(&first_level_array, m_cuda_mipmapped_array, 0)); + + struct cudaResourceDesc resource_desc; + memset(&resource_desc, 0, sizeof(resource_desc)); + resource_desc.resType = cudaResourceTypeArray; + resource_desc.res.array.array = first_level_array; + + CUDA_CHECK_THROW(cudaCreateSurfaceObject(&m_cuda_surface_object, &resource_desc)); + + m_n_bytes = mem_requirements.size; + g_total_n_bytes_allocated += m_n_bytes; + + cleanup_guard.disarm(); + } + + virtual ~VulkanTexture() { + clear(); + } + + void clear() { + g_total_n_bytes_allocated -= m_n_bytes; + + if (m_cuda_data) { + cudaFree(m_cuda_data); + m_cuda_data = nullptr; + } + + if (m_cuda_surface_object) { + cudaDestroySurfaceObject(m_cuda_surface_object); + m_cuda_surface_object = {}; + } + + if (m_cuda_mipmapped_array) { + cudaFreeMipmappedArray(m_cuda_mipmapped_array); + m_cuda_mipmapped_array = {}; + } + + if (m_cuda_external_memory) { + cudaDestroyExternalMemory(m_cuda_external_memory); + m_cuda_external_memory = {}; + } + + if (m_vk_image_view) { + vkDestroyImageView(m_vk->vk_device(), m_vk_image_view, nullptr); + m_vk_image_view = {}; + } + + if (m_vk_image) { + vkDestroyImage(m_vk->vk_device(), m_vk_image, nullptr); + m_vk_image = {}; + } + + if (m_vk_device_memory) { + vkFreeMemory(m_vk->vk_device(), m_vk_device_memory, nullptr); + m_vk_device_memory = {}; + } + } + + float* data() { + return m_cuda_data; + } + + cudaSurfaceObject_t surface() { + return m_cuda_surface_object; + } + + NVSDK_NGX_Resource_VK& ngx_resource() { + return m_ngx_resource; + } + + size_t bytes() const { + return m_size.x * (size_t)m_size.y * sizeof(float) * m_n_channels; + } + + ivec2 size() const { + return m_size; + } + +private: + std::shared_ptr m_vk; + + ivec2 m_size; + uint32_t m_n_channels; + + size_t m_n_bytes = 0; + + VkImage m_vk_image = {}; + VkImageView m_vk_image_view = {}; + VkDeviceMemory m_vk_device_memory = {}; + + cudaExternalMemory_t m_cuda_external_memory = {}; + cudaMipmappedArray_t m_cuda_mipmapped_array = {}; + cudaSurfaceObject_t m_cuda_surface_object = {}; + float* m_cuda_data = nullptr; + + NVSDK_NGX_Resource_VK m_ngx_resource = {}; +}; + +NVSDK_NGX_PerfQuality_Value ngx_dlss_quality(EDlssQuality quality) { + switch (quality) { + case EDlssQuality::UltraPerformance: return NVSDK_NGX_PerfQuality_Value_UltraPerformance; + case EDlssQuality::MaxPerformance: return NVSDK_NGX_PerfQuality_Value_MaxPerf; + case EDlssQuality::Balanced: return NVSDK_NGX_PerfQuality_Value_Balanced; + case EDlssQuality::MaxQuality: return NVSDK_NGX_PerfQuality_Value_MaxQuality; + case EDlssQuality::UltraQuality: return NVSDK_NGX_PerfQuality_Value_UltraQuality; + default: throw std::runtime_error{"Unknown DLSS quality setting."}; + } +} + +struct DlssFeatureSpecs { + EDlssQuality quality; + ivec2 out_resolution; + ivec2 optimal_in_resolution; + ivec2 min_in_resolution; + ivec2 max_in_resolution; + float optimal_sharpness; + + float distance(const ivec2& resolution) const { + return length(vec2(max(max(min_in_resolution - resolution, resolution - max_in_resolution), ivec2(0)))); + } + + ivec2 clamp_resolution(const ivec2& resolution) const { + return clamp(resolution, min_in_resolution, max_in_resolution); + } +}; + +DlssFeatureSpecs dlss_feature_specs(NVSDK_NGX_Parameter* ngx_parameters, const ivec2& out_resolution, EDlssQuality quality) { + DlssFeatureSpecs specs; + specs.quality = quality; + specs.out_resolution = out_resolution; + + NGX_CHECK_THROW(NGX_DLSS_GET_OPTIMAL_SETTINGS( + ngx_parameters, + specs.out_resolution.x, specs.out_resolution.y, + ngx_dlss_quality(quality), + (uint32_t*)&specs.optimal_in_resolution.x, (uint32_t*)&specs.optimal_in_resolution.y, + (uint32_t*)&specs.max_in_resolution.x, (uint32_t*)&specs.max_in_resolution.y, + (uint32_t*)&specs.min_in_resolution.x, (uint32_t*)&specs.min_in_resolution.y, + &specs.optimal_sharpness + )); + + // Don't permit input resolutions larger than the output. (Just in case DLSS allows it.) + specs.optimal_in_resolution = min(specs.optimal_in_resolution, out_resolution); + specs.max_in_resolution = min(specs.max_in_resolution, out_resolution); + specs.min_in_resolution = min(specs.min_in_resolution, out_resolution); + + return specs; +} + +class DlssFeature { +public: + DlssFeature(std::shared_ptr vk_and_ngx, const DlssFeatureSpecs& specs, bool is_hdr, bool sharpen) : m_vk_and_ngx{vk_and_ngx}, m_specs{specs}, m_is_hdr{is_hdr}, m_sharpen{sharpen} { + // Initialize DLSS + unsigned int creation_node_mask = 1; + unsigned int visibility_node_mask = 1; + + int dlss_create_feature_flags = NVSDK_NGX_DLSS_Feature_Flags_None; + dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_MVLowRes : 0; + dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_MVJittered : 0; + dlss_create_feature_flags |= is_hdr ? NVSDK_NGX_DLSS_Feature_Flags_IsHDR : 0; + dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_DepthInverted : 0; + dlss_create_feature_flags |= sharpen ? NVSDK_NGX_DLSS_Feature_Flags_DoSharpening : 0; + dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_AutoExposure : 0; + + NVSDK_NGX_DLSS_Create_Params dlss_create_params; + + memset(&dlss_create_params, 0, sizeof(dlss_create_params)); + + dlss_create_params.Feature.InWidth = m_specs.optimal_in_resolution.x; + dlss_create_params.Feature.InHeight = m_specs.optimal_in_resolution.y; + dlss_create_params.Feature.InTargetWidth = m_specs.out_resolution.x; + dlss_create_params.Feature.InTargetHeight = m_specs.out_resolution.y; + dlss_create_params.Feature.InPerfQualityValue = ngx_dlss_quality(m_specs.quality); + dlss_create_params.InFeatureCreateFlags = dlss_create_feature_flags; + + { + m_vk_and_ngx->vk_command_buffer_begin(); + ScopeGuard command_buffer_guard{[&]() { m_vk_and_ngx->vk_command_buffer_end_and_submit_sync(); }}; + + NGX_CHECK_THROW(NGX_VULKAN_CREATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), creation_node_mask, visibility_node_mask, &m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_create_params)); + } + } + + DlssFeature(std::shared_ptr vk_and_ngx, const ivec2& out_resolution, bool is_hdr, bool sharpen, EDlssQuality quality) + : DlssFeature{vk_and_ngx, dlss_feature_specs(vk_and_ngx->ngx_parameters(), out_resolution, quality), is_hdr, sharpen} {} + + ~DlssFeature() { + cudaDeviceSynchronize(); + + if (m_ngx_dlss) { + NVSDK_NGX_VULKAN_ReleaseFeature(m_ngx_dlss); + } + + m_vk_and_ngx->vk_synchronize(); + } + + void run( + const ivec2& in_resolution, + const vec2& jitter_offset, + float sharpening, + bool shall_reset, + NVSDK_NGX_Resource_VK& frame, + NVSDK_NGX_Resource_VK& depth, + NVSDK_NGX_Resource_VK& mvec, + NVSDK_NGX_Resource_VK& exposure, + NVSDK_NGX_Resource_VK& output + ) { + if (!m_sharpen && sharpening != 0.0f) { + throw std::runtime_error{"May only specify non-zero sharpening, when DlssFeature has been created with sharpen option."}; + } + + m_vk_and_ngx->vk_command_buffer_begin(); + + NVSDK_NGX_VK_DLSS_Eval_Params dlss_params; + memset(&dlss_params, 0, sizeof(dlss_params)); + + dlss_params.Feature.pInColor = &frame; + dlss_params.Feature.pInOutput = &output; + dlss_params.pInDepth = &depth; + dlss_params.pInMotionVectors = &mvec; + dlss_params.pInExposureTexture = &exposure; + dlss_params.InJitterOffsetX = jitter_offset.x; + dlss_params.InJitterOffsetY = jitter_offset.y; + dlss_params.Feature.InSharpness = sharpening; + dlss_params.InReset = shall_reset; + dlss_params.InMVScaleX = 1.0f; + dlss_params.InMVScaleY = 1.0f; + dlss_params.InRenderSubrectDimensions = {(uint32_t)in_resolution.x, (uint32_t)in_resolution.y}; + + NGX_CHECK_THROW(NGX_VULKAN_EVALUATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_params)); + + m_vk_and_ngx->vk_command_buffer_end_and_submit_sync(); + } + + bool is_hdr() const { + return m_is_hdr; + } + + bool sharpen() const { + return m_sharpen; + } + + EDlssQuality quality() const { + return m_specs.quality; + } + + ivec2 out_resolution() const { + return m_specs.out_resolution; + } + + ivec2 clamp_resolution(const ivec2& resolution) const { + return m_specs.clamp_resolution(resolution); + } + + ivec2 optimal_in_resolution() const { + return m_specs.optimal_in_resolution; + } + +private: + std::shared_ptr m_vk_and_ngx; + + NVSDK_NGX_Handle* m_ngx_dlss = {}; + DlssFeatureSpecs m_specs; + bool m_is_hdr; + bool m_sharpen; +}; + +class Dlss : public IDlss { +public: + Dlss(std::shared_ptr vk_and_ngx, const ivec2& max_out_resolution) + : + m_vk_and_ngx{vk_and_ngx}, + m_max_out_resolution{max_out_resolution}, + // Allocate all buffers at output resolution and use dynamic sub-rects + // to use subsets of them. This avoids re-allocations when using DLSS + // with dynamically changing input resolution. + m_frame_buffer{m_vk_and_ngx, max_out_resolution, 4}, + m_depth_buffer{m_vk_and_ngx, max_out_resolution, 1}, + m_mvec_buffer{m_vk_and_ngx, max_out_resolution, 2}, + m_exposure_buffer{m_vk_and_ngx, {1, 1}, 1}, + m_output_buffer{m_vk_and_ngx, max_out_resolution, 4} + { + // Various quality modes of DLSS + for (int i = 0; i < (int)EDlssQuality::NumDlssQualitySettings; ++i) { + try { + auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), max_out_resolution, (EDlssQuality)i); + + // Only emplace the specs if the feature can be created in practice! + DlssFeature{m_vk_and_ngx, specs, true, true}; + DlssFeature{m_vk_and_ngx, specs, true, false}; + DlssFeature{m_vk_and_ngx, specs, false, true}; + DlssFeature{m_vk_and_ngx, specs, false, false}; + m_dlss_specs.emplace_back(specs); + } catch (...) {} + } + + // For super insane performance requirements (more than 3x upscaling) try UltraPerformance + // with reduced output resolutions for 4.5x, 6x, 9x. + std::vector reduced_out_resolutions = { + max_out_resolution / 3 * 2, + max_out_resolution / 2, + max_out_resolution / 3, + // max_out_resolution / 4, + }; + + for (const auto& out_resolution : reduced_out_resolutions) { + try { + auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), out_resolution, EDlssQuality::UltraPerformance); + + // Only emplace the specs if the feature can be created in practice! + DlssFeature{m_vk_and_ngx, specs, true, true}; + DlssFeature{m_vk_and_ngx, specs, true, false}; + DlssFeature{m_vk_and_ngx, specs, false, true}; + DlssFeature{m_vk_and_ngx, specs, false, false}; + m_dlss_specs.emplace_back(specs); + } catch (...) {} + } + } + + virtual ~Dlss() { + // Destroy DLSS feature prior to killing underlying buffers. + m_dlss_feature = nullptr; + } + + void update_feature(const ivec2& in_resolution, bool is_hdr, bool sharpen) override { + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + DlssFeatureSpecs specs; + bool found = false; + for (const auto& s : m_dlss_specs) { + if (s.distance(in_resolution) == 0.0f) { + specs = s; + found = true; + } + } + + if (!found) { + throw std::runtime_error{"Dlss::run called with invalid input resolution."}; + } + + if (!m_dlss_feature || m_dlss_feature->is_hdr() != is_hdr || m_dlss_feature->sharpen() != sharpen || m_dlss_feature->quality() != specs.quality || m_dlss_feature->out_resolution() != specs.out_resolution) { + m_dlss_feature.reset(new DlssFeature{m_vk_and_ngx, specs.out_resolution, is_hdr, sharpen, specs.quality}); + } + } + + void run( + const ivec2& in_resolution, + bool is_hdr, + float sharpening, + const vec2& jitter_offset, + bool shall_reset + ) override { + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + update_feature(in_resolution, is_hdr, sharpening != 0.0f); + + m_dlss_feature->run( + in_resolution, + jitter_offset, + sharpening, + shall_reset, + m_frame_buffer.ngx_resource(), + m_depth_buffer.ngx_resource(), + m_mvec_buffer.ngx_resource(), + m_exposure_buffer.ngx_resource(), + m_output_buffer.ngx_resource() + ); + } + + cudaSurfaceObject_t frame() override { + return m_frame_buffer.surface(); + } + + cudaSurfaceObject_t depth() override { + return m_depth_buffer.surface(); + } + + cudaSurfaceObject_t mvec() override { + return m_mvec_buffer.surface(); + } + + cudaSurfaceObject_t exposure() override { + return m_exposure_buffer.surface(); + } + + cudaSurfaceObject_t output() override { + return m_output_buffer.surface(); + } + + ivec2 clamp_resolution(const ivec2& resolution) const { + float min_distance = std::numeric_limits::infinity(); + DlssFeatureSpecs min_distance_specs = {}; + for (const auto& specs : m_dlss_specs) { + float distance = specs.distance(resolution); + if (distance <= min_distance) { + min_distance = distance; + min_distance_specs = specs; + } + } + + return min_distance_specs.clamp_resolution(resolution); + } + + ivec2 out_resolution() const override { + return m_dlss_feature ? m_dlss_feature->out_resolution() : m_max_out_resolution; + } + + ivec2 max_out_resolution() const override { + return m_max_out_resolution; + } + + bool is_hdr() const override { + return m_dlss_feature && m_dlss_feature->is_hdr(); + } + + bool sharpen() const override { + return m_dlss_feature && m_dlss_feature->sharpen(); + } + + EDlssQuality quality() const override { + return m_dlss_feature ? m_dlss_feature->quality() : EDlssQuality::None; + } + +private: + std::shared_ptr m_vk_and_ngx; + + std::unique_ptr m_dlss_feature; + std::vector m_dlss_specs; + + VulkanTexture m_frame_buffer; + VulkanTexture m_depth_buffer; + VulkanTexture m_mvec_buffer; + VulkanTexture m_exposure_buffer; + VulkanTexture m_output_buffer; + + ivec2 m_max_out_resolution; +}; + +std::unique_ptr VulkanAndNgx::init_dlss(const ivec2& out_resolution) { + return std::make_unique(shared_from_this(), out_resolution); +} + +} diff --git a/gui/src/main.cu b/gui/src/main.cu new file mode 100644 index 0000000000000000000000000000000000000000..caaa2cf9d7693af51992f9d42c1e5bef044bf3fb --- /dev/null +++ b/gui/src/main.cu @@ -0,0 +1,163 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file main.cu + * @author Thomas Müller, NVIDIA + */ + +#include + +#include + +#include + +#include + +using namespace args; +using namespace ngp; +using namespace std; + +namespace ngp { + +int main_func(const std::vector& arguments) { + ArgumentParser parser{ + "Gen3C GUI" + "Version " NGP_VERSION, + "", + }; + + HelpFlag help_flag{ + parser, + "HELP", + "Display this help menu.", + {'h', "help"}, + }; + + Flag vr_flag{parser, "VR", "Enables VR", {"vr"}}; + + ValueFlag snapshot_flag{ + parser, + "SNAPSHOT", + "Optional snapshot to load upon startup.", + {"snapshot", "load_snapshot"}, + }; + + ValueFlag width_flag{ + parser, + "WIDTH", + "Resolution width of the GUI.", + {"width"}, + }; + + ValueFlag height_flag{ + parser, + "HEIGHT", + "Resolution height of the GUI.", + {"height"}, + }; + + Flag version_flag{ + parser, + "VERSION", + "Display the version of Gen3C GUI.", + {'v', "version"}, + }; + + PositionalList files{ + parser, + "files", + "Files to be loaded. Can be a scene, network config, snapshot, camera path, or a combination of those.", + }; + + // Parse command line arguments and react to parsing + // errors using exceptions. + try { + if (arguments.empty()) { + tlog::error() << "Number of arguments must be bigger than 0."; + return -3; + } + + parser.Prog(arguments.front()); + parser.ParseArgs(begin(arguments) + 1, end(arguments)); + } catch (const Help&) { + cout << parser; + return 0; + } catch (const ParseError& e) { + cerr << e.what() << endl; + cerr << parser; + return -1; + } catch (const ValidationError& e) { + cerr << e.what() << endl; + cerr << parser; + return -2; + } + + if (version_flag) { + tlog::none() << "Gen3C GUI v" NGP_VERSION; + return 0; + } + + Testbed testbed{ETestbedMode::Gen3c}; + + for (auto file : get(files)) { + testbed.load_file(file); + } + +#ifdef NGP_GUI + bool gui = true; +#else + bool gui = false; +#endif + + if (gui) { + testbed.init_window(width_flag ? get(width_flag) : 1920, height_flag ? get(height_flag) : 1080); + } + + if (vr_flag) { + testbed.init_vr(); + } + + // Render loop + while (testbed.frame()) {} + + return 0; +} + +} // namespace ngp + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { + SetConsoleOutputCP(CP_UTF8); +#else +int main(int argc, char* argv[]) { +#endif + try { + std::vector arguments; + for (int i = 0; i < argc; ++i) { +#ifdef _WIN32 + arguments.emplace_back(ngp::utf16_to_utf8(argv[i])); +#else + arguments.emplace_back(argv[i]); +#endif + } + + return ngp::main_func(arguments); + } catch (const exception& e) { + tlog::error() << fmt::format("Uncaught exception: {}", e.what()); + return 1; + } +} diff --git a/gui/src/openxr_hmd.cu b/gui/src/openxr_hmd.cu new file mode 100644 index 0000000000000000000000000000000000000000..f33d304d94fe5deddeca0e8d052c4e51cd523765 --- /dev/null +++ b/gui/src/openxr_hmd.cu @@ -0,0 +1,1264 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file openxr_hmd.cu + * @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA + * @brief Wrapper around the OpenXR API, providing access to + * per-eye framebuffers, lens parameters, visible area, + * view, hand, and eye poses, as well as controller inputs. + */ + +#include +#include +#include +#include + +#include + +#include + +#include + +#include + +#include +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type +#endif + +namespace ngp { + +// function XrEnumStr turns enum into string for printing +// uses expansion macro and data provided in openxr_reflection.h +#define XR_ENUM_CASE_STR(name, val) \ + case name: \ + return #name; +#define XR_ENUM_STR(enum_type) \ + constexpr const char* XrEnumStr(enum_type e) { \ + switch (e) { \ + XR_LIST_ENUM_##enum_type(XR_ENUM_CASE_STR) default : return "Unknown"; \ + } \ + } + +XR_ENUM_STR(XrViewConfigurationType) +XR_ENUM_STR(XrEnvironmentBlendMode) +XR_ENUM_STR(XrReferenceSpaceType) +XR_ENUM_STR(XrStructureType) +XR_ENUM_STR(XrSessionState) + +/// Checks the result of a xrXXXXXX call and throws an error on failure +#define XR_CHECK_THROW(x) \ + do { \ + XrResult result = x; \ + if (XR_FAILED(result)) { \ + char buffer[XR_MAX_RESULT_STRING_SIZE]; \ + XrResult result_to_string_result = xrResultToString(m_instance, result, buffer); \ + if (XR_FAILED(result_to_string_result)) { \ + throw std::runtime_error{std::string(FILE_LINE " " #x " failed, but could not obtain error string")}; \ + } else { \ + throw std::runtime_error{std::string(FILE_LINE " " #x " failed with error ") + buffer}; \ + } \ + } \ + } while(0) + +OpenXRHMD::Swapchain::Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& m_instance) { + ScopeGuard cleanup_guard{[&]() { clear(); }}; + + XR_CHECK_THROW(xrCreateSwapchain(session, &rgba_create_info, &handle)); + + width = rgba_create_info.width; + height = rgba_create_info.height; + + { + uint32_t size; + XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, 0, &size, nullptr)); + + images_gl.resize(size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR}); + XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, size, &size, (XrSwapchainImageBaseHeader*)images_gl.data())); + + // One framebuffer per swapchain image + framebuffers_gl.resize(size); + } + + if (depth_create_info.format != 0) { + XR_CHECK_THROW(xrCreateSwapchain(session, &depth_create_info, &depth_handle)); + + uint32_t depth_size; + XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, 0, &depth_size, nullptr)); + + depth_images_gl.resize(depth_size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR}); + XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, depth_size, &depth_size, (XrSwapchainImageBaseHeader*)depth_images_gl.data())); + + // We might have a different number of depth swapchain images as we have framebuffers, + // so we will need to bind an acquired depth image to the current framebuffer on the + // fly later on. + } + + glGenFramebuffers(framebuffers_gl.size(), framebuffers_gl.data()); + + cleanup_guard.disarm(); +} + +OpenXRHMD::Swapchain::~Swapchain() { + clear(); +} + +void OpenXRHMD::Swapchain::clear() { + if (!framebuffers_gl.empty()) { + glDeleteFramebuffers(framebuffers_gl.size(), framebuffers_gl.data()); + } + + if (depth_handle != XR_NULL_HANDLE) { + xrDestroySwapchain(depth_handle); + depth_handle = XR_NULL_HANDLE; + } + + if (handle != XR_NULL_HANDLE) { + xrDestroySwapchain(handle); + handle = XR_NULL_HANDLE; + } +} + +#if defined(XR_USE_PLATFORM_WIN32) +OpenXRHMD::OpenXRHMD(HDC hdc, HGLRC hglrc) { +#elif defined(XR_USE_PLATFORM_XLIB) +OpenXRHMD::OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) { +#elif defined(XR_USE_PLATFORM_WAYLAND) +OpenXRHMD::OpenXRHMD(wl_display* display) { +#endif + ScopeGuard cleanup_guard{[&]() { clear(); }}; + + init_create_xr_instance(); + init_get_xr_system(); + init_configure_xr_views(); + init_check_for_xr_blend_mode(); +#if defined(XR_USE_PLATFORM_WIN32) + init_open_gl(hdc, hglrc); +#elif defined(XR_USE_PLATFORM_XLIB) + init_open_gl(xDisplay, visualid, glxFBConfig, glxDrawable, glxContext); +#elif defined(XR_USE_PLATFORM_WAYLAND) + init_open_gl(display); +#endif + init_xr_session(); + init_xr_actions(); + init_xr_spaces(); + init_xr_swapchain_open_gl(); + init_open_gl_shaders(); + + cleanup_guard.disarm(); + tlog::success() << "Initialized OpenXR for " << m_system_properties.systemName; + // tlog::success() << " " + // << " depth=" << (m_supports_composition_layer_depth ? "true" : "false") + // << " mask=" << (m_supports_hidden_area_mask ? "true" : "false") + // << " eye=" << (m_supports_eye_tracking ? "true" : "false") + // ; +} + +OpenXRHMD::~OpenXRHMD() { + clear(); +} + +void OpenXRHMD::clear() { + auto xr_destroy = [&](auto& handle, auto destroy_fun) { + if (handle != XR_NULL_HANDLE) { + destroy_fun(handle); + handle = XR_NULL_HANDLE; + } + }; + + xr_destroy(m_pose_action, xrDestroyAction); + xr_destroy(m_thumbstick_actions[0], xrDestroyAction); + xr_destroy(m_thumbstick_actions[1], xrDestroyAction); + xr_destroy(m_press_action, xrDestroyAction); + xr_destroy(m_grab_action, xrDestroyAction); + + xr_destroy(m_action_set, xrDestroyActionSet); + + m_swapchains.clear(); + xr_destroy(m_space, xrDestroySpace); + xr_destroy(m_session, xrDestroySession); + xr_destroy(m_instance, xrDestroyInstance); +} + +void OpenXRHMD::init_create_xr_instance() { + std::vector layers = {}; + std::vector extensions = { + XR_KHR_OPENGL_ENABLE_EXTENSION_NAME, + }; + + auto print_extension_properties = [](const char* layer_name) { + uint32_t size; + xrEnumerateInstanceExtensionProperties(layer_name, 0, &size, nullptr); + std::vector props(size, {XR_TYPE_EXTENSION_PROPERTIES}); + xrEnumerateInstanceExtensionProperties(layer_name, size, &size, props.data()); + tlog::info() << fmt::format("Extensions ({}):", props.size()); + for (XrExtensionProperties extension : props) { + tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion); + } + }; + + uint32_t size; + xrEnumerateApiLayerProperties(0, &size, nullptr); + m_api_layer_properties.clear(); + m_api_layer_properties.resize(size, {XR_TYPE_API_LAYER_PROPERTIES}); + xrEnumerateApiLayerProperties(size, &size, m_api_layer_properties.data()); + + if (m_print_api_layers) { + tlog::info() << fmt::format("API Layers ({}):", m_api_layer_properties.size()); + for (auto p : m_api_layer_properties) { + tlog::info() << fmt::format( + "{} (v {}.{}.{}, {}) {}", + p.layerName, + XR_VERSION_MAJOR(p.specVersion), + XR_VERSION_MINOR(p.specVersion), + XR_VERSION_PATCH(p.specVersion), + p.layerVersion, + p.description + ); + print_extension_properties(p.layerName); + } + } + + if (layers.size() != 0) { + for (const auto& e : layers) { + bool found = false; + for (XrApiLayerProperties layer : m_api_layer_properties) { + if (strcmp(e, layer.layerName) == 0) { + found = true; + break; + } + } + + if (!found) { + throw std::runtime_error{fmt::format("OpenXR API layer {} not found", e)}; + } + } + } + + xrEnumerateInstanceExtensionProperties(nullptr, 0, &size, nullptr); + m_instance_extension_properties.clear(); + m_instance_extension_properties.resize(size, {XR_TYPE_EXTENSION_PROPERTIES}); + xrEnumerateInstanceExtensionProperties(nullptr, size, &size, m_instance_extension_properties.data()); + + if (m_print_extensions) { + tlog::info() << fmt::format("Instance extensions ({}):", m_instance_extension_properties.size()); + for (XrExtensionProperties extension : m_instance_extension_properties) { + tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion); + } + } + + auto has_extension = [&](const char* e) { + for (XrExtensionProperties extension : m_instance_extension_properties) { + if (strcmp(e, extension.extensionName) == 0) { + return true; + } + } + + return false; + }; + + for (const auto& e : extensions) { + if (!has_extension(e)) { + throw std::runtime_error{fmt::format("Required OpenXR extension {} not found", e)}; + } + } + + auto add_extension_if_supported = [&](const char* extension) { + if (has_extension(extension)) { + extensions.emplace_back(extension); + return true; + } + + return false; + }; + + if (add_extension_if_supported(XR_KHR_COMPOSITION_LAYER_DEPTH_EXTENSION_NAME)) { + m_supports_composition_layer_depth = true; + } + + if (add_extension_if_supported(XR_KHR_VISIBILITY_MASK_EXTENSION_NAME)) { + m_supports_hidden_area_mask = true; + } + + if (add_extension_if_supported(XR_EXT_EYE_GAZE_INTERACTION_EXTENSION_NAME)) { + m_supports_eye_tracking = true; + } + + XrInstanceCreateInfo instance_create_info = {XR_TYPE_INSTANCE_CREATE_INFO}; + instance_create_info.applicationInfo = {}; + strncpy(instance_create_info.applicationInfo.applicationName, "Gen3C GUI v" NGP_VERSION, XR_MAX_APPLICATION_NAME_SIZE); + instance_create_info.applicationInfo.applicationVersion = 1; + strncpy(instance_create_info.applicationInfo.engineName, "Gen3C GUI v" NGP_VERSION, XR_MAX_ENGINE_NAME_SIZE); + instance_create_info.applicationInfo.engineVersion = 1; + instance_create_info.applicationInfo.apiVersion = XR_CURRENT_API_VERSION; + instance_create_info.enabledExtensionCount = (uint32_t)extensions.size(); + instance_create_info.enabledExtensionNames = extensions.data(); + instance_create_info.enabledApiLayerCount = (uint32_t)layers.size(); + instance_create_info.enabledApiLayerNames = layers.data(); + + if (XR_FAILED(xrCreateInstance(&instance_create_info, &m_instance))) { + throw std::runtime_error{"Failed to create OpenXR instance"}; + } + + XR_CHECK_THROW(xrGetInstanceProperties(m_instance, &m_instance_properties)); + if (m_print_instance_properties) { + tlog::info() << "Instance Properties"; + tlog::info() << fmt::format("\t runtime name: '{}'", m_instance_properties.runtimeName); + const auto& v = m_instance_properties.runtimeVersion; + tlog::info() << fmt::format( + "\t runtime version: {}.{}.{}", + XR_VERSION_MAJOR(v), + XR_VERSION_MINOR(v), + XR_VERSION_PATCH(v) + ); + } +} + +void OpenXRHMD::init_get_xr_system() { + XrSystemGetInfo system_get_info = {XR_TYPE_SYSTEM_GET_INFO, nullptr, XR_FORM_FACTOR_HEAD_MOUNTED_DISPLAY}; + XR_CHECK_THROW(xrGetSystem(m_instance, &system_get_info, &m_system_id)); + + XR_CHECK_THROW(xrGetSystemProperties(m_instance, m_system_id, &m_system_properties)); + if (m_print_system_properties) { + tlog::info() << "System Properties"; + tlog::info() << fmt::format("\t name: '{}'", m_system_properties.systemName); + tlog::info() << fmt::format("\t vendorId: {:#x}", m_system_properties.vendorId); + tlog::info() << fmt::format("\t systemId: {:#x}", m_system_properties.systemId); + tlog::info() << fmt::format("\t max layer count: {}", m_system_properties.graphicsProperties.maxLayerCount); + tlog::info() << fmt::format("\t max img width: {}", m_system_properties.graphicsProperties.maxSwapchainImageWidth); + tlog::info() << fmt::format("\t max img height: {}", m_system_properties.graphicsProperties.maxSwapchainImageHeight); + tlog::info() << fmt::format("\torientation tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO"); + tlog::info() << fmt::format("\t position tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO"); + } +} + +void OpenXRHMD::init_configure_xr_views() { + uint32_t size; + XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, 0, &size, nullptr)); + std::vector view_config_types(size); + XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, size, &size, view_config_types.data())); + + if (m_print_view_configuration_types) { + tlog::info() << fmt::format("View Configuration Types ({}):", view_config_types.size()); + for (const auto& t : view_config_types) { + tlog::info() << fmt::format("\t{}", XrEnumStr(t)); + } + } + + // view configurations we support, in descending preference + const std::vector preferred_view_config_types = { + //XR_VIEW_CONFIGURATION_TYPE_PRIMARY_QUAD_VARJO, + XR_VIEW_CONFIGURATION_TYPE_PRIMARY_STEREO + }; + + bool found = false; + for (const auto& p : preferred_view_config_types) { + for (const auto& t : view_config_types) { + if (p == t) { + found = true; + m_view_configuration_type = t; + } + } + } + + if (!found) { + throw std::runtime_error{"Could not find a suitable OpenXR view configuration type"}; + } + + // get view configuration properties + XR_CHECK_THROW(xrGetViewConfigurationProperties(m_instance, m_system_id, m_view_configuration_type, &m_view_configuration_properties)); + if (m_print_view_configuration_properties) { + tlog::info() << "View Configuration Properties:"; + tlog::info() << fmt::format("\t Type: {}", XrEnumStr(m_view_configuration_type)); + tlog::info() << fmt::format("\t FOV Mutable: {}", m_view_configuration_properties.fovMutable ? "YES" : "NO"); + } + + // enumerate view configuration views + XR_CHECK_THROW(xrEnumerateViewConfigurationViews(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr)); + m_view_configuration_views.clear(); + m_view_configuration_views.resize(size, {XR_TYPE_VIEW_CONFIGURATION_VIEW}); + XR_CHECK_THROW(xrEnumerateViewConfigurationViews( + m_instance, + m_system_id, + m_view_configuration_type, + size, + &size, + m_view_configuration_views.data() + )); + + if (m_print_view_configuration_view) { + tlog::info() << "View Configuration Views, Width x Height x Samples"; + for (size_t i = 0; i < m_view_configuration_views.size(); ++i) { + const auto& view = m_view_configuration_views[i]; + tlog::info() << fmt::format( + "\tView {}\tRecommended: {}x{}x{} Max: {}x{}x{}", + i, + view.recommendedImageRectWidth, + view.recommendedImageRectHeight, + view.recommendedSwapchainSampleCount, + view.maxImageRectWidth, + view.maxImageRectHeight, + view.maxSwapchainSampleCount + ); + } + } +} + +void OpenXRHMD::init_check_for_xr_blend_mode() { + // enumerate environment blend modes + uint32_t size; + XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr)); + std::vector supported_blend_modes(size); + XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes( + m_instance, + m_system_id, + m_view_configuration_type, + size, + &size, + supported_blend_modes.data() + )); + + if (supported_blend_modes.empty()) { + throw std::runtime_error{"No OpenXR environment blend modes found"}; + } + + std::sort(std::begin(supported_blend_modes), std::end(supported_blend_modes)); + if (m_print_environment_blend_modes) { + tlog::info() << fmt::format("Environment Blend Modes ({}):", supported_blend_modes.size()); + } + + m_supported_environment_blend_modes.resize(supported_blend_modes.size()); + m_supported_environment_blend_modes_imgui_string.clear(); + for (size_t i = 0; i < supported_blend_modes.size(); ++i) { + if (m_print_environment_blend_modes) { + tlog::info() << fmt::format("\t{}", XrEnumStr(supported_blend_modes[i])); + } + + auto b = (EEnvironmentBlendMode)supported_blend_modes[i]; + m_supported_environment_blend_modes[i] = b; + + auto b_str = to_string(b); + std::copy(std::begin(b_str), std::end(b_str), std::back_inserter(m_supported_environment_blend_modes_imgui_string)); + m_supported_environment_blend_modes_imgui_string.emplace_back('\0'); + } + + m_supported_environment_blend_modes_imgui_string.emplace_back('\0'); + m_environment_blend_mode = m_supported_environment_blend_modes.front(); +} + +void OpenXRHMD::init_xr_actions() { + // paths for left (0) and right (1) hands + XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/left", &m_hand_paths[0])); + XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/right", &m_hand_paths[1])); + + // create action set + XrActionSetCreateInfo action_set_create_info{XR_TYPE_ACTION_SET_CREATE_INFO, nullptr, "actionset", "actionset", 0}; + XR_CHECK_THROW(xrCreateActionSet(m_instance, &action_set_create_info, &m_action_set)); + + { + XrActionCreateInfo action_create_info{ + XR_TYPE_ACTION_CREATE_INFO, + nullptr, + "hand_pose", + XR_ACTION_TYPE_POSE_INPUT, + (uint32_t)m_hand_paths.size(), + m_hand_paths.data(), + "Hand pose" + }; + XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_pose_action)); + } + + { + XrActionCreateInfo action_create_info{ + XR_TYPE_ACTION_CREATE_INFO, + nullptr, + "thumbstick_left", + XR_ACTION_TYPE_VECTOR2F_INPUT, + 0, + nullptr, + "Left thumbstick" + }; + XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[0])); + } + + { + XrActionCreateInfo action_create_info{ + XR_TYPE_ACTION_CREATE_INFO, + nullptr, + "thumbstick_right", + XR_ACTION_TYPE_VECTOR2F_INPUT, + 0, + nullptr, + "Right thumbstick" + }; + XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[1])); + } + + { + XrActionCreateInfo action_create_info{ + XR_TYPE_ACTION_CREATE_INFO, + nullptr, + "press", + XR_ACTION_TYPE_BOOLEAN_INPUT, + (uint32_t)m_hand_paths.size(), + m_hand_paths.data(), + "Press" + }; + XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_press_action)); + } + + { + XrActionCreateInfo action_create_info{ + XR_TYPE_ACTION_CREATE_INFO, + nullptr, + "grab", + XR_ACTION_TYPE_FLOAT_INPUT, + (uint32_t)m_hand_paths.size(), + m_hand_paths.data(), + "Grab" + }; + XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_grab_action)); + } + + auto create_binding = [&](XrAction action, const std::string& binding_path_str) { + XrPath binding; + XR_CHECK_THROW(xrStringToPath(m_instance, binding_path_str.c_str(), &binding)); + return XrActionSuggestedBinding{action, binding}; + }; + + auto suggest_bindings = [&](const std::string& interaction_profile_path_str, const std::vector& bindings) { + XrPath interaction_profile; + XR_CHECK_THROW(xrStringToPath(m_instance, interaction_profile_path_str.c_str(), &interaction_profile)); + XrInteractionProfileSuggestedBinding suggested_binding{ + XR_TYPE_INTERACTION_PROFILE_SUGGESTED_BINDING, + nullptr, + interaction_profile, + (uint32_t)bindings.size(), + bindings.data() + }; + XR_CHECK_THROW(xrSuggestInteractionProfileBindings(m_instance, &suggested_binding)); + }; + + suggest_bindings("/interaction_profiles/khr/simple_controller", { + create_binding(m_pose_action, "/user/hand/left/input/grip/pose"), + create_binding(m_pose_action, "/user/hand/right/input/grip/pose"), + }); + + auto suggest_controller_bindings = [&](const std::string& xy, const std::string& press, const std::string& grab, const std::string& squeeze, const std::string& interaction_profile_path_str) { + std::vector bindings = { + create_binding(m_pose_action, "/user/hand/left/input/grip/pose"), + create_binding(m_pose_action, "/user/hand/right/input/grip/pose"), + create_binding(m_thumbstick_actions[0], std::string{"/user/hand/left/input/"} + xy), + create_binding(m_thumbstick_actions[1], std::string{"/user/hand/right/input/"} + xy), + create_binding(m_press_action, std::string{"/user/hand/left/input/"} + press), + create_binding(m_press_action, std::string{"/user/hand/right/input/"} + press), + create_binding(m_grab_action, std::string{"/user/hand/left/input/"} + grab), + create_binding(m_grab_action, std::string{"/user/hand/right/input/"} + grab), + }; + + if (!squeeze.empty()) { + bindings.emplace_back(create_binding(m_grab_action, std::string{"/user/hand/left/input/"} + squeeze)); + bindings.emplace_back(create_binding(m_grab_action, std::string{"/user/hand/right/input/"} + squeeze)); + } + + suggest_bindings(interaction_profile_path_str, bindings); + }; + + suggest_controller_bindings("trackpad", "select/click", "trackpad/click", "", "/interaction_profiles/google/daydream_controller"); + suggest_controller_bindings("trackpad", "trackpad/click", "trigger/click", "squeeze/click", "/interaction_profiles/htc/vive_controller"); + suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value", "squeeze/click", "/interaction_profiles/microsoft/motion_controller"); + suggest_controller_bindings("trackpad", "trackpad/click", "trigger/click", "", "/interaction_profiles/oculus/go_controller"); + suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value", "squeeze/value", "/interaction_profiles/oculus/touch_controller"); + + // Valve Index force squeeze is very sensitive and can cause unwanted grabbing. Only permit trigger-grabbing for now. + suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value", ""/*squeeze/force*/, "/interaction_profiles/valve/index_controller"); + + // Xbox controller (currently not functional) + suggest_bindings("/interaction_profiles/microsoft/xbox_controller", { + create_binding(m_thumbstick_actions[0], std::string{"/user/gamepad/input/thumbstick_left"}), + create_binding(m_thumbstick_actions[1], std::string{"/user/gamepad/input/thumbstick_right"}), + }); +} + +#if defined(XR_USE_PLATFORM_WIN32) +void OpenXRHMD::init_open_gl(HDC hdc, HGLRC hglrc) { +#elif defined(XR_USE_PLATFORM_XLIB) +void OpenXRHMD::init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) { +#elif defined(XR_USE_PLATFORM_WAYLAND) +void OpenXRHMD::init_open_gl(wl_display* display) { +#endif + // GL graphics requirements + PFN_xrGetOpenGLGraphicsRequirementsKHR xrGetOpenGLGraphicsRequirementsKHR = nullptr; + XR_CHECK_THROW(xrGetInstanceProcAddr( + m_instance, + "xrGetOpenGLGraphicsRequirementsKHR", + reinterpret_cast(&xrGetOpenGLGraphicsRequirementsKHR) + )); + + XrGraphicsRequirementsOpenGLKHR graphics_requirements{XR_TYPE_GRAPHICS_REQUIREMENTS_OPENGL_KHR}; + xrGetOpenGLGraphicsRequirementsKHR(m_instance, m_system_id, &graphics_requirements); + XrVersion min_version = graphics_requirements.minApiVersionSupported; + GLint major = 0; + GLint minor = 0; + glGetIntegerv(GL_MAJOR_VERSION, &major); + glGetIntegerv(GL_MINOR_VERSION, &minor); + const XrVersion have_version = XR_MAKE_VERSION(major, minor, 0); + + if (have_version < min_version) { + tlog::info() << fmt::format( + "Required OpenGL version: {}.{}, found OpenGL version: {}.{}", + XR_VERSION_MAJOR(min_version), + XR_VERSION_MINOR(min_version), + major, + minor + ); + + throw std::runtime_error{"Insufficient graphics API support"}; + } + +#if defined(XR_USE_PLATFORM_WIN32) + m_graphics_binding.hDC = hdc; + m_graphics_binding.hGLRC = hglrc; +#elif defined(XR_USE_PLATFORM_XLIB) + m_graphics_binding.xDisplay = xDisplay; + m_graphics_binding.visualid = visualid; + m_graphics_binding.glxFBConfig = glxFBConfig; + m_graphics_binding.glxDrawable = glxDrawable; + m_graphics_binding.glxContext = glxContext; +#elif defined(XR_USE_PLATFORM_WAYLAND) + m_graphics_binding.display = display; +#endif +} + +void OpenXRHMD::init_xr_session() { + // create session + XrSessionCreateInfo create_info{ + XR_TYPE_SESSION_CREATE_INFO, + reinterpret_cast(&m_graphics_binding), + 0, + m_system_id + }; + + XR_CHECK_THROW(xrCreateSession(m_instance, &create_info, &m_session)); + + // tlog::info() << fmt::format("Created session {}", fmt::ptr(m_session)); +} + +void OpenXRHMD::init_xr_spaces() { + // reference space + uint32_t size; + XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, 0, &size, nullptr)); + m_reference_spaces.clear(); + m_reference_spaces.resize(size); + XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, size, &size, m_reference_spaces.data())); + + if (m_print_reference_spaces) { + tlog::info() << fmt::format("Reference spaces ({}):", m_reference_spaces.size()); + for (const auto& r : m_reference_spaces) { + tlog::info() << fmt::format("\t{}", XrEnumStr(r)); + } + } + + XrReferenceSpaceCreateInfo reference_space_create_info{XR_TYPE_REFERENCE_SPACE_CREATE_INFO}; + reference_space_create_info.referenceSpaceType = XR_REFERENCE_SPACE_TYPE_LOCAL; + reference_space_create_info.poseInReferenceSpace = XrPosef{}; + reference_space_create_info.poseInReferenceSpace.orientation.w = 1.0f; + XR_CHECK_THROW(xrCreateReferenceSpace(m_session, &reference_space_create_info, &m_space)); + XR_CHECK_THROW(xrGetReferenceSpaceBoundsRect(m_session, reference_space_create_info.referenceSpaceType, &m_bounds)); + + if (m_print_reference_spaces) { + tlog::info() << fmt::format("Using reference space {}", XrEnumStr(reference_space_create_info.referenceSpaceType)); + tlog::info() << fmt::format("Reference space boundaries: {} x {}", m_bounds.width, m_bounds.height); + } + + // action space + XrActionSpaceCreateInfo action_space_create_info{XR_TYPE_ACTION_SPACE_CREATE_INFO}; + action_space_create_info.action = m_pose_action; + action_space_create_info.poseInActionSpace.orientation.w = 1.0f; + action_space_create_info.subactionPath = m_hand_paths[0]; + XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[0])); + action_space_create_info.subactionPath = m_hand_paths[1]; + XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[1])); + + // attach action set + XrSessionActionSetsAttachInfo attach_info{XR_TYPE_SESSION_ACTION_SETS_ATTACH_INFO}; + attach_info.countActionSets = 1; + attach_info.actionSets = &m_action_set; + XR_CHECK_THROW(xrAttachSessionActionSets(m_session, &attach_info)); +} + +void OpenXRHMD::init_xr_swapchain_open_gl() { + // swap chains + uint32_t size; + XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, 0, &size, nullptr)); + std::vector swapchain_formats(size); + XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, size, &size, swapchain_formats.data())); + + if (m_print_available_swapchain_formats) { + tlog::info() << fmt::format("Swapchain formats ({}):", swapchain_formats.size()); + for (const auto& f : swapchain_formats) { + tlog::info() << fmt::format("\t{:#x}", f); + } + } + + auto find_compatible_swapchain_format = [&](const std::vector& candidates) { + for (auto format : candidates) { + if (std::find(std::begin(swapchain_formats), std::end(swapchain_formats), format) != std::end(swapchain_formats)) { + return format; + } + } + + throw std::runtime_error{"No compatible OpenXR swapchain format found"}; + }; + + m_swapchain_rgba_format = find_compatible_swapchain_format({ + GL_SRGB8_ALPHA8, + GL_SRGB8, + GL_RGBA8, + }); + + if (m_supports_composition_layer_depth) { + m_swapchain_depth_format = find_compatible_swapchain_format({ + GL_DEPTH_COMPONENT32F, + GL_DEPTH_COMPONENT24, + GL_DEPTH_COMPONENT16, + }); + } + + // tlog::info() << fmt::format("Chosen swapchain format: {:#x}", m_swapchain_rgba_format); + for (const auto& vcv : m_view_configuration_views) { + XrSwapchainCreateInfo rgba_swapchain_create_info{XR_TYPE_SWAPCHAIN_CREATE_INFO}; + rgba_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_COLOR_ATTACHMENT_BIT; + rgba_swapchain_create_info.format = m_swapchain_rgba_format; + rgba_swapchain_create_info.sampleCount = 1; + rgba_swapchain_create_info.width = vcv.recommendedImageRectWidth; + rgba_swapchain_create_info.height = vcv.recommendedImageRectHeight; + rgba_swapchain_create_info.faceCount = 1; + rgba_swapchain_create_info.arraySize = 1; + rgba_swapchain_create_info.mipCount = 1; + + XrSwapchainCreateInfo depth_swapchain_create_info = rgba_swapchain_create_info; + depth_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT; + depth_swapchain_create_info.format = m_swapchain_depth_format; + + m_swapchains.emplace_back(rgba_swapchain_create_info, depth_swapchain_create_info, m_session, m_instance); + } +} + +void OpenXRHMD::init_open_gl_shaders() { + // Hidden area mask program + { + static const char* shader_vert = R"(#version 140 + in vec2 pos; + uniform mat4 project; + void main() { + vec4 pos = project * vec4(pos, -1.0, 1.0); + pos.xyz /= pos.w; + pos.y *= -1.0; + gl_Position = pos; + })"; + + static const char* shader_frag = R"(#version 140 + out vec4 frag_color; + void main() { + frag_color = vec4(0.0, 0.0, 0.0, 1.0); + })"; + + GLuint vert = glCreateShader(GL_VERTEX_SHADER); + glShaderSource(vert, 1, &shader_vert, NULL); + glCompileShader(vert); + check_shader(vert, "OpenXR hidden area mask vertex shader", false); + + GLuint frag = glCreateShader(GL_FRAGMENT_SHADER); + glShaderSource(frag, 1, &shader_frag, NULL); + glCompileShader(frag); + check_shader(frag, "OpenXR hidden area mask fragment shader", false); + + m_hidden_area_mask_program = glCreateProgram(); + glAttachShader(m_hidden_area_mask_program, vert); + glAttachShader(m_hidden_area_mask_program, frag); + glLinkProgram(m_hidden_area_mask_program); + check_shader(m_hidden_area_mask_program, "OpenXR hidden area mask shader program", true); + + glDeleteShader(vert); + glDeleteShader(frag); + } +} + +void OpenXRHMD::session_state_change(XrSessionState state, EControlFlow& flow) { + //tlog::info() << fmt::format("New session state {}", XrEnumStr(state)); + switch (state) { + case XR_SESSION_STATE_READY: { + XrSessionBeginInfo sessionBeginInfo {XR_TYPE_SESSION_BEGIN_INFO}; + sessionBeginInfo.primaryViewConfigurationType = m_view_configuration_type; + XR_CHECK_THROW(xrBeginSession(m_session, &sessionBeginInfo)); + break; + } + case XR_SESSION_STATE_STOPPING: { + XR_CHECK_THROW(xrEndSession(m_session)); + break; + } + case XR_SESSION_STATE_EXITING: { + flow = EControlFlow::Quit; + break; + } + case XR_SESSION_STATE_LOSS_PENDING: { + flow = EControlFlow::Restart; + break; + } + default: { + break; + } + } +} + +OpenXRHMD::EControlFlow OpenXRHMD::poll_events() { + bool more = true; + EControlFlow flow = EControlFlow::Continue; + while (more) { + // poll events + XrEventDataBuffer event {XR_TYPE_EVENT_DATA_BUFFER, nullptr}; + XrResult result = xrPollEvent(m_instance, &event); + + if (XR_FAILED(result)) { + tlog::error() << "xrPollEvent failed"; + } else if (XR_SUCCESS == result) { + switch (event.type) { + case XR_TYPE_EVENT_DATA_SESSION_STATE_CHANGED: { + const XrEventDataSessionStateChanged& e = *reinterpret_cast(&event); + //tlog::info() << "Session state change"; + //tlog::info() << fmt::format("\t from {}\t to {}", XrEnumStr(m_session_state), XrEnumStr(e.state)); + //tlog::info() << fmt::format("\t session {}, time {}", fmt::ptr(e.session), e.time); + m_session_state = e.state; + session_state_change(e.state, flow); + break; + } + + case XR_TYPE_EVENT_DATA_INSTANCE_LOSS_PENDING: { + flow = EControlFlow::Restart; + break; + } + + case XR_TYPE_EVENT_DATA_VISIBILITY_MASK_CHANGED_KHR: { + m_hidden_area_masks.clear(); + break; + } + + case XR_TYPE_EVENT_DATA_INTERACTION_PROFILE_CHANGED: { + break; // Can ignore + } + + default: { + tlog::info() << fmt::format("Unhandled event type {}", XrEnumStr(event.type)); + break; + } + } + } else if (XR_EVENT_UNAVAILABLE == result) { + more = false; + } + } + return flow; +} + +__global__ void read_hidden_area_mask_kernel(const ivec2 resolution, cudaSurfaceObject_t surface, uint8_t* __restrict__ mask) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + uint32_t idx = x + resolution.x * y; + surf2Dread(&mask[idx], surface, x, y); +} + +std::shared_ptr> OpenXRHMD::rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view) { + if (!m_supports_hidden_area_mask) { + return {}; + } + + PFN_xrGetVisibilityMaskKHR xrGetVisibilityMaskKHR = nullptr; + XR_CHECK_THROW(xrGetInstanceProcAddr( + m_instance, + "xrGetVisibilityMaskKHR", + reinterpret_cast(&xrGetVisibilityMaskKHR) + )); + + XrVisibilityMaskKHR visibility_mask{XR_TYPE_VISIBILITY_MASK_KHR}; + XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask)); + + if (visibility_mask.vertexCountOutput == 0 || visibility_mask.indexCountOutput == 0) { + return nullptr; + } + + std::vector vertices(visibility_mask.vertexCountOutput); + std::vector indices(visibility_mask.indexCountOutput); + + visibility_mask.vertices = vertices.data(); + visibility_mask.indices = indices.data(); + + visibility_mask.vertexCapacityInput = visibility_mask.vertexCountOutput; + visibility_mask.indexCapacityInput = visibility_mask.indexCountOutput; + + XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask)); + + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + ivec2 size = {view.subImage.imageRect.extent.width, view.subImage.imageRect.extent.height}; + + bool tex = glIsEnabled(GL_TEXTURE_2D); + bool depth = glIsEnabled(GL_DEPTH_TEST); + bool cull = glIsEnabled(GL_CULL_FACE); + GLint previous_texture_id; + glGetIntegerv(GL_TEXTURE_BINDING_2D, &previous_texture_id); + + if (!tex) glEnable(GL_TEXTURE_2D); + if (depth) glDisable(GL_DEPTH_TEST); + if (cull) glDisable(GL_CULL_FACE); + + // Generate texture to hold hidden area mask. Single channel, value of 1 means visible and 0 means masked away + ngp::GLTexture mask_texture; + mask_texture.resize(size, 1, true); + glBindTexture(GL_TEXTURE_2D, mask_texture.texture()); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + + GLuint framebuffer = 0; + glGenFramebuffers(1, &framebuffer); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer); + glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, mask_texture.texture(), 0); + + GLenum draw_buffers[1] = {GL_COLOR_ATTACHMENT0}; + glDrawBuffers(1, draw_buffers); + + glViewport(0, 0, size.x, size.y); + + // Draw hidden area mask + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + GLuint vertex_buffer; + glGenBuffers(1, &vertex_buffer); + glEnableVertexAttribArray(0); + glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer); + glBufferData(GL_ARRAY_BUFFER, sizeof(XrVector2f) * vertices.size(), vertices.data(), GL_STATIC_DRAW); + glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 0, (void*)0); + + GLuint index_buffer; + glGenBuffers(1, &index_buffer); + glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, index_buffer); + glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(uint32_t) * indices.size(), indices.data(), GL_STATIC_DRAW); + + glClearColor(1.0f, 1.0f, 1.0f, 1.0f); + glClear(GL_COLOR_BUFFER_BIT); + glUseProgram(m_hidden_area_mask_program); + + XrMatrix4x4f proj; + XrMatrix4x4f_CreateProjectionFov(&proj, GRAPHICS_OPENGL, view.fov, 1.0f / 128.0f, 128.0f); + + GLuint project_id = glGetUniformLocation(m_hidden_area_mask_program, "project"); + glUniformMatrix4fv(project_id, 1, GL_FALSE, &proj.m[0]); + + glDrawElements(GL_TRIANGLES, indices.size(), GL_UNSIGNED_INT, (void*)0); + glFinish(); + + glDisableVertexAttribArray(0); + glDeleteBuffers(1, &vertex_buffer); + glDeleteBuffers(1, &index_buffer); + glDeleteVertexArrays(1, &vao); + glDeleteFramebuffers(1, &framebuffer); + + glBindVertexArray(0); + glUseProgram(0); + + // restore old state + if (!tex) glDisable(GL_TEXTURE_2D); + if (depth) glEnable(GL_DEPTH_TEST); + if (cull) glEnable(GL_CULL_FACE); + glBindTexture(GL_TEXTURE_2D, previous_texture_id); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + + std::shared_ptr> mask = std::make_shared>(size); + + const dim3 threads = { 16, 8, 1 }; + const dim3 blocks = { div_round_up((uint32_t)size.x, threads.x), div_round_up((uint32_t)size.y, threads.y), 1 }; + + read_hidden_area_mask_kernel<<>>(size, mask_texture.surface(), mask->data()); + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + return mask; +} + +mat4x3 convert_xr_matrix_to_glm(const XrMatrix4x4f& m) { + mat4x3 out; + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 4; ++j) { + out[j][i] = m.m[i + j * 4]; + } + } + + // Flip Y and Z axes to match NGP conventions + out[1][0] *= -1.f; + out[0][1] *= -1.f; + + out[2][0] *= -1.f; + out[0][2] *= -1.f; + + out[3][1] *= -1.f; + out[3][2] *= -1.f; + + return out; +} + +mat4x3 convert_xr_pose_to_eigen(const XrPosef& pose) { + XrMatrix4x4f matrix; + XrVector3f unit_scale{1.0f, 1.0f, 1.0f}; + XrMatrix4x4f_CreateTranslationRotationScale(&matrix, &pose.position, &pose.orientation, &unit_scale); + return convert_xr_matrix_to_glm(matrix); +} + +OpenXRHMD::FrameInfoPtr OpenXRHMD::begin_frame() { + XrFrameWaitInfo frame_wait_info{XR_TYPE_FRAME_WAIT_INFO}; + XR_CHECK_THROW(xrWaitFrame(m_session, &frame_wait_info, &m_frame_state)); + + XrFrameBeginInfo frame_begin_info{XR_TYPE_FRAME_BEGIN_INFO}; + XR_CHECK_THROW(xrBeginFrame(m_session, &frame_begin_info)); + + if (!m_frame_state.shouldRender) { + return std::make_shared(); + } + + uint32_t num_views = (uint32_t)m_swapchains.size(); + // TODO assert m_view_configuration_views.size() == m_swapchains.size() + + // locate views + std::vector views(num_views, {XR_TYPE_VIEW}); + + XrViewState viewState{XR_TYPE_VIEW_STATE}; + + XrViewLocateInfo view_locate_info{XR_TYPE_VIEW_LOCATE_INFO}; + view_locate_info.viewConfigurationType = m_view_configuration_type; + view_locate_info.displayTime = m_frame_state.predictedDisplayTime; + view_locate_info.space = m_space; + + XR_CHECK_THROW(xrLocateViews(m_session, &view_locate_info, &viewState, uint32_t(views.size()), &num_views, views.data())); + + if (!(viewState.viewStateFlags & XR_VIEW_STATE_POSITION_VALID_BIT) || !(viewState.viewStateFlags & XR_VIEW_STATE_ORIENTATION_VALID_BIT)) { + return std::make_shared(); + } + + m_hidden_area_masks.resize(num_views); + + // Fill frame information + if (!m_previous_frame_info) { + m_previous_frame_info = std::make_shared(); + } + + FrameInfoPtr frame_info = std::make_shared(*m_previous_frame_info); + frame_info->views.resize(m_swapchains.size()); + + for (size_t i = 0; i < m_swapchains.size(); ++i) { + const auto& sc = m_swapchains[i]; + + XrSwapchainImageAcquireInfo image_acquire_info{XR_TYPE_SWAPCHAIN_IMAGE_ACQUIRE_INFO}; + XrSwapchainImageWaitInfo image_wait_info{XR_TYPE_SWAPCHAIN_IMAGE_WAIT_INFO, nullptr, XR_INFINITE_DURATION}; + + uint32_t image_index; + XR_CHECK_THROW(xrAcquireSwapchainImage(sc.handle, &image_acquire_info, &image_index)); + XR_CHECK_THROW(xrWaitSwapchainImage(sc.handle, &image_wait_info)); + + FrameInfo::View& v = frame_info->views[i]; + v.framebuffer = sc.framebuffers_gl[image_index]; + v.view.pose = views[i].pose; + v.view.fov = views[i].fov; + v.view.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}}; + v.view.subImage.imageArrayIndex = 0; + v.view.subImage.swapchain = sc.handle; + + glBindFramebuffer(GL_FRAMEBUFFER, sc.framebuffers_gl[image_index]); + glClearColor(0.0f, 0.0f, 0.0f, 0.0f); + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, sc.images_gl.at(image_index).image, 0); + + if (sc.depth_handle != XR_NULL_HANDLE) { + uint32_t depth_image_index; + XR_CHECK_THROW(xrAcquireSwapchainImage(sc.depth_handle, &image_acquire_info, &depth_image_index)); + XR_CHECK_THROW(xrWaitSwapchainImage(sc.depth_handle, &image_wait_info)); + + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, sc.depth_images_gl.at(depth_image_index).image, 0); + + v.depth_info.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}}; + v.depth_info.subImage.imageArrayIndex = 0; + v.depth_info.subImage.swapchain = sc.depth_handle; + v.depth_info.minDepth = 0.0f; + v.depth_info.maxDepth = 1.0f; + + // To be overwritten with actual near and far planes by end_frame + v.depth_info.nearZ = 1.0f / 128.0f; + v.depth_info.farZ = 128.0f; + } + + glBindFramebuffer(GL_FRAMEBUFFER, 0); + + if (!m_hidden_area_masks.at(i)) { + m_hidden_area_masks.at(i) = rasterize_hidden_area_mask(i, v.view); + } + + v.hidden_area_mask = m_hidden_area_masks.at(i); + v.pose = convert_xr_pose_to_eigen(v.view.pose); + } + + XrActiveActionSet active_action_set{m_action_set, XR_NULL_PATH}; + XrActionsSyncInfo sync_info{XR_TYPE_ACTIONS_SYNC_INFO}; + sync_info.countActiveActionSets = 1; + sync_info.activeActionSets = &active_action_set; + XR_CHECK_THROW(xrSyncActions(m_session, &sync_info)); + + for (size_t i = 0; i < 2; ++i) { + // Hand pose + { + XrActionStatePose pose_state{XR_TYPE_ACTION_STATE_POSE}; + XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO}; + get_info.action = m_pose_action; + get_info.subactionPath = m_hand_paths[i]; + XR_CHECK_THROW(xrGetActionStatePose(m_session, &get_info, &pose_state)); + + frame_info->hands[i].pose_active = pose_state.isActive; + if (frame_info->hands[i].pose_active) { + XrSpaceLocation space_location{XR_TYPE_SPACE_LOCATION}; + XR_CHECK_THROW(xrLocateSpace(m_hand_spaces[i], m_space, m_frame_state.predictedDisplayTime, &space_location)); + frame_info->hands[i].pose = convert_xr_pose_to_eigen(space_location.pose); + } + } + + // Stick + { + XrActionStateVector2f thumbstick_state{XR_TYPE_ACTION_STATE_VECTOR2F}; + XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO}; + get_info.action = m_thumbstick_actions[i]; + XR_CHECK_THROW(xrGetActionStateVector2f(m_session, &get_info, &thumbstick_state)); + + if (thumbstick_state.isActive) { + frame_info->hands[i].thumbstick.x = thumbstick_state.currentState.x; + frame_info->hands[i].thumbstick.y = thumbstick_state.currentState.y; + } else { + frame_info->hands[i].thumbstick = vec2(0.0f); + } + } + + // Press + { + XrActionStateBoolean press_state{XR_TYPE_ACTION_STATE_BOOLEAN}; + XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO}; + get_info.action = m_press_action; + get_info.subactionPath = m_hand_paths[i]; + XR_CHECK_THROW(xrGetActionStateBoolean(m_session, &get_info, &press_state)); + + if (press_state.isActive) { + frame_info->hands[i].pressing = press_state.currentState; + } else { + frame_info->hands[i].pressing = 0.0f; + } + } + + // Grab + { + XrActionStateFloat grab_state{XR_TYPE_ACTION_STATE_FLOAT}; + XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO}; + get_info.action = m_grab_action; + get_info.subactionPath = m_hand_paths[i]; + XR_CHECK_THROW(xrGetActionStateFloat(m_session, &get_info, &grab_state)); + + if (grab_state.isActive) { + frame_info->hands[i].grab_strength = grab_state.currentState; + } else { + frame_info->hands[i].grab_strength = 0.0f; + } + + bool was_grabbing = frame_info->hands[i].grabbing; + frame_info->hands[i].grabbing = frame_info->hands[i].grab_strength >= 0.5f; + + if (frame_info->hands[i].grabbing) { + frame_info->hands[i].prev_grab_pos = was_grabbing ? frame_info->hands[i].grab_pos : frame_info->hands[i].pose[3]; + frame_info->hands[i].grab_pos = frame_info->hands[i].pose[3]; + } + } + } + + m_previous_frame_info = frame_info; + return frame_info; +} + +void OpenXRHMD::end_frame(FrameInfoPtr frame_info, float znear, float zfar, bool submit_depth) { + std::vector layer_projection_views(frame_info->views.size()); + for (size_t i = 0; i < layer_projection_views.size(); ++i) { + auto& v = frame_info->views[i]; + auto& view = layer_projection_views[i]; + + view = v.view; + + // release swapchain image + XrSwapchainImageReleaseInfo release_info{XR_TYPE_SWAPCHAIN_IMAGE_RELEASE_INFO}; + XR_CHECK_THROW(xrReleaseSwapchainImage(v.view.subImage.swapchain, &release_info)); + + if (v.depth_info.subImage.swapchain != XR_NULL_HANDLE) { + XR_CHECK_THROW(xrReleaseSwapchainImage(v.depth_info.subImage.swapchain, &release_info)); + v.depth_info.nearZ = znear; + v.depth_info.farZ = zfar; + + // Submitting the depth buffer to the runtime for reprojection is optional, + // because, while depth-based reprojection can make the experience smoother, + // it also results in distortion around geometric edges. Many users prefer + // a more stuttery experience without this distortion. + if (submit_depth) { + view.next = &v.depth_info; + } + } + } + + XrCompositionLayerProjection layer{XR_TYPE_COMPOSITION_LAYER_PROJECTION}; + layer.space = m_space; + if (m_environment_blend_mode != EEnvironmentBlendMode::Opaque) { + layer.layerFlags = XR_COMPOSITION_LAYER_BLEND_TEXTURE_SOURCE_ALPHA_BIT; + } + + layer.viewCount = uint32_t(layer_projection_views.size()); + layer.views = layer_projection_views.data(); + + std::vector layers; + if (layer.viewCount) { + layers.push_back(reinterpret_cast(&layer)); + } + + XrFrameEndInfo frame_end_info{XR_TYPE_FRAME_END_INFO}; + frame_end_info.displayTime = m_frame_state.predictedDisplayTime; + frame_end_info.environmentBlendMode = (XrEnvironmentBlendMode)m_environment_blend_mode; + frame_end_info.layerCount = (uint32_t)layers.size(); + frame_end_info.layers = layers.data(); + XR_CHECK_THROW(xrEndFrame(m_session, &frame_end_info)); +} + +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/gui/src/python_api.cu b/gui/src/python_api.cu new file mode 100644 index 0000000000000000000000000000000000000000..a3dc9e21c37f30b8d016f41cc016c5203c89f054 --- /dev/null +++ b/gui/src/python_api.cu @@ -0,0 +1,829 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file python_api.cpp + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef NGP_GUI +# include +# ifdef _WIN32 +# include +# else +# include +# endif +# include +#endif + +using namespace nlohmann; +namespace py = pybind11; + +namespace ngp { + +// Returns RGBA and depth buffers +std::pair, py::array_t> + Testbed::render_to_cpu(int width, int height, int spp, bool linear, float start_time, float end_time, float fps, float shutter_fraction) { + m_windowless_render_surface.resize({width, height}); + m_windowless_render_surface.reset_accumulation(); + + if (end_time < 0.f) { + end_time = start_time; + } + + bool path_animation_enabled = start_time >= 0.f; + if (!path_animation_enabled) { // the old code disabled camera smoothing for non-path renders; so we preserve that behaviour + m_smoothed_camera = m_camera; + } + + // this rendering code assumes that the intra-frame camera motion starts from m_smoothed_camera (ie where we left off) to allow for EMA + // camera smoothing. in the case of a camera path animation, at the very start of the animation, we have yet to initialize + // smoothed_camera to something sensible + // - it will just be the default boot position. oops! + // that led to the first frame having a crazy streak from the default camera position to the start of the path. + // so we detect that case and explicitly force the current matrix to the start of the path + if (start_time == 0.f) { + set_camera_from_time(start_time); + m_smoothed_camera = m_camera; + } + + auto start_cam_matrix = m_smoothed_camera; + + // now set up the end-of-frame camera matrix if we are moving along a path + if (path_animation_enabled) { + set_camera_from_time(end_time); + apply_camera_smoothing(1000.f / fps); + } + + auto end_cam_matrix = m_smoothed_camera; + auto prev_camera_matrix = m_smoothed_camera; + + for (int i = 0; i < spp; ++i) { + float start_alpha = ((float)i) / (float)spp * shutter_fraction; + float end_alpha = ((float)i + 1.0f) / (float)spp * shutter_fraction; + + auto sample_start_cam_matrix = start_cam_matrix; + auto sample_end_cam_matrix = camera_log_lerp(start_cam_matrix, end_cam_matrix, shutter_fraction); + if (i == 0) { + prev_camera_matrix = sample_start_cam_matrix; + } + + if (path_animation_enabled) { + set_camera_from_time(start_time + (end_time - start_time) * (start_alpha + end_alpha) / 2.0f); + m_smoothed_camera = m_camera; + } + + if (m_autofocus) { + autofocus(); + } + + render_frame( + m_stream.get(), + sample_start_cam_matrix, + sample_end_cam_matrix, + prev_camera_matrix, + m_screen_center, + m_relative_focal_length, + {}, // foveation + {}, // prev foveation + {}, // lens + m_visualized_dimension, + m_windowless_render_surface, + !linear + ); + prev_camera_matrix = sample_start_cam_matrix; + } + + // For cam smoothing when rendering the next frame. + m_smoothed_camera = end_cam_matrix; + + py::array_t result_rgba({height, width, 4}); + py::buffer_info buf_rgba = result_rgba.request(); + + py::array_t result_depth({height, width}); + py::buffer_info buf_depth = result_depth.request(); + + CUDA_CHECK_THROW(cudaMemcpy2DFromArray( + buf_rgba.ptr, width * sizeof(float) * 4, m_windowless_render_surface.surface_provider().array(), 0, 0, width * sizeof(float) * 4, height, cudaMemcpyDeviceToHost + )); + + CUDA_CHECK_THROW( + cudaMemcpy(buf_depth.ptr, m_windowless_render_surface.depth_buffer(), height * width * sizeof(float), cudaMemcpyDeviceToHost) + ); + + return {result_rgba, result_depth}; +} + +py::array_t Testbed::render_to_cpu_rgba( + int width, int height, int spp, bool linear, float start_time, float end_time, float fps, float shutter_fraction +) { + return render_to_cpu(width, height, spp, linear, start_time, end_time, fps, shutter_fraction).first; +} + +py::array_t Testbed::view(bool linear, size_t view_idx) const { + if (m_views.size() <= view_idx) { + throw std::runtime_error{fmt::format("View #{} does not exist.", view_idx)}; + } + + auto& view = m_views.at(view_idx); + auto& render_buffer = *view.render_buffer; + + auto res = render_buffer.out_resolution(); + + py::array_t result({res.y, res.x, 4}); + py::buffer_info buf = result.request(); + float* data = (float*)buf.ptr; + + CUDA_CHECK_THROW(cudaMemcpy2DFromArray( + data, res.x * sizeof(float) * 4, render_buffer.surface_provider().array(), 0, 0, res.x * sizeof(float) * 4, res.y, cudaMemcpyDeviceToHost + )); + + if (linear) { + ThreadPool{}.parallel_for(0, res.y, [&](size_t y) { + size_t base = y * res.x; + for (uint32_t x = 0; x < res.x; ++x) { + size_t px = base + x; + data[px * 4 + 0] = srgb_to_linear(data[px * 4 + 0]); + data[px * 4 + 1] = srgb_to_linear(data[px * 4 + 1]); + data[px * 4 + 2] = srgb_to_linear(data[px * 4 + 2]); + } + }); + } + + return result; +} + +std::pair, py::array_t> + Testbed::reproject(const mat4x3& src, const py::array_t& src_img, const py::array_t& src_depth, const mat4x3& dst) { + + py::buffer_info src_img_buf = src_img.request(); + py::buffer_info src_depth_buf = src_depth.request(); + + if (src_img_buf.ndim != 3) { + throw std::runtime_error{"src image should be (H,W,C) where C=4"}; + } + + if (src_img_buf.shape[2] != 4) { + throw std::runtime_error{"src image should be (H,W,C) where C=4"}; + } + + if (src_depth_buf.ndim != 2) { + throw std::runtime_error{"src depth should be (H,W)"}; + } + + if (src_img_buf.shape[0] != src_depth_buf.shape[0] || src_img_buf.shape[1] != src_depth_buf.shape[1]) { + throw std::runtime_error{"image and depth dimensions don't match"}; + } + + const ivec2 src_res = {(int)src_img_buf.shape[1], (int)src_img_buf.shape[0]}; + const ivec2 dst_res = src_res; // For now + + auto src_render_buffer = std::make_shared(std::make_shared()); + src_render_buffer->resize(src_res); + + auto dst_render_buffer = std::make_shared(std::make_shared()); + dst_render_buffer->resize(dst_res); + + View src_view, dst_view; + + src_view.camera0 = src_view.camera1 = src_view.prev_camera = src; + src_view.device = &primary_device(); + src_view.foveation = src_view.prev_foveation = {}; + src_view.screen_center = vec2(0.5f); + src_view.full_resolution = src_res; + src_view.visualized_dimension = -1; + src_view.relative_focal_length = m_relative_focal_length; + src_view.render_buffer = src_render_buffer; + + dst_view.camera0 = dst_view.camera1 = dst_view.prev_camera = dst; + dst_view.device = &primary_device(); + dst_view.foveation = dst_view.prev_foveation = {}; + dst_view.screen_center = vec2(0.5f); + dst_view.full_resolution = dst_res; + dst_view.visualized_dimension = -1; + dst_view.relative_focal_length = m_relative_focal_length; + dst_view.render_buffer = dst_render_buffer; + + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_render_buffer->frame_buffer(), src_img_buf.ptr, product(src_res) * sizeof(float) * 4, cudaMemcpyHostToDevice, m_stream.get() + )); + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_render_buffer->depth_buffer(), src_depth_buf.ptr, product(src_res) * sizeof(float), cudaMemcpyHostToDevice, m_stream.get() + )); + + std::vector src_views = {&src_view}; + reproject_views(src_views, dst_view); + + py::array_t result_rgba({dst_res.y, dst_res.x, 4}); + py::buffer_info buf_rgba = result_rgba.request(); + + py::array_t result_idx({dst_res.y, dst_res.x}); + py::buffer_info buf_idx = result_idx.request(); + + CUDA_CHECK_THROW(cudaMemcpyAsync( + buf_rgba.ptr, dst_render_buffer->frame_buffer(), product(dst_res) * sizeof(float) * 4, cudaMemcpyDeviceToHost, m_stream.get() + )); + + auto idx_buffer = GPUImage(dst_res, m_stream.get()); + + parallel_for_gpu( + m_stream.get(), + idx_buffer.n_elements(), + [out = idx_buffer.view(), in = dst_view.index_field.view(), src_width = src_res.x, dst_width = dst_res.x] __device__(size_t i) { + ivec2 idx = ivec2(i % dst_width, i / dst_width); + ivec2 src_idx = in(idx.y, idx.x).px; + out(idx.y, idx.x) = src_idx.x + src_idx.y * src_width; + } + ); + + CUDA_CHECK_THROW( + cudaMemcpyAsync(buf_idx.ptr, idx_buffer.data(), product(dst_res) * sizeof(uint32_t), cudaMemcpyDeviceToHost, m_stream.get()) + ); + + return {result_rgba, result_idx}; +} + +uint32_t Testbed::add_src_view( + mat4x3 camera_to_world, float fx, float fy, float cx, float cy, Lens lens, pybind11::array_t img, pybind11::array_t depth, float timestamp, bool is_srgb +) { + py::buffer_info src_img_buf = img.request(); + py::buffer_info src_depth_buf = depth.request(); + + if (src_img_buf.ndim != 3) { + throw std::runtime_error{"src image should be (H,W,C) where C=4"}; + } + + if (src_img_buf.shape[2] != 4) { + throw std::runtime_error{"src image should be (H,W,C) where C=4"}; + } + + if (src_depth_buf.ndim != 2) { + throw std::runtime_error{"src depth should be (H,W)"}; + } + + if (src_img_buf.shape[0] != src_depth_buf.shape[0] || src_img_buf.shape[1] != src_depth_buf.shape[1]) { + throw std::runtime_error{"image and depth dimensions don't match"}; + } + + const ivec2 src_res = {(int)src_img_buf.shape[1], (int)src_img_buf.shape[0]}; + + static uint32_t id = 0; + + m_reproject_src_views.emplace_back(); + if (m_reproject_max_src_view_count > 0 && m_reproject_src_views.size() > (size_t)m_reproject_max_src_view_count) { + m_reproject_src_views.pop_front(); + } + + auto& src_view = m_reproject_src_views.back(); + src_view.uid = id++; + src_view.camera0 = src_view.camera1 = src_view.prev_camera = camera_to_world; + src_view.device = &primary_device(); + src_view.foveation = src_view.prev_foveation = {}; + src_view.screen_center = vec2(cx, cy); + src_view.full_resolution = src_res; + src_view.visualized_dimension = -1; + src_view.relative_focal_length = vec2(fx, fy) / (float)src_res[m_fov_axis]; + src_view.render_buffer = std::make_shared(std::make_shared()); + src_view.render_buffer->resize(src_res); + src_view.lens = lens; + + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_view.render_buffer->frame_buffer(), src_img_buf.ptr, product(src_res) * sizeof(float) * 4, cudaMemcpyHostToDevice, m_stream.get() + )); + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_view.render_buffer->depth_buffer(), src_depth_buf.ptr, product(src_res) * sizeof(float), cudaMemcpyHostToDevice, m_stream.get() + )); + + if (is_srgb) { + // Convert from sRGB to linear on the GPU directly + parallel_for_gpu( + m_stream.get(), + product(src_res) * 4, + [values = (float *) src_view.render_buffer->frame_buffer()] __device__(size_t i) { + if ((i % 4) == 3) { + // Don't linearize the alpha channel + return; + } + values[i] = srgb_to_linear(values[i]); + } + ); + } + + return src_view.uid; +} + + +pybind11::array_t Testbed::src_view_ids() const { + py::array_t result({(int)m_reproject_src_views.size()}); + py::buffer_info buf = result.request(); + uint32_t* data = (uint32_t*)buf.ptr; + for (size_t i = 0; i < m_reproject_src_views.size(); ++i) { + data[i] = m_reproject_src_views[i].uid; + } + return result; +} + +#ifdef NGP_GUI +py::array_t Testbed::screenshot(bool linear, bool front_buffer) const { + std::vector tmp(product(m_window_res) * 4); + glReadBuffer(front_buffer ? GL_FRONT : GL_BACK); + glReadPixels(0, 0, m_window_res.x, m_window_res.y, GL_RGBA, GL_FLOAT, tmp.data()); + + py::array_t result({m_window_res.y, m_window_res.x, 4}); + py::buffer_info buf = result.request(); + float* data = (float*)buf.ptr; + + // Linear, alpha premultiplied, Y flipped + ThreadPool{}.parallel_for(0, m_window_res.y, [&](size_t y) { + size_t base = y * m_window_res.x; + size_t base_reverse = (m_window_res.y - y - 1) * m_window_res.x; + for (uint32_t x = 0; x < m_window_res.x; ++x) { + size_t px = base + x; + size_t px_reverse = base_reverse + x; + data[px_reverse * 4 + 0] = linear ? srgb_to_linear(tmp[px * 4 + 0]) : tmp[px * 4 + 0]; + data[px_reverse * 4 + 1] = linear ? srgb_to_linear(tmp[px * 4 + 1]) : tmp[px * 4 + 1]; + data[px_reverse * 4 + 2] = linear ? srgb_to_linear(tmp[px * 4 + 2]) : tmp[px * 4 + 2]; + data[px_reverse * 4 + 3] = tmp[px * 4 + 3]; + } + }); + + return result; +} +#endif + +PYBIND11_MODULE(pyngp, m) { + m.doc() = "Gen3C GUI"; + + m.def("free_temporary_memory", &free_all_gpu_memory_arenas); + + py::enum_(m, "TestbedMode") + .value("Gen3c", ETestbedMode::Gen3c) + .value("None", ETestbedMode::None) + .export_values(); + + m.def("mode_from_scene", &mode_from_scene); + m.def("mode_from_string", &mode_from_string); + + py::enum_(m, "GroundTruthRenderMode") + .value("Shade", EGroundTruthRenderMode::Shade) + .value("Depth", EGroundTruthRenderMode::Depth) + .export_values(); + + py::enum_(m, "RenderMode") + .value("AO", ERenderMode::AO) + .value("Shade", ERenderMode::Shade) + .value("Normals", ERenderMode::Normals) + .value("Positions", ERenderMode::Positions) + .value("Depth", ERenderMode::Depth) + .value("Distortion", ERenderMode::Distortion) + .value("Cost", ERenderMode::Cost) + .value("Slice", ERenderMode::Slice) + .export_values(); + + py::enum_(m, "RandomMode") + .value("Random", ERandomMode::Random) + .value("Halton", ERandomMode::Halton) + .value("Sobol", ERandomMode::Sobol) + .value("Stratified", ERandomMode::Stratified) + .export_values(); + + py::enum_(m, "LossType") + .value("L2", ELossType::L2) + .value("L1", ELossType::L1) + .value("Mape", ELossType::Mape) + .value("Smape", ELossType::Smape) + .value("Huber", ELossType::Huber) + // Legacy: we used to refer to the Huber loss + // (L2 near zero, L1 further away) as "SmoothL1". + .value("SmoothL1", ELossType::Huber) + .value("LogL1", ELossType::LogL1) + .value("RelativeL2", ELossType::RelativeL2) + .export_values(); + + py::enum_(m, "SDFGroundTruthMode") + .value("RaytracedMesh", ESDFGroundTruthMode::RaytracedMesh) + .value("SpheretracedMesh", ESDFGroundTruthMode::SpheretracedMesh) + .value("SDFBricks", ESDFGroundTruthMode::SDFBricks) + .export_values(); + + py::enum_(m, "MeshSdfMode") + .value("Watertight", EMeshSdfMode::Watertight) + .value("Raystab", EMeshSdfMode::Raystab) + .value("PathEscape", EMeshSdfMode::PathEscape) + .export_values(); + + py::enum_(m, "ColorSpace").value("Linear", EColorSpace::Linear).value("SRGB", EColorSpace::SRGB).export_values(); + + py::enum_(m, "TonemapCurve") + .value("Identity", ETonemapCurve::Identity) + .value("ACES", ETonemapCurve::ACES) + .value("Hable", ETonemapCurve::Hable) + .value("Reinhard", ETonemapCurve::Reinhard) + .export_values(); + + py::enum_(m, "LensMode") + .value("Perspective", ELensMode::Perspective) + .value("OpenCV", ELensMode::OpenCV) + .value("FTheta", ELensMode::FTheta) + .value("LatLong", ELensMode::LatLong) + .value("OpenCVFisheye", ELensMode::OpenCVFisheye) + .value("Equirectangular", ELensMode::Equirectangular) + .value("Orthographic", ELensMode::Orthographic) + .export_values(); + + + py::class_(m, "BoundingBox") + .def(py::init<>()) + .def(py::init()) + .def("center", &BoundingBox::center) + .def("contains", &BoundingBox::contains) + .def("diag", &BoundingBox::diag) + .def("distance", &BoundingBox::distance) + .def("distance_sq", &BoundingBox::distance_sq) + .def("enlarge", py::overload_cast(&BoundingBox::enlarge)) + .def("enlarge", py::overload_cast(&BoundingBox::enlarge)) + .def("get_vertices", &BoundingBox::get_vertices) + .def("inflate", &BoundingBox::inflate) + .def("intersection", &BoundingBox::intersection) + .def("intersects", py::overload_cast(&BoundingBox::intersects, py::const_)) + .def("ray_intersect", &BoundingBox::ray_intersect) + .def("relative_pos", &BoundingBox::relative_pos) + .def("signed_distance", &BoundingBox::signed_distance) + .def_readwrite("min", &BoundingBox::min) + .def_readwrite("max", &BoundingBox::max); + + py::class_ lens(m, "Lens"); + lens.def(py::init<>()).def_readwrite("mode", &Lens::mode).def_property_readonly("params", [](py::object& obj) { + Lens& o = obj.cast(); + return py::array{sizeof(o.params) / sizeof(o.params[0]), o.params, obj}; + }); + + m.def("fov_to_focal_length", py::overload_cast(&ngp::fov_to_focal_length), + py::arg("resolution"), py::arg("degrees")) + .def("fov_to_focal_length", py::overload_cast(&fov_to_focal_length), + py::arg("resolution"), py::arg("degrees")) + .def("focal_length_to_fov", py::overload_cast(&ngp::focal_length_to_fov), + py::arg("resolution"), py::arg("focal_length")) + .def("focal_length_to_fov", py::overload_cast(&ngp::focal_length_to_fov), + py::arg("resolution"), py::arg("focal_length")) + .def("relative_focal_length_to_fov", &ngp::relative_focal_length_to_fov, + py::arg("rel_focal_length")); + + py::class_(m, "path").def(py::init<>()).def(py::init()); + + py::implicitly_convertible(); + + py::class_ testbed(m, "Testbed"); + testbed.def(py::init(), py::arg("mode") = ETestbedMode::None) + .def_readonly("mode", &Testbed::m_testbed_mode) + // General control + .def( + "init_window", + &Testbed::init_window, + "Init a GLFW window that shows real-time progress and a GUI. 'second_window' creates a second copy of the output in its own window.", + py::arg("width"), + py::arg("height"), + py::arg("hidden") = false, + py::arg("second_window") = false + ) + .def("destroy_window", &Testbed::destroy_window, "Destroy the window again.") + .def( + "init_vr", + &Testbed::init_vr, + "Init rendering to a connected and active VR headset. Requires a window to have been previously created via `init_window`." + ) + .def( + "view", + &Testbed::view, + "Outputs the currently displayed image by a given view (0 by default).", + py::arg("linear") = true, + py::arg("view") = 0 + ) + .def("view_camera", &Testbed::view_camera, "Outputs the current camera matrix of a given view (0 by default).", py::arg("view") = 0) + .def( + "add_src_view", + &Testbed::add_src_view, + "Adds a source view to the pool of views for reprojection.", + py::arg("camera_to_world"), + py::arg("fx"), + py::arg("fy"), + py::arg("cx"), + py::arg("cy"), + py::arg("img"), + py::arg("depth"), + py::arg("lens"), + py::arg("timestamp"), + py::arg("is_srgb") = false + ) + .def("src_view_ids", &Testbed::src_view_ids, "Returns the IDs of all source views currently registered.") + .def("clear_src_views", &Testbed::clear_src_views, "Remove all views from the pool of views for reprojection.") +#ifdef NGP_GUI + .def_readwrite("keyboard_event_callback", &Testbed::m_keyboard_event_callback) + .def_readwrite("file_drop_callback", &Testbed::m_file_drop_callback) + .def("is_key_pressed", [](py::object& obj, int key) { return ImGui::IsKeyPressed(key); }) + .def("is_key_down", [](py::object& obj, int key) { return ImGui::IsKeyDown(key); }) + .def("is_alt_down", [](py::object& obj) { return ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Alt; }) + .def("is_ctrl_down", [](py::object& obj) { return ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Ctrl; }) + .def("is_shift_down", [](py::object& obj) { return ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Shift; }) + .def("is_super_down", [](py::object& obj) { return ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Super; }) + .def( + "screenshot", + &Testbed::screenshot, + "Takes a screenshot of the current window contents.", + py::arg("linear") = true, + py::arg("front_buffer") = true + ) + .def_readwrite("vr_use_hidden_area_mask", &Testbed::m_vr_use_hidden_area_mask) + .def_readwrite("vr_use_depth_reproject", &Testbed::m_vr_use_depth_reproject) +#endif + .def("want_repl", &Testbed::want_repl, "returns true if the user clicked the 'I want a repl' button") + .def( + "frame", &Testbed::frame, py::call_guard(), "Process a single frame. Renders if a window was previously created." + ) + .def( + "render", + &Testbed::render_to_cpu_rgba, + "Renders an image at the requested resolution. Does not require a window.", + py::arg("width") = 1920, + py::arg("height") = 1080, + py::arg("spp") = 1, + py::arg("linear") = true, + py::arg("start_t") = -1.f, + py::arg("end_t") = -1.f, + py::arg("fps") = 30.f, + py::arg("shutter_fraction") = 1.0f + ) + .def( + "render_with_depth", + &Testbed::render_to_cpu, + "Renders an image at the requested resolution. Does not require a window.", + py::arg("width") = 1920, + py::arg("height") = 1080, + py::arg("spp") = 1, + py::arg("linear") = true, + py::arg("start_t") = -1.f, + py::arg("end_t") = -1.f, + py::arg("fps") = 30.f, + py::arg("shutter_fraction") = 1.0f + ) + .def("reproject", &Testbed::reproject, "Reprojects an RGBA + depth image from a known camera view to another camera view.") + .def("reset_camera", &Testbed::reset_camera, "Reset camera to default state.") + .def( + "reset_accumulation", + &Testbed::reset_accumulation, + "Reset rendering accumulation.", + py::arg("due_to_camera_movement") = false, + py::arg("immediate_redraw") = true, + py::arg("reset_pip") = false + ) + .def("load_camera_path", &Testbed::load_camera_path, py::arg("path"), "Load a camera path") + .def( + "load_file", + &Testbed::load_file, + py::arg("path"), + "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path." + ) + .def_property("loop_animation", &Testbed::loop_animation, &Testbed::set_loop_animation) + // Interesting members. + .def_readwrite("reproject_min_t", &Testbed::m_reproject_min_t) + .def_readwrite("reproject_step_factor", &Testbed::m_reproject_step_factor) + .def_readwrite("reproject_parallax", &Testbed::m_reproject_parallax) + .def_readwrite("reproject_second_view", &Testbed::m_reproject_enable) + .def_readwrite("reproject_enable", &Testbed::m_reproject_enable) + .def_readwrite("reproject_visualize_src_views", &Testbed::m_reproject_visualize_src_views) + .def_readwrite("reproject_min_src_view_index", &Testbed::m_reproject_min_src_view_index) + .def_readwrite("reproject_max_src_view_index", &Testbed::m_reproject_max_src_view_index) + .def_readwrite("reproject_max_src_view_count", &Testbed::m_reproject_max_src_view_count) + .def("reproject_src_views_count", [](const Testbed& testbed) { return testbed.m_reproject_src_views.size(); }) + .def_readwrite("reproject_reuse_last_frame", &Testbed::m_reproject_reuse_last_frame) + .def("init_camera_path_from_reproject_src_cameras", &Testbed::init_camera_path_from_reproject_src_cameras) + .def_readwrite("pm_enable", &Testbed::m_pm_enable) + .def_readwrite("dynamic_res", &Testbed::m_dynamic_res) + .def_readwrite("dynamic_res_target_fps", &Testbed::m_dynamic_res_target_fps) + .def_readwrite("fixed_res_factor", &Testbed::m_fixed_res_factor) + .def_readwrite("background_color", &Testbed::m_background_color) + .def_readwrite("render_transparency_as_checkerboard", &Testbed::m_render_transparency_as_checkerboard) + .def_readwrite("render_groundtruth", &Testbed::m_render_ground_truth) + .def_readwrite("render_ground_truth", &Testbed::m_render_ground_truth) + .def_readwrite("groundtruth_render_mode", &Testbed::m_ground_truth_render_mode) + .def_readwrite("render_mode", &Testbed::m_render_mode) + .def_readwrite("render_near_distance", &Testbed::m_render_near_distance) + .def_readwrite("slice_plane_z", &Testbed::m_slice_plane_z) + .def_readwrite("dof", &Testbed::m_aperture_size) + .def_readwrite("aperture_size", &Testbed::m_aperture_size) + .def_readwrite("autofocus", &Testbed::m_autofocus) + .def_readwrite("autofocus_target", &Testbed::m_autofocus_target) + .def_readwrite("camera_path", &Testbed::m_camera_path) + .def_readwrite("record_camera_path", &Testbed::m_record_camera_path) + .def_readwrite("floor_enable", &Testbed::m_floor_enable) + .def_readwrite("exposure", &Testbed::m_exposure) + .def_property("scale", &Testbed::scale, &Testbed::set_scale) + .def_readonly("bounding_radius", &Testbed::m_bounding_radius) + .def_readwrite("render_aabb", &Testbed::m_render_aabb) + .def_readwrite("render_aabb_to_local", &Testbed::m_render_aabb_to_local) + .def_readwrite("is_rendering", &Testbed::m_render) + .def_readwrite("aabb", &Testbed::m_aabb) + .def_readwrite("raw_aabb", &Testbed::m_raw_aabb) + .def_property("fov", &Testbed::fov, &Testbed::set_fov) + .def_property("fov_xy", &Testbed::fov_xy, &Testbed::set_fov_xy) + .def_readwrite("fov_axis", &Testbed::m_fov_axis) + .def_readwrite("relative_focal_length", &Testbed::m_relative_focal_length) + .def_readwrite("zoom", &Testbed::m_zoom) + .def_readwrite("screen_center", &Testbed::m_screen_center) + .def_readwrite("camera_matrix", &Testbed::m_camera) + .def_readwrite("up_dir", &Testbed::m_up_dir) + .def_readwrite("sun_dir", &Testbed::m_sun_dir) + .def_readwrite("default_camera", &Testbed::m_default_camera) + .def_property("look_at", &Testbed::look_at, &Testbed::set_look_at) + .def_property("view_dir", &Testbed::view_dir, &Testbed::set_view_dir) + .def_readwrite("camera_smoothing", &Testbed::m_camera_smoothing) + .def_readwrite("render_with_lens_distortion", &Testbed::m_render_with_lens_distortion) + .def_readwrite("render_lens", &Testbed::m_render_lens) + .def_property( + "display_gui", + [](py::object& obj) { return obj.cast().m_imgui.mode == Testbed::ImGuiMode::Enabled; }, + [](const py::object& obj, bool value) { + obj.cast().m_imgui.mode = value ? Testbed::ImGuiMode::Enabled : Testbed::ImGuiMode::Disabled; + } + ) + .def_property( + "video_path", + [](Testbed& obj) { return obj.m_imgui.video_path; }, + [](Testbed& obj, const std::string& value) { + if (value.size() > Testbed::ImGuiVars::MAX_PATH_LEN) + throw std::runtime_error{"Video path is too long."}; + strcpy(obj.m_imgui.video_path, value.c_str()); + } + ) + .def_readwrite("visualize_unit_cube", &Testbed::m_visualize_unit_cube) + .def_readwrite("snap_to_pixel_centers", &Testbed::m_snap_to_pixel_centers) + .def_readwrite("parallax_shift", &Testbed::m_parallax_shift) + .def_readwrite("color_space", &Testbed::m_color_space) + .def_readwrite("tonemap_curve", &Testbed::m_tonemap_curve) + .def_property( + "dlss", + [](py::object& obj) { return obj.cast().m_dlss; }, + [](const py::object& obj, bool value) { + if (value && !obj.cast().m_dlss_provider) { + if (obj.cast().m_render_window) { + throw std::runtime_error{"DLSS not supported."}; + } else { + throw std::runtime_error{"DLSS requires a Window to be initialized via `init_window`."}; + } + } + + obj.cast().m_dlss = value; + } + ) + .def_readwrite("dlss_sharpening", &Testbed::m_dlss_sharpening) + .def_property( + "root_dir", + [](py::object& obj) { return obj.cast().root_dir().str(); }, + [](const py::object& obj, const std::string& value) { obj.cast().set_root_dir(value); } + ); + + py::enum_(m, "Gen3cCameraSource") + .value("Fake", EGen3cCameraSource::Fake) + .value("Viewpoint", EGen3cCameraSource::Viewpoint) + .value("Authored", EGen3cCameraSource::Authored); + + testbed + .def( + "set_gen3c_cb", + [](Testbed& testbed, const Testbed::gen3c_cb_t& cb) { + // testbed.m_gen3c_cb.reset(cb); + testbed.m_gen3c_cb = cb; + } + ) + .def_readwrite("gen3c_info", &Testbed::m_gen3c_info) + .def_readwrite("gen3c_seed_path", &Testbed::m_gen3c_seed_path) + .def_readwrite("gen3c_auto_inference", &Testbed::m_gen3c_auto_inference) + .def_readwrite("gen3c_camera_source", &Testbed::m_gen3c_camera_source) + .def_readwrite("gen3c_translation_speed", &Testbed::m_gen3c_translation_speed) + .def_readwrite("gen3c_rotation_speed", &Testbed::m_gen3c_rotation_speed) + .def_readwrite("gen3c_inference_info", &Testbed::m_gen3c_inference_info) + .def_readwrite("gen3c_seeding_progress", &Testbed::m_gen3c_seeding_progress) + .def_readwrite("gen3c_inference_progress", &Testbed::m_gen3c_inference_progress) + .def_readwrite("gen3c_inference_is_connected", &Testbed::m_gen3c_inference_is_connected) + .def_readwrite("gen3c_render_with_gen3c", &Testbed::m_gen3c_render_with_gen3c) + // Output + .def_readwrite("gen3c_save_frames", &Testbed::m_gen3c_save_frames) + .def_readwrite("gen3c_display_frames", &Testbed::m_gen3c_display_frames) + .def_readwrite("gen3c_output_dir", &Testbed::m_gen3c_output_dir) + .def_readwrite("gen3c_show_cache_renderings", &Testbed::m_gen3c_show_cache_renderings); + + py::class_(m, "CameraKeyframe") + .def(py::init<>()) + .def( + py::init(), + py::arg("r"), + py::arg("t"), + py::arg("fov"), + py::arg("timestamp") + ) + .def( + py::init(), + py::arg("m"), + py::arg("fov"), + py::arg("timestamp") + ) + .def_readwrite("R", &CameraKeyframe::R) + .def_readwrite("T", &CameraKeyframe::T) + .def_readwrite("fov", &CameraKeyframe::fov) + .def_readwrite("timestamp", &CameraKeyframe::timestamp) + .def("m", &CameraKeyframe::m) + .def("from_m", &CameraKeyframe::from_m, py::arg("rv")) + .def("same_pos_as", &CameraKeyframe::same_pos_as, py::arg("rhs")); + + py::enum_(m, "EditingKernel") + .value("None", EEditingKernel::None) + .value("Gaussian", EEditingKernel::Gaussian) + .value("Quartic", EEditingKernel::Quartic) + .value("Hat", EEditingKernel::Hat) + .value("Box", EEditingKernel::Box); + + py::class_(m, "CameraPathRenderSettings") + .def_readwrite("resolution", &CameraPath::RenderSettings::resolution) + .def_readwrite("spp", &CameraPath::RenderSettings::spp) + .def_readwrite("fps", &CameraPath::RenderSettings::fps) + .def_readwrite("shutter_fraction", &CameraPath::RenderSettings::shutter_fraction) + .def_readwrite("quality", &CameraPath::RenderSettings::quality); + + py::class_(m, "CameraPathPos").def_readwrite("kfidx", &CameraPath::Pos::kfidx).def_readwrite("t", &CameraPath::Pos::t); + + py::class_(m, "CameraPath") + .def_readwrite("keyframes", &CameraPath::keyframes) + .def_readwrite("update_cam_from_path", &CameraPath::update_cam_from_path) + .def_readwrite("play_time", &CameraPath::play_time) + .def_readwrite("auto_play_speed", &CameraPath::auto_play_speed) + .def_readwrite("default_duration_seconds", &CameraPath::default_duration_seconds) + .def_readwrite("loop", &CameraPath::loop) + .def_readwrite("keyframe_subsampling", &CameraPath::keyframe_subsampling) + .def_property("duration_seconds", &CameraPath::duration_seconds, &CameraPath::set_duration_seconds) + .def_readwrite("editing_kernel_type", &CameraPath::editing_kernel_type) + .def_readwrite("editing_kernel_radius", &CameraPath::editing_kernel_radius) + .def_readwrite("spline_order", &CameraPath::spline_order) + .def_readwrite("render_settings", &CameraPath::render_settings) + .def_readwrite("rendering", &CameraPath::rendering) + .def_readwrite("render_frame_idx", &CameraPath::render_frame_idx) + .def_readwrite("render_start_time", &CameraPath::render_start_time) + .def_readwrite("render_frame_end_camera", &CameraPath::render_frame_end_camera) + .def("clear", &CameraPath::clear) + .def("has_valid_timestamps", &CameraPath::has_valid_timestamps) + .def("make_keyframe_timestamps_equidistant", &CameraPath::make_keyframe_timestamps_equidistant) + .def("sanitize_keyframes", &CameraPath::sanitize_keyframes) + .def("get_pos", &CameraPath::get_pos, py::arg("playtime")) + .def("get_playtime", &CameraPath::get_playtime, py::arg("i")) + .def("get_keyframe", &CameraPath::get_keyframe, py::arg("i")) + .def("eval_camera_path", &CameraPath::eval_camera_path, py::arg("t")) + .def("save", &CameraPath::save, py::arg("path")) + .def("load", &CameraPath::load, py::arg("path"), py::arg("first_xform")) + .def( + "add_camera", + &CameraPath::add_camera, + py::arg("camera"), + py::arg("fov"), + py::arg("timestamp") + ); + + // Minimal logging framework (tlog) + // https://github.com/Tom94/tinylogger/ + py::module_ tlog = m.def_submodule("tlog", "Tiny logging framework"); + tlog.def("none", [](const std::string &s) { tlog::none() << s; }) + .def("info", [](const std::string &s) { tlog::info() << s; }) + .def("debug", [](const std::string &s) { tlog::debug() << s; }) + .def("warning", [](const std::string &s) { tlog::warning() << s; }) + .def("error", [](const std::string &s) { tlog::error() << s; }) + .def("success", [](const std::string &s) { tlog::success() << s; }); +} + +} // namespace ngp diff --git a/gui/src/render_buffer.cu b/gui/src/render_buffer.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e1812b5b0a7b1832c81dcf0e69012d26093cc39 --- /dev/null +++ b/gui/src/render_buffer.cu @@ -0,0 +1,842 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file render_buffer.cu + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#include +#include +#include + +#include + +#include + +#ifdef NGP_GUI +# ifdef _WIN32 +# include +# else +# include +# endif +# include +# include +#endif + +#include + +namespace ngp { + +extern std::atomic g_total_n_bytes_allocated; + +void CudaSurface2D::free() { + if (m_surface) { + cudaDestroySurfaceObject(m_surface); + } + m_surface = 0; + if (m_array) { + cudaFreeArray(m_array); + g_total_n_bytes_allocated -= product(m_size) * sizeof(float) * m_n_channels; + } + m_array = nullptr; + m_size = ivec2(0); + m_n_channels = 0; +} + +void CudaSurface2D::resize(const ivec2& size, int n_channels) { + if (size == m_size && n_channels == m_n_channels) { + return; + } + + free(); + + cudaChannelFormatDesc desc; + switch (n_channels) { + case 1: desc = cudaCreateChannelDesc(); break; + case 2: desc = cudaCreateChannelDesc(); break; + case 3: desc = cudaCreateChannelDesc(); break; + case 4: desc = cudaCreateChannelDesc(); break; + default: throw std::runtime_error{fmt::format("CudaSurface2D: unsupported number of channels {}", n_channels)}; + } + CUDA_CHECK_THROW(cudaMallocArray(&m_array, &desc, size.x, size.y, cudaArraySurfaceLoadStore)); + + g_total_n_bytes_allocated += product(m_size) * sizeof(float) * n_channels; + + struct cudaResourceDesc resource_desc; + memset(&resource_desc, 0, sizeof(resource_desc)); + resource_desc.resType = cudaResourceTypeArray; + resource_desc.res.array.array = m_array; + CUDA_CHECK_THROW(cudaCreateSurfaceObject(&m_surface, &resource_desc)); + + m_size = size; + m_n_channels = n_channels; +} + +#ifdef NGP_GUI +GLTexture::~GLTexture() { + m_cuda_mapping.reset(); + if (m_texture_id) { + glDeleteTextures(1, &m_texture_id); + } +} + +GLuint GLTexture::texture() { + if (!m_texture_id) { + glGenTextures(1, &m_texture_id); + } + + return m_texture_id; +} + +cudaSurfaceObject_t GLTexture::surface() { + if (!m_cuda_mapping) { + m_cuda_mapping = std::make_unique(texture(), m_size, m_n_channels); + } + return m_cuda_mapping->surface(); +} + +cudaArray_t GLTexture::array() { + if (!m_cuda_mapping) { + m_cuda_mapping = std::make_unique(texture(), m_size, m_n_channels); + } + return m_cuda_mapping->array(); +} + +void GLTexture::blit_from_cuda_mapping() { + if (!m_cuda_mapping || m_cuda_mapping->is_interop()) { + return; + } + + if (m_is_8bit) { + throw std::runtime_error{"Can only blit from CUDA mapping if the texture is float."}; + } + + const float* data_cpu = m_cuda_mapping->data_cpu(); + + glBindTexture(GL_TEXTURE_2D, m_texture_id); + glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, m_size.x, m_size.y, 0, m_format, GL_FLOAT, data_cpu); +} + +void GLTexture::load(const fs::path& path) { + uint8_t* out; // width * height * RGBA + int comp, width, height; + out = load_stbi(path, &width, &height, &comp, 4); + if (!out) { + throw std::runtime_error{std::string{stbi_failure_reason()}}; + } + ScopeGuard mem_guard{[&]() { stbi_image_free(out); }}; + load(out, {width, height}, 4); +} + +void GLTexture::load(const float* data, ivec2 new_size, int n_channels) { + resize(new_size, n_channels, false); + + glBindTexture(GL_TEXTURE_2D, m_texture_id); + glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, new_size.x, new_size.y, 0, m_format, GL_FLOAT, data); +} + +void GLTexture::load(const uint8_t* data, ivec2 new_size, int n_channels) { + resize(new_size, n_channels, true); + + glBindTexture(GL_TEXTURE_2D, m_texture_id); + glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, new_size.x, new_size.y, 0, m_format, GL_UNSIGNED_BYTE, data); +} + +void GLTexture::resize(const ivec2& new_size, int n_channels, bool is_8bit) { + if (m_size == new_size && m_n_channels == n_channels && m_is_8bit == is_8bit) { + return; + } + + if (m_texture_id) { + m_cuda_mapping.reset(); + glDeleteTextures(1, &m_texture_id); + m_texture_id = 0; + } + + glGenTextures(1, &m_texture_id); + glBindTexture(GL_TEXTURE_2D, m_texture_id); + + switch (n_channels) { + case 1: + m_internal_format = is_8bit ? GL_R8 : GL_R32F; + m_format = GL_RED; + break; + case 2: + m_internal_format = is_8bit ? GL_RG8 : GL_RG32F; + m_format = GL_RG; + break; + case 3: + m_internal_format = is_8bit ? GL_RGB8 : GL_RGB32F; + m_format = GL_RGB; + break; + case 4: + m_internal_format = is_8bit ? GL_RGBA8 : GL_RGBA32F; + m_format = GL_RGBA; + break; + default: throw std::runtime_error{fmt::format("GLTexture: unsupported number of channels {}", n_channels)}; + } + m_is_8bit = is_8bit; + m_size = new_size; + m_n_channels = n_channels; + + glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, new_size.x, new_size.y, 0, m_format, is_8bit ? GL_UNSIGNED_BYTE : GL_FLOAT, nullptr); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); +} + +GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const ivec2& size, int n_channels) : m_size{size}, m_n_channels{n_channels} { + static bool s_is_cuda_interop_supported = !is_wsl(); + if (s_is_cuda_interop_supported) { + cudaError_t err = + cudaGraphicsGLRegisterImage(&m_graphics_resource, texture_id, GL_TEXTURE_2D, cudaGraphicsRegisterFlagsSurfaceLoadStore); + if (err != cudaSuccess) { + s_is_cuda_interop_supported = false; + cudaGetLastError(); // Reset error + } + } + + if (!s_is_cuda_interop_supported) { + // falling back to a regular cuda surface + CPU copy of data + m_cuda_surface = std::make_unique(); + m_cuda_surface->resize(size, n_channels); + m_data_cpu.resize(product(m_size) * n_channels); + return; + } + + CUDA_CHECK_THROW(cudaGraphicsMapResources(1, &m_graphics_resource)); + CUDA_CHECK_THROW(cudaGraphicsSubResourceGetMappedArray(&m_mapped_array, m_graphics_resource, 0, 0)); + + struct cudaResourceDesc resource_desc; + memset(&resource_desc, 0, sizeof(resource_desc)); + resource_desc.resType = cudaResourceTypeArray; + resource_desc.res.array.array = m_mapped_array; + + CUDA_CHECK_THROW(cudaCreateSurfaceObject(&m_surface, &resource_desc)); +} + +GLTexture::CUDAMapping::~CUDAMapping() { + if (m_surface) { + cudaDestroySurfaceObject(m_surface); + cudaGraphicsUnmapResources(1, &m_graphics_resource); + cudaGraphicsUnregisterResource(m_graphics_resource); + } +} + +const float* GLTexture::CUDAMapping::data_cpu() { + CUDA_CHECK_THROW(cudaMemcpy2DFromArray( + m_data_cpu.data(), m_size.x * sizeof(float) * m_n_channels, array(), 0, 0, m_size.x * sizeof(float) * m_n_channels, m_size.y, cudaMemcpyDeviceToHost + )); + return m_data_cpu.data(); +} + +bool check_shader(uint32_t handle, const char* desc, bool program) { + GLint status = 0, log_length = 0; + if (program) { + glGetProgramiv(handle, GL_LINK_STATUS, &status); + glGetProgramiv(handle, GL_INFO_LOG_LENGTH, &log_length); + } else { + glGetShaderiv(handle, GL_COMPILE_STATUS, &status); + glGetShaderiv(handle, GL_INFO_LOG_LENGTH, &log_length); + } + + if ((GLboolean)status == GL_FALSE) { + tlog::error() << "Failed to compile shader: " << desc; + } + + if (log_length > 1) { + std::vector log; + log.resize(log_length + 1); + if (program) { + glGetProgramInfoLog(handle, log_length, NULL, (GLchar*)log.data()); + } else { + glGetShaderInfoLog(handle, log_length, NULL, (GLchar*)log.data()); + } + log.back() = 0; + tlog::error() << log.data(); + } + + return (GLboolean)status == GL_TRUE; +} + +uint32_t compile_shader(bool pixel, const char* code) { + GLuint g_VertHandle = glCreateShader(pixel ? GL_FRAGMENT_SHADER : GL_VERTEX_SHADER); + const char* glsl_version = "#version 140\n"; + const GLchar* strings[2] = {glsl_version, code}; + glShaderSource(g_VertHandle, 2, strings, NULL); + glCompileShader(g_VertHandle); + + if (!check_shader(g_VertHandle, pixel ? "pixel" : "vertex", false)) { + glDeleteShader(g_VertHandle); + return 0; + } + + return g_VertHandle; +} +#endif // NGP_GUI + +__global__ void accumulate_kernel(ivec2 resolution, vec4* frame_buffer, vec4* accumulate_buffer, float sample_count, EColorSpace color_space) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + uint32_t idx = x + resolution.x * y; + + vec4 color = frame_buffer[idx]; + vec4 tmp = accumulate_buffer[idx]; + + switch (color_space) { + case EColorSpace::VisPosNeg: { + float val = color.x - color.y; + float tmp_val = tmp.x - tmp.y; + + tmp_val = (tmp_val * sample_count + val) / (sample_count + 1); + + tmp.x = fmaxf(tmp_val, 0.0f); + tmp.y = fmaxf(-tmp_val, 0.0f); + break; + } + case EColorSpace::SRGB: + color.rgb() = linear_to_srgb(color.rgb()); + // fallthrough is intended! + case EColorSpace::Linear: tmp.rgb() = (tmp.rgb() * sample_count + color.rgb()) / (sample_count + 1); break; + } + + tmp.a = (tmp.a * sample_count + color.a) / (sample_count + 1); + accumulate_buffer[idx] = tmp; +} + +__device__ vec3 tonemap(vec3 x, ETonemapCurve curve) { + if (curve == ETonemapCurve::Identity) { + return x; + } + + x = max(x, vec3(0.0f)); + + float k0, k1, k2, k3, k4, k5; + if (curve == ETonemapCurve::ACES) { + // Source: ACES approximation : https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/ + // Include pre - exposure cancelation in constants + k0 = 0.6f * 0.6f * 2.51f; + k1 = 0.6f * 0.03f; + k2 = 0.0f; + k3 = 0.6f * 0.6f * 2.43f; + k4 = 0.6f * 0.59f; + k5 = 0.14f; + } else if (curve == ETonemapCurve::Hable) { + // Source: https://64.github.io/tonemapping/ + const float A = 0.15f; + const float B = 0.50f; + const float C = 0.10f; + const float D = 0.20f; + const float E = 0.02f; + const float F = 0.30f; + k0 = A * F - A * E; + k1 = C * B * F - B * E; + k2 = 0.0f; + k3 = A * F; + k4 = B * F; + k5 = D * F * F; + + const float W = 11.2f; + const float nom = k0 * (W * W) + k1 * W + k2; + const float denom = k3 * (W * W) + k4 * W + k5; + const float white_scale = denom / nom; + + // Include white scale and exposure bias in rational polynomial coefficients + k0 = 4.0f * k0 * white_scale; + k1 = 2.0f * k1 * white_scale; + k2 = k2 * white_scale; + k3 = 4.0f * k3; + k4 = 2.0f * k4; + } else { // if (curve == ETonemapCurve::Reinhard) + const vec3 luminance_coefficients = {0.2126f, 0.7152f, 0.0722f}; + float Y = dot(luminance_coefficients, x); + + return x * (1.f / (Y + 1.0f)); + } + + vec3 color_sq = x * x; + vec3 nom = color_sq * k0 + k1 * x + k2; + vec3 denom = k3 * color_sq + k4 * x + k5; + + vec3 tonemapped_color = nom / denom; + + return tonemapped_color; +} + +__device__ vec3 tonemap(vec3 col, const vec3& exposure, ETonemapCurve tonemap_curve, EColorSpace color_space, EColorSpace output_color_space) { + // Conversion to output by + // 1. converting to linear. (VisPosNeg is treated as linear red/green) + if (color_space == EColorSpace::SRGB) { + col = srgb_to_linear(col); + } + + // 2. applying exposure in linear space + col *= pow(vec3(2.0f), exposure); + + // 3. tonemapping in linear space according to the specified curve + col = tonemap(col, tonemap_curve); + + // 4. converting to output color space. + if (output_color_space == EColorSpace::SRGB) { + col = linear_to_srgb(col); + } + + return col; +} + +__global__ void overlay_image_kernel( + ivec2 resolution, + float alpha, + vec3 exposure, + vec4 background_color, + const void* __restrict__ image, + EImageDataType image_data_type, + ivec2 image_resolution, + ETonemapCurve tonemap_curve, + EColorSpace color_space, + EColorSpace output_color_space, + int fov_axis, + float zoom, + vec2 screen_center, + cudaSurfaceObject_t surface +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + float scale = image_resolution[fov_axis] / float(resolution[fov_axis]); + + float fx = x + 0.5f; + float fy = y + 0.5f; + + fx -= resolution.x * 0.5f; + fx /= zoom; + fx += screen_center.x * resolution.x; + fy -= resolution.y * 0.5f; + fy /= zoom; + fy += screen_center.y * resolution.y; + + float u = (fx - resolution.x * 0.5f) * scale + image_resolution.x * 0.5f; + float v = (fy - resolution.y * 0.5f) * scale + image_resolution.y * 0.5f; + + int srcx = floorf(u); + int srcy = floorf(v); + + vec4 val; + if (srcx >= image_resolution.x || srcy >= image_resolution.y || srcx < 0 || srcy < 0) { + val = vec4(0.0f); + } else { + val = read_rgba(ivec2{srcx, srcy}, image_resolution, image, image_data_type); + } + + vec4 color = {val[0], val[1], val[2], val[3]}; + + // The background color is represented in SRGB, so convert + // to linear if that's not the space in which we're rendering. + if (color_space != EColorSpace::SRGB) { + background_color.xyz() = srgb_to_linear(background_color.xyz()); + } else { + if (color.a > 0) { + color.rgb() = linear_to_srgb(color.rgb() / color.a) * color.a; + } else { + color.rgb() = vec3(0.0f); + } + } + + float weight = (1 - color.a) * background_color.a; + color.rgb() += background_color.rgb() * weight; + color.a += weight; + + color.rgb() = tonemap(color.rgb(), exposure, tonemap_curve, color_space, output_color_space); + + vec4 prev_color; + surf2Dread((float4*)&prev_color, surface, x * sizeof(float4), y); + color = color * alpha + prev_color * (1.f - alpha); + surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y); +} + +__global__ void overlay_depth_kernel( + ivec2 resolution, + float alpha, + const float* __restrict__ depth, + float depth_scale, + ivec2 image_resolution, + int fov_axis, + float zoom, + vec2 screen_center, + cudaSurfaceObject_t surface +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + float scale = image_resolution[fov_axis] / float(resolution[fov_axis]); + + float fx = x + 0.5f; + float fy = y + 0.5f; + + fx -= resolution.x * 0.5f; + fx /= zoom; + fx += screen_center.x * resolution.x; + fy -= resolution.y * 0.5f; + fy /= zoom; + fy += screen_center.y * resolution.y; + + float u = (fx - resolution.x * 0.5f) * scale + image_resolution.x * 0.5f; + float v = (fy - resolution.y * 0.5f) * scale + image_resolution.y * 0.5f; + + int srcx = floorf(u); + int srcy = floorf(v); + uint32_t srcidx = srcx + image_resolution.x * srcy; + + vec4 color; + if (srcx >= image_resolution.x || srcy >= image_resolution.y || srcx < 0 || srcy < 0) { + color = {0.0f, 0.0f, 0.0f, 0.0f}; + } else { + float depth_value = depth[srcidx] * depth_scale; + vec3 c = colormap_turbo(depth_value); + color = {c[0], c[1], c[2], 1.0f}; + } + + vec4 prev_color; + surf2Dread((float4*)&prev_color, surface, x * sizeof(float4), y); + color = color * alpha + prev_color * (1.f - alpha); + surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y); +} + +__device__ vec3 colormap_viridis(float x) { + const vec3 c0 = vec3{0.2777273272234177f, 0.005407344544966578f, 0.3340998053353061f}; + const vec3 c1 = vec3{0.1050930431085774f, 1.404613529898575f, 1.384590162594685f}; + const vec3 c2 = vec3{-0.3308618287255563f, 0.214847559468213f, 0.09509516302823659f}; + const vec3 c3 = vec3{-4.634230498983486f, -5.799100973351585f, -19.33244095627987f}; + const vec3 c4 = vec3{6.228269936347081f, 14.17993336680509f, 56.69055260068105f}; + const vec3 c5 = vec3{4.776384997670288f, -13.74514537774601f, -65.35303263337234f}; + const vec3 c6 = vec3{-5.435455855934631f, 4.645852612178535f, 26.3124352495832f}; + x = __saturatef(x); + return (c0 + x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * (c5 + x * c6)))))); +} + +__global__ void overlay_false_color_kernel( + ivec2 resolution, + ivec2 training_resolution, + bool to_srgb, + int fov_axis, + cudaSurfaceObject_t surface, + const float* error_map, + ivec2 error_map_resolution, + const float* average, + float brightness, + bool viridis +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + float error_map_scale = brightness / (0.0000001f + average[0]); // average maps to 1/16th + + float scale = training_resolution[fov_axis] / float(resolution[fov_axis]); + float u = (x + 0.5f - resolution.x * 0.5f) * scale + training_resolution.x * 0.5f; + float v = (y + 0.5f - resolution.y * 0.5f) * scale + training_resolution.y * 0.5f; + int srcx = floor(u * error_map_resolution.x / float(max(1.f, (float)training_resolution.x))); + int srcy = floor(v * error_map_resolution.y / float(max(1.f, (float)training_resolution.y))); + + uint32_t srcidx = srcx + error_map_resolution.x * srcy; + + if (srcx >= error_map_resolution.x || srcy >= error_map_resolution.y || srcx < 0 || srcy < 0) { + return; + } + + float err = error_map[srcidx] * error_map_scale; + if (viridis) { + err *= 1.f / (1.f + err); + } + vec4 color; + surf2Dread((float4*)&color, surface, x * sizeof(float4), y); + vec3 c = viridis ? colormap_viridis(err) : colormap_turbo(err); + float grey = color.x * 0.2126f + color.y * 0.7152f + color.z * 0.0722f; + color.x = grey * __saturatef(c.x); + color.y = grey * __saturatef(c.y); + color.z = grey * __saturatef(c.z); + + surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y); +} + +__global__ void tonemap_kernel( + ivec2 resolution, + float exposure, + vec4 background_color, + vec4* accumulate_buffer, + EColorSpace color_space, + EColorSpace output_color_space, + ETonemapCurve tonemap_curve, + bool clamp_output_color, + bool unmultiply_alpha, + cudaSurfaceObject_t surface +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + uint32_t idx = x + resolution.x * y; + + // The background color is represented in SRGB, so convert + // to linear if that's not the space in which we're rendering. + if (color_space != EColorSpace::SRGB) { + background_color.rgb() = srgb_to_linear(background_color.rgb()); + } + + vec4 color = accumulate_buffer[idx]; + float weight = (1 - color.a) * background_color.a; + color.rgb() += background_color.rgb() * weight; + color.a += weight; + + color.rgb() = tonemap(color.rgb(), vec3(exposure), tonemap_curve, color_space, output_color_space); + + if (unmultiply_alpha && color.a > 0.0f) { + color.rgb() = color.rgb() / color.a; + } + + if (clamp_output_color) { + color = clamp(color, vec4(0.0f), vec4(1.0f)); + } + + surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y); +} + +__global__ void dlss_splat_kernel(ivec2 resolution, cudaSurfaceObject_t dlss_surface, cudaSurfaceObject_t surface) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + float4 color; + surf2Dread(&color, dlss_surface, x * sizeof(float4), y); + + // DLSS operates on non-premultiplied alpha, so multiply it back in + color.x *= color.w; + color.y *= color.w; + color.z *= color.w; + surf2Dwrite(color, surface, x * sizeof(float4), y); +} + +__global__ void depth_splat_kernel(ivec2 resolution, float znear, float zfar, float* __restrict__ depth_buffer, cudaSurfaceObject_t surface) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + uint32_t idx = x + resolution.x * y; + surf2Dwrite(to_ndc_depth(depth_buffer[idx], znear, zfar), surface, x * sizeof(float), y); +} + +void CudaRenderBufferView::clear(cudaStream_t stream) const { + size_t n_pixels = product(resolution); + CUDA_CHECK_THROW(cudaMemsetAsync(frame_buffer, 0, n_pixels * sizeof(vec4), stream)); + CUDA_CHECK_THROW(cudaMemsetAsync(depth_buffer, 0, n_pixels * sizeof(float), stream)); +} + +void CudaRenderBuffer::resize(const ivec2& res) { + m_in_resolution = res; + m_frame_buffer.enlarge(res.x * res.y); + m_depth_buffer.enlarge(res.x * res.y); + if (m_depth_target) { + m_depth_target->resize(res, 1); + } + m_accumulate_buffer.enlarge(res.x * res.y); + + ivec2 out_res = m_dlss ? m_dlss->out_resolution() : res; + auto prev_out_res = out_resolution(); + m_rgba_target->resize(out_res, 4); + + if (out_resolution() != prev_out_res) { + reset_accumulation(); + } +} + +void CudaRenderBuffer::clear_frame(cudaStream_t stream) { view().clear(stream); } + +void CudaRenderBuffer::accumulate(float exposure, cudaStream_t stream) { + ivec2 res = in_resolution(); + + uint32_t accum_spp = m_dlss ? 0 : m_spp; + + if (accum_spp == 0) { + CUDA_CHECK_THROW(cudaMemsetAsync(m_accumulate_buffer.data(), 0, m_accumulate_buffer.bytes(), stream)); + } + + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + accumulate_kernel<<>>(res, frame_buffer(), accumulate_buffer(), (float)accum_spp, m_color_space); + + ++m_spp; +} + +void CudaRenderBuffer::tonemap( + float exposure, const vec4& background_color, EColorSpace output_color_space, float znear, float zfar, bool snap_to_pixel_centers, cudaStream_t stream +) { + assert(m_dlss || out_resolution() == in_resolution()); + + auto res = in_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + tonemap_kernel<<>>( + res, + exposure, + background_color, + accumulate_buffer(), + m_color_space, + output_color_space, + m_tonemap_curve, + m_dlss && output_color_space == EColorSpace::SRGB, + (bool)m_dlss, // DLSS seems to perform best with non-premultiplied alpha (probably trained on such data) + m_dlss ? m_dlss->frame() : surface() + ); + + if (m_dlss) { + assert(out_resolution() == m_dlss->out_resolution()); + + assert(m_spp >= 1); + uint32_t sample_index = m_spp - 1; + + m_dlss->run( + res, + output_color_space == EColorSpace::Linear, /* HDR mode */ + m_dlss_sharpening, + vec2(0.5f) - ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index), /* jitter offset in [-0.5, 0.5] */ + sample_index == 0 /* reset history */ + ); + + auto out_res = out_resolution(); + const dim3 out_blocks = {div_round_up((uint32_t)out_res.x, threads.x), div_round_up((uint32_t)out_res.y, threads.y), 1}; + dlss_splat_kernel<<>>(out_res, m_dlss->output(), surface()); + } + + if (m_depth_target) { + depth_splat_kernel<<>>(res, znear, zfar, depth_buffer(), m_depth_target->surface()); + } +} + +void CudaRenderBuffer::overlay_image( + float alpha, + const vec3& exposure, + const vec4& background_color, + EColorSpace output_color_space, + const void* __restrict__ image, + EImageDataType image_data_type, + const ivec2& image_resolution, + int fov_axis, + float zoom, + const vec2& screen_center, + cudaStream_t stream +) { + auto res = out_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + overlay_image_kernel<<>>( + res, + alpha, + exposure, + background_color, + image, + image_data_type, + image_resolution, + m_tonemap_curve, + m_color_space, + output_color_space, + fov_axis, + zoom, + screen_center, + surface() + ); +} + +void CudaRenderBuffer::overlay_depth( + float alpha, + const float* __restrict__ depth, + float depth_scale, + const ivec2& image_resolution, + int fov_axis, + float zoom, + const vec2& screen_center, + cudaStream_t stream +) { + auto res = out_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + overlay_depth_kernel<<>>( + res, alpha, depth, depth_scale, image_resolution, fov_axis, zoom, screen_center, surface() + ); +} + +void CudaRenderBuffer::overlay_false_color( + ivec2 training_resolution, + bool to_srgb, + int fov_axis, + cudaStream_t stream, + const float* error_map, + ivec2 error_map_resolution, + const float* average, + float brightness, + bool viridis +) { + auto res = out_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + overlay_false_color_kernel<<>>( + res, training_resolution, to_srgb, fov_axis, surface(), error_map, error_map_resolution, average, brightness, viridis + ); +} + +void CudaRenderBuffer::enable_dlss(IDlssProvider& dlss_provider, const ivec2& max_out_res) { +#ifdef NGP_VULKAN + if (!m_dlss || m_dlss->max_out_resolution() != max_out_res) { + m_dlss = dlss_provider.init_dlss(max_out_res); + } + + if (m_dlss) { + resize(m_dlss->clamp_resolution(in_resolution())); + } +#else + throw std::runtime_error{"NGP was compiled without Vulkan/NGX/DLSS support."}; +#endif +} + +void CudaRenderBuffer::disable_dlss() { m_dlss = nullptr; } + +} // namespace ngp diff --git a/gui/src/testbed.cu b/gui/src/testbed.cu new file mode 100644 index 0000000000000000000000000000000000000000..3401254a0ec47dd476bde928538ea30ac53e2225 --- /dev/null +++ b/gui/src/testbed.cu @@ -0,0 +1,3843 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file testbed.cu + * @author Thomas Müller & Alex Evans, NVIDIA + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include + +#include +#include + +#ifdef NGP_GUI +# include +# include +# include +# include +# include +# ifdef _WIN32 +# include +# else +# include +# endif +# include +# include +# include +#endif + +// Windows.h is evil +#undef min +#undef max +#undef near +#undef far + +using namespace std::literals::chrono_literals; +using nlohmann::json; + +namespace ngp { + +int do_system(const std::string& cmd) { +#ifdef _WIN32 + tlog::info() << "> " << cmd; + return _wsystem(utf8_to_utf16(cmd).c_str()); +#else + tlog::info() << "$ " << cmd; + return system(cmd.c_str()); +#endif +} + +std::atomic g_total_n_bytes_allocated{0}; + +void Testbed::update_imgui_paths() { + snprintf(m_imgui.cam_path_path, sizeof(m_imgui.cam_path_path), "%s", (root_dir() / "cam.json").str().c_str()); + snprintf(m_imgui.video_path, sizeof(m_imgui.video_path), "%s", (root_dir() / "video.json").str().c_str()); + snprintf(m_imgui.cam_export_path, sizeof(m_imgui.cam_export_path), "%s", (root_dir() / "cam_export.json").str().c_str()); +} + +void Testbed::set_mode(ETestbedMode mode) { + if (mode == m_testbed_mode) { + return; + } + + // Clear device-owned data that might be mode-specific + for (auto&& device : m_devices) { + device.clear(); + } + + m_testbed_mode = mode; + + // Set various defaults depending on mode + m_use_aux_devices = false; + + if (m_testbed_mode == ETestbedMode::Gen3c) { + if (m_dlss_provider && m_aperture_size == 0.0f) { + m_dlss = true; + } + } else { + m_dlss = false; + } + + m_reproject_enable = m_testbed_mode == ETestbedMode::Gen3c; + + reset_camera(); + +#ifdef NGP_GUI + update_vr_performance_settings(); +#endif +} + +void Testbed::load_file(const fs::path& path) { + if (!path.exists()) { + tlog::error() << "File '" << path.str() << "' does not exist."; + return; + } + + // If we get a json file, we need to parse it to determine its purpose. + if (equals_case_insensitive(path.extension(), "json")) { + json file; + { + std::ifstream f{native_string(path)}; + file = json::parse(f, nullptr, true, true); + } + + // Camera path + if (file.contains("path")) { + load_camera_path(path); + return; + } + } + + tlog::error() << "File '" << path.str() << "' is not a valid file to load."; +} + +void Testbed::reset_accumulation(bool due_to_camera_movement, bool immediate_redraw, bool reset_pip) { + if (immediate_redraw) { + redraw_next_frame(); + } + + if (!due_to_camera_movement || !reprojection_available()) { + m_windowless_render_surface.reset_accumulation(); + for (auto& view : m_views) { + view.render_buffer->reset_accumulation(); + } + } + + if (reset_pip) { + m_pip_render_buffer->reset_accumulation(); + } +} + +void Testbed::translate_camera(const vec3& rel, const mat3& rot, bool allow_up_down) { + vec3 movement = rot * rel; + if (!allow_up_down) { + movement -= dot(movement, m_up_dir) * m_up_dir; + } + + m_camera[3] += movement; + reset_accumulation(true); +} + +vec3 Testbed::look_at() const { return view_pos() + view_dir() * m_scale; } + +void Testbed::set_look_at(const vec3& pos) { m_camera[3] += pos - look_at(); } + +void Testbed::set_scale(float scale) { + auto prev_look_at = look_at(); + m_camera[3] = (view_pos() - prev_look_at) * (scale / m_scale) + prev_look_at; + m_scale = scale; +} + +void Testbed::set_view_dir(const vec3& dir) { + auto old_look_at = look_at(); + m_camera[0] = normalize(cross(dir, m_up_dir)); + m_camera[1] = normalize(cross(dir, m_camera[0])); + m_camera[2] = normalize(dir); + set_look_at(old_look_at); +} + +void Testbed::reset_camera() { + m_fov_axis = 1; + m_zoom = 1.0f; + m_screen_center = vec2(0.5f); + + set_fov(50.625f); + m_scale = 1.5f; + + m_camera = m_default_camera; + m_camera[3] -= m_scale * view_dir(); + + m_smoothed_camera = m_camera; + m_sun_dir = normalize(vec3(1.0f)); + + reset_accumulation(); +} + +fs::path Testbed::root_dir() { + if (m_root_dir.empty()) { + set_root_dir(discover_root_dir()); + } + + return m_root_dir; +} + +void Testbed::set_root_dir(const fs::path& dir) { m_root_dir = dir; } + +inline float linear_to_db(float x) { return -10.f * logf(x) / logf(10.f); } + + +#ifdef NGP_GUI +bool imgui_colored_button(const char* name, float hue) { + ImGui::PushStyleColor(ImGuiCol_Button, (ImVec4)ImColor::HSV(hue, 0.6f, 0.6f)); + ImGui::PushStyleColor(ImGuiCol_ButtonHovered, (ImVec4)ImColor::HSV(hue, 0.7f, 0.7f)); + ImGui::PushStyleColor(ImGuiCol_ButtonActive, (ImVec4)ImColor::HSV(hue, 0.8f, 0.8f)); + bool rv = ImGui::Button(name); + ImGui::PopStyleColor(3); + return rv; +} + +void Testbed::overlay_fps() { + ImGui::PushFont((ImFont*)m_imgui.overlay_font); + ImGui::SetNextWindowPos({10.0f, 10.0f}, ImGuiCond_Always, {0.0f, 0.0f}); + ImGui::SetNextWindowBgAlpha(0.35f); + if (ImGui::Begin( + "Overlay", + nullptr, + ImGuiWindowFlags_NoDecoration | ImGuiWindowFlags_AlwaysAutoResize | ImGuiWindowFlags_NoSavedSettings | + ImGuiWindowFlags_NoFocusOnAppearing | ImGuiWindowFlags_NoNav | ImGuiWindowFlags_NoMove + )) { + ImGui::Text("%.1f FPS", 1000.0f / m_render_ms.ema_val()); + } + ImGui::PopFont(); +} + +void Testbed::imgui() { + // If a GUI interaction causes an error, write that error to the following string and call + // ImGui::OpenPopup("Error"); + static std::string imgui_error_string = ""; + + m_picture_in_picture_res = 0; + + // Good default position and size for the camera path editing window + ImGui::SetNextWindowPos({10.0f, 10.0f}, ImGuiCond_FirstUseEver); + int window_width, window_height; + glfwGetWindowSize(m_glfw_window, &window_width, &window_height); + ImGui::SetNextWindowSize({420.0f, window_height - 20.0f}, ImGuiCond_FirstUseEver); + + if (ImGui::Begin("Camera path & video generation", 0, ImGuiWindowFlags_NoScrollbar)) { + if (ImGui::CollapsingHeader("Path manipulation", ImGuiTreeNodeFlags_DefaultOpen)) { + ImGui::Checkbox("Record camera path", &m_record_camera_path); + ImGui::SameLine(); + if (ImGui::Button("Clear")) { + m_camera_path.clear(); + } + + if (m_reproject_enable) { + ImGui::SameLine(); + if (ImGui::Button("Init from views")) { + init_camera_path_from_reproject_src_cameras(); + } + } + + if (int read = m_camera_path.imgui(m_imgui.cam_path_path, m_frame_ms.val(), m_camera, fov(), mat4x3::identity())) { + if (!m_camera_path.rendering || m_gen3c_render_with_gen3c) { + reset_accumulation(true); + + if (m_camera_path.update_cam_from_path) { + set_camera_from_time(m_camera_path.play_time); + + // A value of larger than 1 indicates that the camera path wants + // to override camera smoothing. + if (read > 1) { + m_smoothed_camera = m_camera; + } + } else { + m_pip_render_buffer->reset_accumulation(); + } + } + } + + if (!m_camera_path.keyframes.empty()) { + float w = ImGui::GetContentRegionAvail().x; + if (m_camera_path.update_cam_from_path) { + m_picture_in_picture_res = 0; + ImGui::Image((ImTextureID)(size_t)m_rgba_render_textures.front()->texture(), ImVec2(w, w * 9.0f / 16.0f)); + } else { + m_picture_in_picture_res = (float)std::min((int(w) + 31) & (~31), 1920 / 4); + ImGui::Image((ImTextureID)(size_t)m_pip_render_texture->texture(), ImVec2(w, w * 9.0f / 16.0f)); + } + } + } + + if (!m_camera_path.keyframes.empty() && ImGui::CollapsingHeader("Video generation", ImGuiTreeNodeFlags_DefaultOpen)) { + // Render a video + // TODO: simplify this (only allow rendering with Gen3C). + ImGui::BeginDisabled(m_camera_path.rendering); + if (imgui_colored_button(m_camera_path.rendering ? "Waiting for model..." : "Generate video", 0.4)) { + bool was_rendering = m_camera_path.rendering; + m_camera_path.rendering = !m_camera_path.rendering; + + if (m_gen3c_render_with_gen3c) { + if (m_gen3c_cb) { + m_gen3c_cb(was_rendering ? "abort_inference" : "request_inference"); + } + } else { + if (!clear_tmp_dir()) { + imgui_error_string = "Failed to clear temporary directory 'tmp' to hold rendered images."; + ImGui::OpenPopup("Error"); + + m_camera_path.rendering = false; + } + + if (m_camera_path.rendering) { + m_camera_path.render_start_time = std::chrono::steady_clock::now(); + m_camera_path.update_cam_from_path = true; + m_camera_path.play_time = 0.0f; + m_camera_path.auto_play_speed = 1.0f; + m_camera_path.render_frame_idx = 0; + + m_dlss = false; + + reset_accumulation(true); + set_camera_from_time(m_camera_path.play_time); + m_smoothed_camera = m_camera; + } else { + m_camera_path.update_cam_from_path = false; + m_camera_path.play_time = 0.0f; + m_camera_path.auto_play_speed = 0.0f; + } + } + } + ImGui::EndDisabled(); + + ImGui::SameLine(); + ImGui::BeginDisabled(!m_gen3c_inference_is_connected || !m_gen3c_cb); + ImGui::Checkbox("Gen3C inference", &m_gen3c_render_with_gen3c); + ImGui::EndDisabled(); + + if (m_camera_path.rendering) { + const auto elapsed = std::chrono::steady_clock::now() - m_camera_path.render_start_time; + + const float duration = m_camera_path.duration_seconds(); + const uint32_t progress = m_camera_path.render_frame_idx * m_camera_path.render_settings.spp + m_views.front().render_buffer->spp(); + const uint32_t goal = m_camera_path.render_settings.n_frames(duration) * m_camera_path.render_settings.spp; + const auto est_remaining = elapsed * (float)(goal - progress) / std::max(progress, 1u); + + if (m_gen3c_render_with_gen3c) { + if (!m_gen3c_inference_info.empty()) { + ImGui::TextWrapped("%s", m_gen3c_inference_info.c_str()); + } + + if (m_gen3c_inference_progress > 0) { + ImGui::ProgressBar(m_gen3c_inference_progress); + } + } else { + ImGui::Text( + "%s", + fmt::format( + "Frame {}/{}, Elapsed: {}, Remaining: {}", + m_camera_path.render_frame_idx + 1, + m_camera_path.render_settings.n_frames(duration), + tlog::durationToString(std::chrono::steady_clock::now() - m_camera_path.render_start_time), + tlog::durationToString(est_remaining) + ) + .c_str() + ); + + ImGui::ProgressBar((float)progress / goal); + } + } + + ImGui::BeginDisabled(m_camera_path.rendering); + + ImGui::Checkbox("Show rendered Gen3C cache in video", &m_gen3c_show_cache_renderings); + // Note: 3D cache visualization is incompatible with adding Gen3C frames to the viewport. + if (m_gen3c_show_cache_renderings) + m_gen3c_display_frames = false; + ImGui::BeginDisabled(m_gen3c_show_cache_renderings); + ImGui::Checkbox("Add Gen3C keyframes to viewport after inference", &m_gen3c_display_frames); + ImGui::EndDisabled(); // m_gen3c_show_cache_renderings + + ImGui::InputText("Video file##Video file path", m_imgui.video_path, sizeof(m_imgui.video_path)); + m_camera_path.render_settings.filename = m_imgui.video_path; + ImGui::SliderInt("MP4 quality", &m_camera_path.render_settings.quality, 0, 10); + + float duration_seconds = m_camera_path.duration_seconds(); + if (ImGui::InputFloat("Duration (seconds)", &duration_seconds) && duration_seconds > 0.0f) { + m_camera_path.set_duration_seconds(duration_seconds); + } + + ImGui::InputFloat("FPS (frames/second)", &m_camera_path.render_settings.fps); + + ImGui::BeginDisabled(m_gen3c_render_with_gen3c); + ImGui::InputInt2("Resolution", &m_camera_path.render_settings.resolution.x); + // ImGui::InputInt("SPP (samples/pixel)", &m_camera_path.render_settings.spp); + if (m_gen3c_render_with_gen3c) { + m_camera_path.render_settings.spp = 1; + } + // ImGui::SliderFloat("Shutter fraction", &m_camera_path.render_settings.shutter_fraction, 0.0f, 1.0f); + ImGui::EndDisabled(); // end m_gen3c_render_with_gen3c + + ImGui::EndDisabled(); // end m_camera_path.rendering + + ImGui::Spacing(); + bool export_cameras = imgui_colored_button("Export cameras", 0.7); + + ImGui::SameLine(); + + static bool w2c = false; + ImGui::Checkbox("W2C", &w2c); + + ImGui::InputText("Cameras file##Camera export path", m_imgui.cam_export_path, sizeof(m_imgui.cam_export_path)); + m_camera_path.render_settings.filename = m_imgui.video_path; + + if (export_cameras) { + std::vector cameras; + const float duration = m_camera_path.duration_seconds(); + for (uint32_t i = 0; i < m_camera_path.render_settings.n_frames(duration); ++i) { + mat4x3 start_cam = m_camera_path.eval_camera_path((float)i / (m_camera_path.render_settings.n_frames(duration))).m(); + mat4x3 end_cam = m_camera_path + .eval_camera_path( + ((float)i + m_camera_path.render_settings.shutter_fraction) / + (m_camera_path.render_settings.n_frames(duration)) + ) + .m(); + if (w2c) { + start_cam = inverse(mat4x4(start_cam)); + end_cam = inverse(mat4x4(end_cam)); + } + + cameras.push_back({ + {"start", start_cam}, + {"end", end_cam }, + }); + } + + json j; + j["cameras"] = cameras; + j["resolution"] = m_camera_path.render_settings.resolution; + j["duration_seconds"] = m_camera_path.duration_seconds(); + j["fps"] = m_camera_path.render_settings.fps; + j["spp"] = m_camera_path.render_settings.spp; + j["quality"] = m_camera_path.render_settings.quality; + j["shutter_fraction"] = m_camera_path.render_settings.shutter_fraction; + + std::ofstream f(native_string(m_imgui.cam_export_path)); + f << j; + } + } + } + ImGui::End(); + + // Good default position and size for the right-hand side window + int pane_width = 350; + ImGui::SetNextWindowPos({window_width - pane_width - 10.0f, 10.0f}, ImGuiCond_FirstUseEver); + ImGui::SetNextWindowSize({(float)pane_width, window_height - 20.0f}, ImGuiCond_FirstUseEver); + + ImGui::Begin("Gen3C v" NGP_VERSION); + + size_t n_bytes = tcnn::total_n_bytes_allocated() + g_total_n_bytes_allocated; + if (m_dlss_provider) { + n_bytes += m_dlss_provider->allocated_bytes(); + } + + ImGui::Text("Frame: %.2f ms (%.1f FPS); Mem: %s", m_frame_ms.ema_val(), 1000.0f / m_frame_ms.ema_val(), bytes_to_string(n_bytes).c_str()); + bool accum_reset = false; + + if (m_testbed_mode == ETestbedMode::Gen3c && ImGui::CollapsingHeader("Video generation server", ImGuiTreeNodeFlags_DefaultOpen)) { + ImGui::TextWrapped("%s", m_gen3c_info.c_str()); + ImGui::Spacing(); + + // Create a child box with a title and borders + if (ImGui::TreeNodeEx("Seeding", ImGuiTreeNodeFlags_DefaultOpen)) { + ImGui::TextWrapped("Enter the path to an image or a pre-processed video directory."); + ImGui::InputText("Path", &m_gen3c_seed_path); + + ImGui::BeginDisabled(m_gen3c_seed_path.empty()); + if (ImGui::Button("Seed") && m_gen3c_cb) { + m_gen3c_cb("seed_model"); + } + if (m_gen3c_seeding_progress > 0) { + ImGui::ProgressBar(m_gen3c_seeding_progress); + } + ImGui::EndDisabled(); + + ImGui::Spacing(); + ImGui::TreePop(); + } + + // ImGui::Separator(); + + // We need this to be executed even if the panel below is collapsed. + switch (m_gen3c_camera_source) { + case EGen3cCameraSource::Fake: { + m_gen3c_auto_inference = false; + break; + } + case EGen3cCameraSource::Viewpoint: { + break; + } + case EGen3cCameraSource::Authored: { + m_gen3c_auto_inference = false; + break; + } + default: throw std::runtime_error("Unsupported Gen3C camera source."); + } + + } + + if (ImGui::CollapsingHeader("Point cloud", ImGuiTreeNodeFlags_DefaultOpen)) { + // accum_reset |= ImGui::Checkbox("Enable reprojection", &m_reproject_enable); + if (m_reproject_enable) { + int max_views = (int)m_reproject_src_views.size(); + + int prev_min_src_view_index = m_reproject_min_src_view_index; + int prev_max_src_view_index = m_reproject_max_src_view_index; + int prev_n_frames_shown = std::max(0, prev_max_src_view_index - prev_min_src_view_index); + + if (ImGui::SliderInt("Min view index", &m_reproject_min_src_view_index, 0, max_views)) { + // If shift, move the range synchronously. + if (ImGui::GetIO().KeyShift) { + m_reproject_max_src_view_index = + std::min(m_reproject_max_src_view_index + m_reproject_min_src_view_index - prev_min_src_view_index, max_views); + // Keep the number of frames shown constant. + m_reproject_min_src_view_index = m_reproject_max_src_view_index - prev_n_frames_shown; + } + + // Ensure that range remains valid (max index >= min index). + m_reproject_max_src_view_index = std::max(m_reproject_max_src_view_index, m_reproject_min_src_view_index); + accum_reset = true; + } + + if (ImGui::SliderInt("Max view index", &m_reproject_max_src_view_index, 0, max_views)) { + // If shift, move the range synchronously. + if (ImGui::GetIO().KeyShift) { + m_reproject_min_src_view_index = + std::max(m_reproject_min_src_view_index + m_reproject_max_src_view_index - prev_max_src_view_index, 0); + // Keep the number of frames shown constant. + m_reproject_max_src_view_index = m_reproject_min_src_view_index + prev_n_frames_shown; + } + // Ensure that range remains valid (max index >= min index). + m_reproject_min_src_view_index = std::min(m_reproject_max_src_view_index, m_reproject_min_src_view_index); + accum_reset = true; + } + + if (max_views > 0 && ImGui::SliderInt("Snap to view", (int*)&m_reproject_selected_src_view, 0, max_views - 1)) { + m_camera = m_smoothed_camera = + m_reproject_src_views[std::min((size_t)m_reproject_selected_src_view, m_reproject_src_views.size() - 1)].camera0; + accum_reset = true; + } + + accum_reset |= ImGui::Checkbox("Visualize views", &m_reproject_visualize_src_views); + ImGui::SameLine(); + if (ImGui::Button("Delete views")) { + clear_src_views(); + } + + if (ImGui::TreeNodeEx("Advanced reprojection settings")) { + accum_reset |= ImGui::SliderFloat( + "Reproject min t", &m_reproject_min_t, 0.01f, 16.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat + ); + accum_reset |= ImGui::SliderFloat( + "Reproject scaling", &m_reproject_step_factor, 1.003f, 1.5f, "%.001f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat + ); + + accum_reset |= ImGui::Combo("Reproject render mode", (int*)&m_pm_viz_mode, PmVizModeStr); + + ImGui::TreePop(); + } + + } + } + + if (ImGui::CollapsingHeader("Rendering", m_testbed_mode == ETestbedMode::Gen3c ? 0 : ImGuiTreeNodeFlags_DefaultOpen)) { + + ImGui::Checkbox("Render", &m_render); + ImGui::SameLine(); + + const auto& render_buffer = m_views.front().render_buffer; + std::string spp_string = m_dlss ? std::string{""} : fmt::format("({} spp)", std::max(render_buffer->spp(), 1u)); + ImGui::Text( + ": %.01fms for %dx%d %s", + m_render_ms.ema_val(), + render_buffer->in_resolution().x, + render_buffer->in_resolution().y, + spp_string.c_str() + ); + + ImGui::SameLine(); + if (ImGui::Checkbox("VSync", &m_vsync)) { + glfwSwapInterval(m_vsync ? 1 : 0); + } + + + ImGui::Checkbox("Dynamic resolution", &m_dynamic_res); + ImGui::SameLine(); + ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f); + if (m_dynamic_res) { + ImGui::SliderFloat( + "Target FPS", &m_dynamic_res_target_fps, 2.0f, 144.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat + ); + } else { + ImGui::SliderInt("Resolution factor", &m_fixed_res_factor, 8, 64); + } + ImGui::PopItemWidth(); + + if (ImGui::TreeNode("Advanced rendering options")) { + accum_reset |= ImGui::Combo("Render mode", (int*)&m_render_mode, RenderModeStr); + accum_reset |= ImGui::Combo("Tonemap curve", (int*)&m_tonemap_curve, TonemapCurveStr); + accum_reset |= ImGui::ColorEdit4("Background", &m_background_color[0]); + + if (ImGui::SliderFloat("Exposure", &m_exposure, -5.f, 5.f)) { + set_exposure(m_exposure); + } + + ImGui::SliderInt("Max spp", &m_max_spp, 0, 1024, "%d", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat); + accum_reset |= ImGui::Checkbox("Render transparency as checkerboard", &m_render_transparency_as_checkerboard); + accum_reset |= ImGui::Combo("Color space", (int*)&m_color_space, ColorSpaceStr); + accum_reset |= ImGui::Checkbox("Snap to pixel centers", &m_snap_to_pixel_centers); + + ImGui::TreePop(); + } + } + + if (ImGui::CollapsingHeader("Camera")) { + ImGui::Checkbox("First person controls", &m_fps_camera); + ImGui::SameLine(); + ImGui::Checkbox("Smooth motion", &m_camera_smoothing); + + float local_fov = fov(); + if (ImGui::SliderFloat("Field of view", &local_fov, 0.0f, 120.0f)) { + set_fov(local_fov); + accum_reset = true; + } + + if (ImGui::TreeNode("Advanced camera settings")) { + accum_reset |= ImGui::SliderFloat2("Screen center", &m_screen_center.x, 0.f, 1.f); + accum_reset |= ImGui::SliderFloat2("Parallax shift", &m_parallax_shift.x, -1.f, 1.f); + accum_reset |= ImGui::SliderFloat("Slice / focus depth", &m_slice_plane_z, -m_bounding_radius, m_bounding_radius); + accum_reset |= ImGui::SliderFloat( + "Render near distance", &m_render_near_distance, 0.0f, 1.0f, "%.3f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat + ); + + bool lens_changed = ImGui::Checkbox("Apply lens distortion", &m_render_with_lens_distortion); + if (m_render_with_lens_distortion) { + lens_changed |= ImGui::Combo("Lens mode", (int*)&m_render_lens.mode, LensModeStr); + if (m_render_lens.mode == ELensMode::OpenCV) { + accum_reset |= ImGui::InputFloat("k1", &m_render_lens.params[0], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("k2", &m_render_lens.params[1], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("p1", &m_render_lens.params[2], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("p2", &m_render_lens.params[3], 0.f, 0.f, "%.5f"); + } else if (m_render_lens.mode == ELensMode::OpenCVFisheye) { + accum_reset |= ImGui::InputFloat("k1", &m_render_lens.params[0], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("k2", &m_render_lens.params[1], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("k3", &m_render_lens.params[2], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("k4", &m_render_lens.params[3], 0.f, 0.f, "%.5f"); + } else if (m_render_lens.mode == ELensMode::FTheta) { + accum_reset |= ImGui::InputFloat("width", &m_render_lens.params[5], 0.f, 0.f, "%.0f"); + accum_reset |= ImGui::InputFloat("height", &m_render_lens.params[6], 0.f, 0.f, "%.0f"); + accum_reset |= ImGui::InputFloat("f_theta p0", &m_render_lens.params[0], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("f_theta p1", &m_render_lens.params[1], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("f_theta p2", &m_render_lens.params[2], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("f_theta p3", &m_render_lens.params[3], 0.f, 0.f, "%.5f"); + accum_reset |= ImGui::InputFloat("f_theta p4", &m_render_lens.params[4], 0.f, 0.f, "%.5f"); + } + + if (lens_changed && !m_render_lens.supports_dlss()) { + m_dlss = false; + } + } + ImGui::Spacing(); + + accum_reset |= lens_changed; + + char buf[2048]; + vec3 v = view_dir(); + vec3 p = look_at(); + vec3 s = m_sun_dir; + vec3 u = m_up_dir; + vec4 b = m_background_color; + snprintf( + buf, + sizeof(buf), + "testbed.background_color = [%0.3f, %0.3f, %0.3f, %0.3f]\n" + "testbed.exposure = %0.3f\n" + "testbed.sun_dir = [%0.3f,%0.3f,%0.3f]\n" + "testbed.up_dir = [%0.3f,%0.3f,%0.3f]\n" + "testbed.view_dir = [%0.3f,%0.3f,%0.3f]\n" + "testbed.look_at = [%0.3f,%0.3f,%0.3f]\n" + "testbed.scale = %0.3f\n" + "testbed.fov,testbed.aperture_size,testbed.slice_plane_z = %0.3f,%0.3f,%0.3f\n" + "testbed.autofocus_target = [%0.3f,%0.3f,%0.3f]\n" + "testbed.autofocus = %s\n\n", + b.r, + b.g, + b.b, + b.a, + m_exposure, + s.x, + s.y, + s.z, + u.x, + u.y, + u.z, + v.x, + v.y, + v.z, + p.x, + p.y, + p.z, + scale(), + fov(), + m_aperture_size, + m_slice_plane_z, + m_autofocus_target.x, + m_autofocus_target.y, + m_autofocus_target.z, + m_autofocus ? "True" : "False" + ); + + ImGui::InputTextMultiline("Params", buf, sizeof(buf)); + ImGui::TreePop(); + } + } + + if (ImGui::BeginPopupModal("Error", NULL, ImGuiWindowFlags_AlwaysAutoResize)) { + ImGui::Text("%s", imgui_error_string.c_str()); + if (ImGui::Button("OK", ImVec2(120, 0))) { + ImGui::CloseCurrentPopup(); + } + ImGui::EndPopup(); + } + + if (accum_reset) { + reset_accumulation(); + } + + if (ImGui::Button("Go to Python REPL")) { + m_want_repl = true; + } + + ImGui::End(); +} + +void Testbed::init_camera_path_from_reproject_src_cameras() { + m_camera_path.clear(); + + for (int i = m_reproject_min_src_view_index; i < std::min(m_reproject_max_src_view_index, (int)m_reproject_src_views.size()); ++i) { + const auto& view = m_reproject_src_views[i]; + m_camera_path.add_camera( + view.camera0, + view.fov()[m_fov_axis], + 0.0f // timestamp set to zero: camera path treats keyframes as temporally equidistant + ); + } + + m_camera_path.keyframe_subsampling = (int)m_camera_path.keyframes.size(); + m_camera_path.editing_kernel_type = EEditingKernel::Gaussian; +} + +void Testbed::visualize_reproject_src_cameras(ImDrawList* list, const mat4& world2proj) { + for (size_t i = (size_t)m_reproject_min_src_view_index; + i < std::min((size_t)m_reproject_max_src_view_index, m_reproject_src_views.size()); + ++i) { + const auto& view = m_reproject_src_views[i]; + auto res = view.full_resolution; + float aspect = float(res.x) / float(res.y); + + visualize_camera(list, world2proj, view.camera0, aspect, 0xffffffff); + } +} + +void Testbed::clear_src_views() { + m_reproject_src_views.clear(); + reset_accumulation(); +} + +void Testbed::draw_visualizations(ImDrawList* list, const mat4x3& camera_matrix) { + mat4 view2world = camera_matrix; + mat4 world2view = inverse(view2world); + + auto focal = calc_focal_length(ivec2(1), m_relative_focal_length, m_fov_axis, m_zoom); + float zscale = 1.0f / focal[m_fov_axis]; + + float xyscale = (float)m_window_res[m_fov_axis]; + vec2 screen_center = render_screen_center(m_screen_center); + mat4 view2proj = transpose( + mat4{ + xyscale, + 0.0f, + (float)m_window_res.x * screen_center.x * zscale, + 0.0f, + 0.0f, + xyscale, + (float)m_window_res.y * screen_center.y * zscale, + 0.0f, + 0.0f, + 0.0f, + 1.0f, + 0.0f, + 0.0f, + 0.0f, + zscale, + 0.0f, + } + ); + + mat4 world2proj = view2proj * world2view; + float aspect = (float)m_window_res.x / (float)m_window_res.y; + + if (m_reproject_visualize_src_views) { + visualize_reproject_src_cameras(list, world2proj); + } + + if (m_visualize_unit_cube) { + visualize_cube(list, world2proj, vec3(0.f), vec3(1.f), mat3::identity()); + } + + if (m_edit_render_aabb) { + ImGuiIO& io = ImGui::GetIO(); + // float flx = focal.x; + float fly = focal.y; + float zfar = m_ndc_zfar; + float znear = m_ndc_znear; + mat4 view2proj_guizmo = transpose( + mat4{ + fly * 2.0f / aspect, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + -fly * 2.f, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + (zfar + znear) / (zfar - znear), + -(2.0f * zfar * znear) / (zfar - znear), + 0.0f, + 0.0f, + 1.0f, + 0.0f, + } + ); + + ImGuizmo::SetRect(0, 0, io.DisplaySize.x, io.DisplaySize.y); + + static mat4 matrix = mat4::identity(); + static mat4 world2view_guizmo = mat4::identity(); + + vec3 cen = transpose(m_render_aabb_to_local) * m_render_aabb.center(); + if (!ImGuizmo::IsUsing()) { + // The the guizmo is being used, it handles updating its matrix on its own. + // Outside interference can only lead to trouble. + auto rot = transpose(m_render_aabb_to_local); + matrix = mat4(mat4x3(rot[0], rot[1], rot[2], cen)); + + // Additionally, the world2view transform must stay fixed, else the guizmo will incorrectly + // interpret the state from past frames. Special handling is necessary here, because below + // we emulate world translation and rotation through (inverse) camera movement. + world2view_guizmo = world2view; + } + + auto prev_matrix = matrix; + + if (ImGuizmo::Manipulate( + (const float*)&world2view_guizmo, (const float*)&view2proj_guizmo, m_camera_path.m_gizmo_op, ImGuizmo::LOCAL, (float*)&matrix, NULL, NULL + )) { + if (m_edit_world_transform) { + // We transform the world by transforming the camera in the opposite direction. + auto rel = prev_matrix * inverse(matrix); + m_camera = mat3(rel) * m_camera; + m_camera[3] += rel[3].xyz(); + + m_up_dir = mat3(rel) * m_up_dir; + } else { + m_render_aabb_to_local = transpose(mat3(matrix)); + vec3 new_cen = m_render_aabb_to_local * matrix[3].xyz(); + vec3 old_cen = m_render_aabb.center(); + m_render_aabb.min += new_cen - old_cen; + m_render_aabb.max += new_cen - old_cen; + } + + reset_accumulation(); + } + } + + + if (m_camera_path.imgui_viz( + list, + view2proj, + world2proj, + world2view, + focal, + aspect, + m_ndc_znear, + m_ndc_zfar + )) { + m_pip_render_buffer->reset_accumulation(); + } +} + +void glfw_error_callback(int error, const char* description) { tlog::error() << "GLFW error #" << error << ": " << description; } + +bool Testbed::keyboard_event() { + if (ImGui::GetIO().WantCaptureKeyboard) { + return false; + } + + if (m_keyboard_event_callback && m_keyboard_event_callback()) { + return false; + } + + if (ImGui::IsKeyPressed('Q') && ImGui::GetIO().KeyCtrl) { + glfwSetWindowShouldClose(m_glfw_window, GLFW_TRUE); + } + + if ((ImGui::IsKeyPressed(GLFW_KEY_TAB) || ImGui::IsKeyPressed(GLFW_KEY_GRAVE_ACCENT)) && !ImGui::GetIO().KeyCtrl) { + m_imgui.mode = (ImGuiMode)(((uint32_t)m_imgui.mode + 1) % (uint32_t)ImGuiMode::NumModes); + } + + for (int idx = 0; idx < std::min((int)ERenderMode::NumRenderModes, 10); ++idx) { + char c[] = {"1234567890"}; + if (ImGui::IsKeyPressed(c[idx])) { + m_render_mode = (ERenderMode)idx; + reset_accumulation(); + } + } + + bool ctrl = ImGui::GetIO().KeyCtrl; + bool shift = ImGui::GetIO().KeyShift; + + if (ImGui::IsKeyPressed('Z')) { + m_camera_path.m_gizmo_op = ImGuizmo::TRANSLATE; + } + + if (ImGui::IsKeyPressed('X')) { + m_camera_path.m_gizmo_op = ImGuizmo::ROTATE; + } + + if (ImGui::IsKeyPressed('E')) { + set_exposure(m_exposure + (shift ? -0.5f : 0.5f)); + redraw_next_frame(); + } + + if (ImGui::IsKeyPressed('R')) { + reset_camera(); + } + + if (ImGui::IsKeyPressed('=') || ImGui::IsKeyPressed('+')) { + if (m_fps_camera) { + m_camera_velocity *= 1.5f; + } else { + set_scale(m_scale * 1.1f); + } + } + + if (ImGui::IsKeyPressed('-') || ImGui::IsKeyPressed('_')) { + if (m_fps_camera) { + m_camera_velocity /= 1.5f; + } else { + set_scale(m_scale / 1.1f); + } + } + + // WASD camera movement + vec3 translate_vec = vec3(0.0f); + if (ImGui::IsKeyDown('W')) { + translate_vec.z += 1.0f; + } + + if (ImGui::IsKeyDown('A')) { + translate_vec.x += -1.0f; + } + + if (ImGui::IsKeyDown('S')) { + translate_vec.z += -1.0f; + } + + if (ImGui::IsKeyDown('D')) { + translate_vec.x += 1.0f; + } + + if (ImGui::IsKeyDown(' ')) { + translate_vec.y += -1.0f; + } + + if (ImGui::IsKeyDown('C')) { + translate_vec.y += 1.0f; + } + + translate_vec *= m_camera_velocity * m_frame_ms.val() / 1000.0f; + if (shift) { + translate_vec *= 5.0f; + } + + if (translate_vec != vec3(0.0f)) { + m_fps_camera = true; + + // If VR is active, movement that isn't aligned with the current view + // direction is _very_ jarring to the user, so make keyboard-based + // movement aligned with the VR view, even though it is not an intended + // movement mechanism. (Users should use controllers.) + translate_camera(translate_vec, m_hmd && m_hmd->is_visible() ? mat3(m_views.front().camera0) : mat3(m_camera)); + } + + return false; +} + +void Testbed::mouse_wheel() { + float delta = ImGui::GetIO().MouseWheel; + if (delta == 0) { + return; + } + + float scale_factor = pow(1.1f, -delta); + set_scale(m_scale * scale_factor); + + reset_accumulation(true); +} + +mat3 Testbed::rotation_from_angles(const vec2& angles) const { + vec3 up = m_up_dir; + vec3 side = m_camera[0]; + return rotmat(angles.x, up) * rotmat(angles.y, side); +} + +void Testbed::mouse_drag() { + vec2 rel = vec2{ImGui::GetIO().MouseDelta.x, ImGui::GetIO().MouseDelta.y} / (float)m_window_res[m_fov_axis]; + vec2 mouse = {ImGui::GetMousePos().x, ImGui::GetMousePos().y}; + + vec3 side = m_camera[0]; + + bool shift = ImGui::GetIO().KeyShift; + + // Left pressed + if (ImGui::GetIO().MouseClicked[0] && shift) { + m_autofocus_target = get_3d_pos_from_pixel(*m_views.front().render_buffer, mouse); + m_autofocus = true; + + reset_accumulation(); + } + + // Left held + if (ImGui::GetIO().MouseDown[0]) { + float rot_sensitivity = m_fps_camera ? 0.35f : 1.0f; + mat3 rot = rotation_from_angles(-rel * 2.0f * PI() * rot_sensitivity); + + if (m_fps_camera) { + rot *= mat3(m_camera); + m_camera = mat4x3(rot[0], rot[1], rot[2], m_camera[3]); + } else { + // Turntable + auto old_look_at = look_at(); + set_look_at({0.0f, 0.0f, 0.0f}); + m_camera = rot * m_camera; + set_look_at(old_look_at); + } + + reset_accumulation(true); + } + + // Right held + if (ImGui::GetIO().MouseDown[1]) { + mat3 rot = rotation_from_angles(-rel * 2.0f * PI()); + if (m_render_mode == ERenderMode::Shade) { + m_sun_dir = transpose(rot) * m_sun_dir; + } + + m_slice_plane_z += -rel.y * m_bounding_radius; + reset_accumulation(); + } + + // Middle pressed + if (ImGui::GetIO().MouseClicked[2]) { + m_drag_depth = get_depth_from_renderbuffer(*m_views.front().render_buffer, mouse / vec2(m_window_res)); + } + + // Middle held + if (ImGui::GetIO().MouseDown[2]) { + vec3 translation = vec3{-rel.x, -rel.y, 0.0f} / m_zoom; + bool is_orthographic = m_render_with_lens_distortion && m_render_lens.mode == ELensMode::Orthographic; + + translation /= m_relative_focal_length[m_fov_axis]; + + // If we have a valid depth value, scale the scene translation by it such that the + // hovered point in 3D space stays under the cursor. + if (m_drag_depth < 256.0f && !is_orthographic) { + translation *= m_drag_depth; + } + + translate_camera(translation, mat3(m_camera)); + } +} + +bool Testbed::begin_frame() { + if (glfwWindowShouldClose(m_glfw_window)) { + destroy_window(); + return false; + } + + { + auto now = std::chrono::steady_clock::now(); + auto elapsed = now - m_last_frame_time_point; + m_last_frame_time_point = now; + m_frame_ms.update(std::chrono::duration(elapsed).count()); + } + + glfwPollEvents(); + glfwGetFramebufferSize(m_glfw_window, &m_window_res.x, &m_window_res.y); + + ImGui_ImplOpenGL3_NewFrame(); + ImGui_ImplGlfw_NewFrame(); + ImGui::NewFrame(); + ImGuizmo::BeginFrame(); + + return true; +} + +void Testbed::handle_user_input() { + // Only respond to mouse inputs when not interacting with ImGui + if (!ImGui::IsAnyItemActive() && !ImGuizmo::IsUsing() && !ImGui::GetIO().WantCaptureMouse) { + mouse_wheel(); + mouse_drag(); + } + + keyboard_event(); + + switch (m_imgui.mode) { + case ImGuiMode::Enabled: imgui(); break; + case ImGuiMode::FpsOverlay: overlay_fps(); break; + case ImGuiMode::Disabled: break; + default: throw std::runtime_error{fmt::format("Invalid imgui mode: {}", (uint32_t)m_imgui.mode)}; + } +} + +vec3 Testbed::vr_to_world(const vec3& pos) const { return mat3(m_camera) * pos * m_scale + m_camera[3]; } + +void Testbed::begin_vr_frame_and_handle_vr_input() { + if (!m_hmd) { + m_vr_frame_info = nullptr; + return; + } + + m_hmd->poll_events(); + if (!m_hmd->must_run_frame_loop()) { + m_vr_frame_info = nullptr; + return; + } + + m_vr_frame_info = m_hmd->begin_frame(); + + const auto& views = m_vr_frame_info->views; + size_t n_views = views.size(); + size_t n_devices = m_devices.size(); + if (n_views > 0) { + set_n_views(n_views); + + ivec2 total_size = 0; + for (size_t i = 0; i < n_views; ++i) { + ivec2 view_resolution = {views[i].view.subImage.imageRect.extent.width, views[i].view.subImage.imageRect.extent.height}; + total_size += view_resolution; + + m_views[i].full_resolution = view_resolution; + + // Apply the VR pose relative to the world camera transform. + m_views[i].camera0 = mat3(m_camera) * views[i].pose; + m_views[i].camera0[3] = vr_to_world(views[i].pose[3]); + m_views[i].camera1 = m_views[i].camera0; + + m_views[i].visualized_dimension = m_visualized_dimension; + + const auto& xr_fov = views[i].view.fov; + + // Compute the distance on the image plane (1 unit away from the camera) that an angle of the respective FOV spans + vec2 rel_focal_length_left_down = 0.5f * + fov_to_focal_length(ivec2(1), vec2{360.0f * xr_fov.angleLeft / PI(), 360.0f * xr_fov.angleDown / PI()}); + vec2 rel_focal_length_right_up = 0.5f * + fov_to_focal_length(ivec2(1), vec2{360.0f * xr_fov.angleRight / PI(), 360.0f * xr_fov.angleUp / PI()}); + + // Compute total distance (for X and Y) that is spanned on the image plane. + m_views[i].relative_focal_length = rel_focal_length_right_up - rel_focal_length_left_down; + + // Compute fraction of that distance that is spanned by the right-up part and set screen center accordingly. + vec2 ratio = rel_focal_length_right_up / m_views[i].relative_focal_length; + m_views[i].screen_center = {1.0f - ratio.x, ratio.y}; + + // Fix up weirdness in the rendering pipeline + m_views[i].relative_focal_length[(m_fov_axis + 1) % 2] *= (float)view_resolution[(m_fov_axis + 1) % 2] / + (float)view_resolution[m_fov_axis]; + m_views[i].render_buffer->set_hidden_area_mask(m_vr_use_hidden_area_mask ? views[i].hidden_area_mask : nullptr); + + // Render each view on a different GPU (if available) + m_views[i].device = m_use_aux_devices ? &m_devices.at(i % m_devices.size()) : &primary_device(); + } + + // Put all the views next to each other, but at half size + glfwSetWindowSize(m_glfw_window, total_size.x / 2, (total_size.y / 2) / n_views); + + // VR controller input + const auto& hands = m_vr_frame_info->hands; + m_fps_camera = true; + + // TRANSLATE BY STICK (if not pressing the stick) + if (!hands[0].pressing) { + vec3 translate_vec = vec3{hands[0].thumbstick.x, 0.0f, hands[0].thumbstick.y} * m_camera_velocity * m_frame_ms.val() / 1000.0f; + if (translate_vec != vec3(0.0f)) { + translate_camera(translate_vec, mat3(m_views.front().camera0), false); + } + } + + // TURN BY STICK (if not pressing the stick) + if (!hands[1].pressing) { + auto prev_camera = m_camera; + + // Turn around the up vector (equivalent to x-axis mouse drag) with right joystick left/right + float sensitivity = 0.35f; + auto rot = rotation_from_angles({-2.0f * PI() * sensitivity * hands[1].thumbstick.x * m_frame_ms.val() / 1000.0f, 0.0f}) * + mat3(m_camera); + m_camera = mat4x3(rot[0], rot[1], rot[2], m_camera[3]); + + // Translate camera such that center of rotation was about the current view + m_camera[3] += mat3(prev_camera) * views[0].pose[3] * m_scale - mat3(m_camera) * views[0].pose[3] * m_scale; + } + + // TRANSLATE, SCALE, AND ROTATE BY GRAB + { + bool both_grabbing = hands[0].grabbing && hands[1].grabbing; + float drag_factor = both_grabbing ? 0.5f : 1.0f; + + if (both_grabbing) { + drag_factor = 0.5f; + + vec3 prev_diff = hands[0].prev_grab_pos - hands[1].prev_grab_pos; + vec3 diff = hands[0].grab_pos - hands[1].grab_pos; + vec3 center = 0.5f * (hands[0].grab_pos + hands[1].grab_pos); + + vec3 center_world = vr_to_world(0.5f * (hands[0].grab_pos + hands[1].grab_pos)); + + // Scale around center position of the two dragging hands. Makes the scaling feel similar to phone pinch-to-zoom + float scale = m_scale * length(prev_diff) / length(diff); + m_camera[3] = (view_pos() - center_world) * (scale / m_scale) + center_world; + m_scale = scale; + + // Take rotational component and project it to the nearest rotation about the up vector. + // We don't want to rotate the scene about any other axis. + vec3 rot = cross(normalize(prev_diff), normalize(diff)); + float rot_radians = std::asin(dot(m_up_dir, rot)); + + auto prev_camera = m_camera; + auto rotcam = rotmat(rot_radians, m_up_dir) * mat3(m_camera); + m_camera = mat4x3(rotcam[0], rotcam[1], rotcam[2], m_camera[3]); + m_camera[3] += mat3(prev_camera) * center * m_scale - mat3(m_camera) * center * m_scale; + } + + for (const auto& hand : hands) { + if (hand.grabbing) { + m_camera[3] -= drag_factor * mat3(m_camera) * hand.drag() * m_scale; + } + } + } + } +} + +void Testbed::SecondWindow::draw(GLuint texture) { + if (!window) { + return; + } + int display_w, display_h; + GLFWwindow* old_context = glfwGetCurrentContext(); + glfwMakeContextCurrent(window); + glfwGetFramebufferSize(window, &display_w, &display_h); + glViewport(0, 0, display_w, display_h); + glClearColor(0.0f, 0.0f, 0.0f, 1.0f); + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + glEnable(GL_TEXTURE_2D); + glBindTexture(GL_TEXTURE_2D, texture); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glBindVertexArray(vao); + if (program) { + glUseProgram(program); + } + glDrawArrays(GL_TRIANGLES, 0, 6); + glBindVertexArray(0); + glUseProgram(0); + glfwSwapBuffers(window); + glfwMakeContextCurrent(old_context); +} + +void Testbed::init_opengl_shaders() { + static const char* shader_vert = R"glsl(#version 140 + out vec2 UVs; + void main() { + UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2); + gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0); + })glsl"; + + static const char* shader_frag = R"glsl(#version 140 + in vec2 UVs; + out vec4 frag_color; + uniform sampler2D rgba_texture; + uniform sampler2D depth_texture; + + struct FoveationWarp { + float al, bl, cl; + float am, bm; + float ar, br, cr; + float switch_left, switch_right; + float inv_switch_left, inv_switch_right; + }; + + uniform FoveationWarp warp_x; + uniform FoveationWarp warp_y; + + float unwarp(in FoveationWarp warp, float y) { + y = clamp(y, 0.0, 1.0); + if (y < warp.inv_switch_left) { + return (sqrt(-4.0 * warp.al * warp.cl + 4.0 * warp.al * y + warp.bl * warp.bl) - warp.bl) / (2.0 * warp.al); + } else if (y > warp.inv_switch_right) { + return (sqrt(-4.0 * warp.ar * warp.cr + 4.0 * warp.ar * y + warp.br * warp.br) - warp.br) / (2.0 * warp.ar); + } else { + return (y - warp.bm) / warp.am; + } + } + + vec2 unwarp(in vec2 pos) { + return vec2(unwarp(warp_x, pos.x), unwarp(warp_y, pos.y)); + } + + void main() { + vec2 tex_coords = UVs; + tex_coords.y = 1.0 - tex_coords.y; + tex_coords = unwarp(tex_coords); + frag_color = texture(rgba_texture, tex_coords.xy); + //Uncomment the following line of code to visualize debug the depth buffer for debugging. + // frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0); + gl_FragDepth = texture(depth_texture, tex_coords.xy).r; + })glsl"; + + GLuint vert = glCreateShader(GL_VERTEX_SHADER); + glShaderSource(vert, 1, &shader_vert, NULL); + glCompileShader(vert); + check_shader(vert, "Blit vertex shader", false); + + GLuint frag = glCreateShader(GL_FRAGMENT_SHADER); + glShaderSource(frag, 1, &shader_frag, NULL); + glCompileShader(frag); + check_shader(frag, "Blit fragment shader", false); + + m_blit_program = glCreateProgram(); + glAttachShader(m_blit_program, vert); + glAttachShader(m_blit_program, frag); + glLinkProgram(m_blit_program); + check_shader(m_blit_program, "Blit shader program", true); + + glDeleteShader(vert); + glDeleteShader(frag); + + glGenVertexArrays(1, &m_blit_vao); +} + +void Testbed::blit_texture( + const Foveation& foveation, + GLint rgba_texture, + GLint rgba_filter_mode, + GLint depth_texture, + GLint framebuffer, + const ivec2& offset, + const ivec2& resolution +) { + if (m_blit_program == 0) { + return; + } + + // Blit image to OpenXR swapchain. + // Note that the OpenXR swapchain is 8bit while the rendering is in a float texture. + // As some XR runtimes do not support float swapchains, we can't render into it directly. + + bool tex = glIsEnabled(GL_TEXTURE_2D); + bool depth = glIsEnabled(GL_DEPTH_TEST); + bool cull = glIsEnabled(GL_CULL_FACE); + + if (!tex) { + glEnable(GL_TEXTURE_2D); + } + if (!depth) { + glEnable(GL_DEPTH_TEST); + } + if (cull) { + glDisable(GL_CULL_FACE); + } + + glDepthFunc(GL_ALWAYS); + glDepthMask(GL_TRUE); + + glBindVertexArray(m_blit_vao); + glUseProgram(m_blit_program); + glUniform1i(glGetUniformLocation(m_blit_program, "rgba_texture"), 0); + glUniform1i(glGetUniformLocation(m_blit_program, "depth_texture"), 1); + + auto bind_warp = [&](const FoveationPiecewiseQuadratic& warp, const std::string& uniform_name) { + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".al").c_str()), warp.al); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bl").c_str()), warp.bl); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cl").c_str()), warp.cl); + + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".am").c_str()), warp.am); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bm").c_str()), warp.bm); + + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".ar").c_str()), warp.ar); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".br").c_str()), warp.br); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cr").c_str()), warp.cr); + + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_left").c_str()), warp.switch_left); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_right").c_str()), warp.switch_right); + + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_left").c_str()), warp.inv_switch_left); + glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_right").c_str()), warp.inv_switch_right); + }; + + bind_warp(foveation.warp_x, "warp_x"); + bind_warp(foveation.warp_y, "warp_y"); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, depth_texture); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(GL_TEXTURE_2D, rgba_texture); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, rgba_filter_mode); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, rgba_filter_mode); + + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer); + glViewport(offset.x, offset.y, resolution.x, resolution.y); + + glDrawArrays(GL_TRIANGLES, 0, 3); + + glBindVertexArray(0); + glUseProgram(0); + + glDepthFunc(GL_LESS); + + // restore old state + if (!tex) { + glDisable(GL_TEXTURE_2D); + } + if (!depth) { + glDisable(GL_DEPTH_TEST); + } + if (cull) { + glEnable(GL_CULL_FACE); + } + glBindFramebuffer(GL_FRAMEBUFFER, 0); +} + +void Testbed::draw_gui() { + // Make sure all the cuda code finished its business here + CUDA_CHECK_THROW(cudaDeviceSynchronize()); + + if (!m_rgba_render_textures.empty()) { + m_second_window.draw((GLuint)m_rgba_render_textures.front()->texture()); + } + + glfwMakeContextCurrent(m_glfw_window); + int display_w, display_h; + glfwGetFramebufferSize(m_glfw_window, &display_w, &display_h); + glViewport(0, 0, display_w, display_h); + glClearColor(0.f, 0.f, 0.f, 0.f); + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + + glEnable(GL_BLEND); + glBlendEquationSeparate(GL_FUNC_ADD, GL_FUNC_ADD); + glBlendFuncSeparate(GL_ONE, GL_ONE_MINUS_SRC_ALPHA, GL_ONE, GL_ONE_MINUS_SRC_ALPHA); + + ivec2 extent = {(int)((float)display_w / m_n_views.x), (int)((float)display_h / m_n_views.y)}; + + int i = 0; + for (int y = 0; y < m_n_views.y; ++y) { + for (int x = 0; x < m_n_views.x; ++x) { + if (i >= m_views.size()) { + break; + } + + auto& view = m_views[i]; + ivec2 top_left{x * extent.x, display_h - (y + 1) * extent.y}; + blit_texture( + m_foveated_rendering_visualize ? Foveation{} : view.foveation, + m_rgba_render_textures.at(i)->texture(), + m_foveated_rendering ? GL_LINEAR : GL_NEAREST, + m_depth_render_textures.at(i)->texture(), + 0, + top_left, + extent + ); + + ++i; + } + } + glFinish(); + glViewport(0, 0, display_w, display_h); + + ImDrawList* list = ImGui::GetBackgroundDrawList(); + list->AddCallback(ImDrawCallback_ResetRenderState, nullptr); + + // Visualizations are only meaningful when rendering a single view + if (m_views.size() == 1) { + draw_visualizations(list, m_smoothed_camera); + } + + if (m_render_ground_truth) { + list->AddText(ImVec2(4.f, 4.f), 0xffffffff, "Ground Truth"); + } + + ImGui::Render(); + ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); + + glfwSwapBuffers(m_glfw_window); + + // Make sure all the OGL code finished its business here. + // Any code outside of this function needs to be able to freely write to + // textures without being worried about interfering with rendering. + glFinish(); +} +#endif // NGP_GUI + +__global__ void to_8bit_color_kernel(ivec2 resolution, EColorSpace output_color_space, cudaSurfaceObject_t surface, uint8_t* result) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + vec4 color; + surf2Dread((float4*)&color, surface, x * sizeof(float4), y); + + if (output_color_space == EColorSpace::Linear) { + color.rgb() = linear_to_srgb(color.rgb()); + } + + for (uint32_t i = 0; i < 3; ++i) { + result[(x + resolution.x * y) * 3 + i] = (uint8_t)(clamp(color[i], 0.0f, 1.0f) * 255.0f + 0.5f); + } +} + +void Testbed::prepare_next_camera_path_frame() { + if (!m_camera_path.rendering) { + return; + } + + // If we're rendering a video, we'd like to accumulate multiple spp + // for motion blur. Hence dump the frame once the target spp has been reached + // and only reset _then_. + if (m_views.front().render_buffer->spp() == m_camera_path.render_settings.spp) { + auto tmp_dir = fs::path{"tmp"}; + if (!tmp_dir.exists()) { + if (!fs::create_directory(tmp_dir)) { + m_camera_path.rendering = false; + tlog::error() << "Failed to create temporary directory 'tmp' to hold rendered images."; + return; + } + } + + ivec2 res = m_views.front().render_buffer->out_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + + GPUMemory image_data(product(res) * 3); + to_8bit_color_kernel<<>>( + res, + EColorSpace::SRGB, // the GUI always renders in SRGB + m_views.front().render_buffer->surface(), + image_data.data() + ); + + m_render_futures.emplace_back( + m_thread_pool.enqueue_task([image_data = std::move(image_data), frame_idx = m_camera_path.render_frame_idx++, res, tmp_dir] { + std::vector cpu_image_data(image_data.size()); + CUDA_CHECK_THROW(cudaMemcpy(cpu_image_data.data(), image_data.data(), image_data.bytes(), cudaMemcpyDeviceToHost)); + write_stbi(tmp_dir / fmt::format("{:06d}.jpg", frame_idx), res.x, res.y, 3, cpu_image_data.data(), 100); + }) + ); + + reset_accumulation(true); + + if (m_camera_path.render_frame_idx == m_camera_path.render_settings.n_frames(m_camera_path.duration_seconds())) { + m_camera_path.rendering = false; + + wait_all(m_render_futures); + m_render_futures.clear(); + + tlog::success() << "Finished rendering '.jpg' video frames to '" << tmp_dir << "'. Assembling them into a video next."; + + fs::path ffmpeg = "ffmpeg"; + +#ifdef _WIN32 + // Under Windows, try automatically downloading FFmpeg binaries if they don't exist + if (system(fmt::format("where {} >nul 2>nul", ffmpeg.str()).c_str()) != 0) { + fs::path dir = root_dir(); + if ((dir / "external" / "ffmpeg").exists()) { + for (const auto& path : fs::directory{dir / "external" / "ffmpeg"}) { + ffmpeg = path / "bin" / "ffmpeg.exe"; + } + } + + if (!ffmpeg.exists()) { + tlog::info() << "FFmpeg not found. Downloading FFmpeg..."; + do_system((dir / "scripts" / "download_ffmpeg.bat").str()); + } + + for (const auto& path : fs::directory{dir / "external" / "ffmpeg"}) { + ffmpeg = path / "bin" / "ffmpeg.exe"; + } + + if (!ffmpeg.exists()) { + tlog::warning() << "FFmpeg download failed. Trying system-wide FFmpeg."; + } + } +#endif + + auto ffmpeg_command = fmt::format( + "{} -loglevel error -y -framerate {} -i tmp/%06d.jpg -c:v libx264 -preset slow -crf {} -pix_fmt yuv420p \"{}\"", + ffmpeg.str(), + m_camera_path.render_settings.fps, + // Quality goes from 0 to 10. This conversion to CRF means a quality of 10 + // is a CRF of 17 and a quality of 0 a CRF of 27, which covers the "sane" + // range of x264 quality settings according to the FFmpeg docs: + // https://trac.ffmpeg.org/wiki/Encode/H.264 + 27 - m_camera_path.render_settings.quality, + m_camera_path.render_settings.filename + ); + int ffmpeg_result = do_system(ffmpeg_command); + if (ffmpeg_result == 0) { + tlog::success() << "Saved video '" << m_camera_path.render_settings.filename << "'"; + } else if (ffmpeg_result == -1) { + tlog::error() << "Video could not be assembled: FFmpeg not found."; + } else { + tlog::error() << "Video could not be assembled: FFmpeg failed"; + } + + clear_tmp_dir(); + } + } + + const auto& rs = m_camera_path.render_settings; + const float duration = m_camera_path.duration_seconds(); + m_camera_path.play_time = (float)((double)m_camera_path.render_frame_idx / (double)rs.n_frames(duration)); + + if (m_views.front().render_buffer->spp() == 0) { + set_camera_from_time(m_camera_path.play_time); + apply_camera_smoothing(rs.frame_milliseconds(duration)); + + auto smoothed_camera_backup = m_smoothed_camera; + + // Compute the camera for the next frame in order to be able to compute motion blur + // between it and the current one. + set_camera_from_time(m_camera_path.play_time + 1.0f / rs.n_frames(duration)); + apply_camera_smoothing(rs.frame_milliseconds(duration)); + + m_camera_path.render_frame_end_camera = m_smoothed_camera; + + // Revert camera such that the next frame will be computed correctly + // (Start camera of next frame should be the same as end camera of this frame) + set_camera_from_time(m_camera_path.play_time); + m_smoothed_camera = smoothed_camera_backup; + } +} + +__global__ void reproject_kernel( + BoundingBox render_aabb, + mat3 render_aabb_to_local, + default_rng_t rng, + float near_t, + float step_factor, + uint32_t spp, + uint32_t view_idx, + mat4x3 src_camera, + vec2 src_screen_center, + vec2 src_focal_length, + ivec2 src_resolution, + Foveation src_foveation, + Lens src_lens, + MatrixView src_depth_buffer, + mat4x3 dst_camera, + vec2 dst_screen_center, + vec2 dst_focal_length, + ivec2 dst_resolution, + Foveation dst_foveation, + Lens dst_lens, + vec4* __restrict__ dst_frame_buffer, + MatrixView dst_depth_buffer, + MatrixView dst_hole_mask, + MatrixView dst_index_field, + MatrixView src_hole_mask = {}, + MatrixView src_index_field = {} +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + uint32_t is_hole = dst_hole_mask(y, x); + if (x >= dst_resolution.x || y >= dst_resolution.y || (src_hole_mask && !is_hole)) { + return; + } + + auto ray = pixel_to_ray( + spp, + {(int)x, (int)y}, + dst_resolution, + dst_focal_length, + dst_camera, + dst_screen_center, + vec3(0.0f), // parallax + false, // pixel center snap + 0.0f, // near dist + 1.0f, // focus + 0.0f, // aperture + dst_foveation, + {}, + dst_lens + ); + + uint32_t dst_idx = x + dst_resolution.x * y; + + float t = near_t; + rng.advance(dst_idx); + t *= std::pow(step_factor, rng.next_float()); + + struct Result { + ViewIdx idx; + float dist; + float t; + }; + + auto get_reprojected_dist = [&](float t) -> Result { + vec3 p = ray(t); + + vec2 src_px = pos_to_pixel(p, src_resolution, src_focal_length, src_camera, src_screen_center, vec3(0.0f), src_foveation, src_lens); + + if (src_px.x <= 0 || src_px.x >= src_resolution.x || src_px.y <= 0 || src_px.y >= src_resolution.y) { + return { + {-1, 0}, + -1.0f, -1.0f + }; + } + + ViewIdx nearest = {clamp(ivec2(floor(src_px)), 0, src_resolution - 1), view_idx}; + if (src_hole_mask) { + if (!src_hole_mask(nearest.px.y, nearest.px.x) || src_depth_buffer(nearest.px.y, nearest.px.x) == 0.0f) { + return { + {-1, 0}, + -1.0f, -1.0f + }; + } + } + + float d = src_depth_buffer(nearest.px.y, nearest.px.x); + Ray src_ray = { + src_camera[3], + p - src_camera[3], + }; + + src_ray.d /= src_lens.is_360() ? length(src_ray.d) : dot(src_ray.d, src_camera[2]); + + vec3 src_p = src_ray(d); + if (src_index_field) { + nearest = src_index_field(nearest.px.y, nearest.px.x); + } + + return {nearest, distance(p, src_p), t}; + }; + + auto refine_match = [&](Result match) -> Result { + static const uint32_t N_STEPS_PER_REFINEMENT = 10; + static const uint32_t N_REFINEMENTS = 3; + + float prev_t = match.t / step_factor; + float next_t = match.t * step_factor; + + NGP_PRAGMA_UNROLL + for (uint32_t j = 0; j < N_REFINEMENTS; ++j) { + float step_size = (next_t - prev_t) / (N_STEPS_PER_REFINEMENT - 1); + float t = prev_t; + + NGP_PRAGMA_UNROLL + for (uint32_t i = 0; i < N_STEPS_PER_REFINEMENT; ++i) { + auto res = get_reprojected_dist(t); + if (res.idx.px.x >= 0 && res.dist < match.dist) { + match = res; + prev_t = t - step_size; + next_t = t + step_size; + } + + t += step_size; + } + } + + return match; + }; + + Result final = { + {-1, 0}, + std::numeric_limits::infinity(), 0 + }; + Result fallback = final; + + float mint = fmaxf(render_aabb.ray_intersect(render_aabb_to_local * ray.o, render_aabb_to_local * ray.d).x, 0.0f) + 1e-6f; + if (mint < MAX_DEPTH()) { + while (t <= mint) { + t *= step_factor; + } + } + + // float last_dist = std::numeric_limits::infinity(); + for (; render_aabb.contains(render_aabb_to_local * ray(t)); t *= step_factor) { + auto res = get_reprojected_dist(t); + if (res.idx.px.x >= 0) { + if (res.dist < t * (step_factor - 1.0f)) { + res = refine_match(res); + if (res.dist < final.dist) { + if (res.dist / res.t < 4.0f / dst_focal_length.x) { + final = res; + break; + } + } + } + + // if (res.dist < last_dist) { + // fallback = res; + // } + + // last_dist = res.dist; + } + } + + if (final.idx.px.x == -1) { + final = fallback; + } + + float prev_depth = dst_depth_buffer(y, x); + + dst_frame_buffer[dst_idx] = vec4::zero(); + if (final.idx.px.x == -1) { + if (is_hole) { + dst_depth_buffer(y, x) = MAX_DEPTH(); + dst_hole_mask(y, x) = 1; + dst_index_field(y, x) = {-1, 0}; + } + } else { + if (is_hole || final.t * step_factor < prev_depth) { + dst_depth_buffer(y, x) = final.t; + dst_hole_mask(y, x) = src_index_field ? 2 : 0; + dst_index_field(y, x) = final.idx; + } + } +} + +__global__ void dilate_holes_kernel(ivec2 res, MatrixView old_hole_mask, MatrixView hole_mask) { + int32_t x = threadIdx.x + blockDim.x * blockIdx.x; + int32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= res.x || y >= res.y) { + return; + } + + auto is_hole = [&](const ivec2& offset) { + auto clamped = clamp(ivec2{x, y} + offset, 0, res - 1); + return old_hole_mask(clamped.y, clamped.x); + }; + + hole_mask(y, x) = is_hole({1, 0}) || is_hole({-1, 0}) || is_hole({1, 1}) || is_hole({-1, 1}) || is_hole({1, -1}) || is_hole({-1, -1}) || + is_hole({0, 1}) || is_hole({0, -1}); +} + +__global__ void generate_alt_depth_kernel( + mat4x3 src_camera, + vec2 src_screen_center, + vec2 src_focal_length, + ivec2 src_resolution, + const vec4* __restrict__ src_frame_buffer, + const float* __restrict__ src_depth_buffer, + Foveation src_foveation, + Lens src_lens, + mat4x3 dst_camera, + Lens dst_lens, + MatrixView alt_depth_buffer +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= src_resolution.x || y >= src_resolution.y) { + return; + } + + auto ray = pixel_to_ray( + 0, + {(int)x, (int)y}, + src_resolution, + src_focal_length, + src_camera, + src_screen_center, + vec3(0.0f), // parallax + false, // pixel center snap + 0.0f, // near dist + 1.0f, // focus + 0.0f, // aperture + src_foveation, + {}, + src_lens + ); + + uint32_t src_idx = x + src_resolution.x * y; + vec3 p = ray(src_depth_buffer[src_idx]); + + alt_depth_buffer(y, x) = dst_lens.is_360() ? distance(p, dst_camera[3]) : dot(p - dst_camera[3], dst_camera[2]); +} + +__global__ void copy_depth_buffer_kernel(ivec2 dst_resolution, const float* __restrict__ src_depth_buffer, MatrixView dst_depth_buffer) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= dst_resolution.x || y >= dst_resolution.y) { + return; + } + + uint32_t idx = x + dst_resolution.x * y; + dst_depth_buffer(y, x) = src_depth_buffer[idx]; +} + +static constexpr float Z_NEAR = 0.1f; +static constexpr float Z_BASE = 1.03f; + +inline NGP_HOST_DEVICE float to_log_depth(float d) { return logf(d / Z_NEAR) * logf(Z_BASE); } + +inline NGP_HOST_DEVICE float from_log_depth(float d) { return expf(d / logf(Z_BASE)) * Z_NEAR; } + +inline NGP_HOST_DEVICE vec4 from_rgbd32(uint32_t val) { + vec4 result = rgba32_to_rgba(val); + result.a = from_log_depth(result.a); + return result; +} + +inline NGP_HOST_DEVICE uint32_t to_rgbd32(vec4 rgbd) { + rgbd.a = to_log_depth(rgbd.a); + return rgba_to_rgba32(rgbd); +} + +__global__ void reproject_viz_kernel( + ivec2 dst_res, + const ivec2* src_res, + bool pm_enable, + MatrixView hole_labels, + MatrixView state, + MatrixView index_field, + MatrixView dst_rgbd, + MatrixView dst_depth, + const MatrixView* src_rgba, + const MatrixView* src_depth, + MatrixView frame, + MatrixView depth, + EPmVizMode viz_mode, + float depth_scale +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= dst_res.x || y >= dst_res.y) { + return; + } + + if (!pm_enable && state(y, x) == EPmPixelState::Hole) { + if (viz_mode == EPmVizMode::Depth) { + frame(y, x).rgb() = vec3(depth(y, x) * depth_scale); + } else { + frame(y, x).rgb() = vec3(0.0f); + } + + depth(y, x) = MAX_DEPTH(); + return; + } + + auto src_idx = index_field(y, x); + + if (viz_mode == EPmVizMode::Depth) { + frame(y, x).rgb() = vec3(dst_depth(y, x) * depth_scale); + } else if (viz_mode == EPmVizMode::Offset) { + vec2 diff = vec2(x, y) / vec2(dst_res) - vec2(src_idx.px) / vec2(src_res[src_idx.view]); + float l = length(diff); + frame(y, x).rgb() = hsv_to_rgb({atan2(diff.y / l, diff.x / l) / (PI() * 2.0f) + 0.5f, 1.0f, l}); + } else if (viz_mode == EPmVizMode::Holes) { + if (state(y, x) == EPmPixelState::Hole) { + frame(y, x).rgb() = colormap_turbo(hole_labels(y, x) / (float)product(dst_res)); + } + } else { + vec4 rgbd = rgba32_to_rgba(src_rgba[src_idx.view](src_idx.px.y, src_idx.px.x)); + rgbd.rgb() = srgb_to_linear(rgbd.rgb()); + frame(y, x) = rgbd; + depth(y, x) = src_depth[src_idx.view](src_idx.px.y, src_idx.px.x); + } +} + +static constexpr int32_t PM_PATCH_RADIUS = 4; + +inline NGP_HOST_DEVICE ivec2 mirror(const ivec2& v, const ivec2& res) { return abs(res - abs(res - v - 1) - 1); } + +__global__ void pm_prepare_padded_src_buffers( + ivec2 padded_res, + ivec2 res, + MatrixView src_rgba, + MatrixView src_depth, + MatrixView dst_rgbd, + MatrixView dst_depth +) { + int32_t x = threadIdx.x + blockDim.x * blockIdx.x; + int32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= padded_res.x || y >= padded_res.y) { + return; + } + + ivec2 padding = (padded_res - res) / 2; + ivec2 idx = {(int16_t)(x - padding.x), (int16_t)(y - padding.y)}; + + // auto clamped_idx = clamp(idx, i16vec2((int16_t)0), i16vec2(res - 1)); + auto clamped_idx = mirror(idx, i16vec2(res)); + + vec4 rgba = src_rgba(clamped_idx.y, clamped_idx.x); + rgba.rgb() = linear_to_srgb(rgba.rgb()); + dst_rgbd(idx.y, idx.x) = rgba_to_rgba32(rgba); + dst_depth(idx.y, idx.x) = src_depth(clamped_idx.y, clamped_idx.x); +} + +__global__ void pm_prepare_padded_dst_buffers( + ivec2 padded_dst_res, + ivec2 dst_res, + uint32_t n_src_views, + const ivec2* src_res, + default_rng_t fixed_seed_rng, + const MatrixView* src_rgbd, + const MatrixView* src_depth, + MatrixView dst_state, + MatrixView dst_index_field, + MatrixView dst_rgbd, + MatrixView dst_depth, + MatrixView dst_depth_threshold, + MatrixView hole_mask +) { + int32_t x = threadIdx.x + blockDim.x * blockIdx.x; + int32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= padded_dst_res.x || y >= padded_dst_res.y) { + return; + } + + ivec2 padding = (padded_dst_res - dst_res) / 2; + ivec2 idx = {x - padding.x, y - padding.y}; + + // auto clamped_idx = clamp(idx, i16vec2((int16_t)0), i16vec2(res - 1)); + auto clamped_idx = mirror(idx, dst_res); + + ViewIdx src_idx; + uint8_t is_hole = hole_mask(clamped_idx.y, clamped_idx.x); + if (is_hole == 1) { + fixed_seed_rng.advance((x + y * padded_dst_res.x) * 3); + + // uint32_t random_view = fixed_seed_rng.next_uint(n_src_views); + uint32_t random_view = 0; + auto res = src_res[random_view]; + src_idx = { + i16vec2{(int16_t)fixed_seed_rng.next_uint(res.y), (int16_t)fixed_seed_rng.next_uint(res.x)}, + random_view + }; + } else { + src_idx = dst_index_field(clamped_idx.y, clamped_idx.x); + } + + dst_index_field(idx.y, idx.x) = src_idx; + + if (is_hole == 0) { + dst_state(idx.y, idx.x) = EPmPixelState::Reprojected; + dst_rgbd(idx.y, idx.x) = src_rgbd[src_idx.view](src_idx.px.y, src_idx.px.x); + + float depth = src_depth[src_idx.view](src_idx.px.y, src_idx.px.x); + dst_depth(idx.y, idx.x) = depth; + dst_depth_threshold(idx.y, idx.x) = depth; + } else if (is_hole == 1) { + dst_state(idx.y, idx.x) = EPmPixelState::Hole; + dst_rgbd(idx.y, idx.x) = 0x00FF00FF; + dst_depth(idx.y, idx.x) = 0.0f; + dst_depth_threshold(idx.y, idx.x) = 0.0f; + } else { + dst_state(idx.y, idx.x) = EPmPixelState::Reprojected; + dst_rgbd(idx.y, idx.x) = src_rgbd[src_idx.view](src_idx.px.y, src_idx.px.x); + dst_depth_threshold(idx.y, idx.x) = dst_depth(idx.y, idx.x); + } +} + + +void Testbed::reproject_views(const std::vector src_views, View& dst_view) { + if (src_views.empty()) { + dst_view.render_buffer->clear_frame(m_stream.get()); + return; + } + + auto dst_res = dst_view.render_buffer->in_resolution(); + + std::vector src_res(src_views.size()); + std::vector src_screen_center(src_views.size()); + std::vector src_focal_length(src_views.size()); + std::vector> tmp_src_depth_buffer(src_views.size()); + + for (size_t i = 0; i < src_views.size(); ++i) { + src_res[i] = src_views[i]->render_buffer->in_resolution(); + + src_screen_center[i] = render_screen_center(src_views[i]->screen_center); + src_focal_length[i] = + calc_focal_length(src_views[i]->render_buffer->in_resolution(), src_views[i]->relative_focal_length, m_fov_axis, m_zoom); + + // Compute the depth of every pixel in the src_view when reprojected into the dst_view. + // This could in principle happen in parallel with the reprojection step happening below. + tmp_src_depth_buffer[i] = GPUImage(src_res[i], m_stream.get()); + + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)dst_res.x, threads.x), div_round_up((uint32_t)dst_res.y, threads.y), 1}; + + generate_alt_depth_kernel<<>>( + src_views[i]->camera0, + src_screen_center[i], + src_focal_length[i], + src_res[i], + src_views[i]->render_buffer->frame_buffer(), + src_views[i]->render_buffer->depth_buffer(), + src_views[i]->foveation, + src_views[i]->lens, + dst_view.camera0, + dst_view.lens, + tmp_src_depth_buffer[i].view() + ); + } + + dst_view.render_buffer->clear_frame(m_stream.get()); + + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)dst_res.x, threads.x), div_round_up((uint32_t)dst_res.y, threads.y), 1}; + + auto prev_index_field = std::move(dst_view.index_field); + dst_view.index_field = GPUImage(dst_res, PM_PATCH_RADIUS, m_stream.get()); + + auto prev_hole_mask = std::move(dst_view.hole_mask); + dst_view.hole_mask = GPUImage(dst_res, m_stream.get()); + dst_view.hole_mask.image.memset_async(m_stream.get(), 1); + + auto prev_depth_buffer = std::move(dst_view.depth_buffer); + dst_view.depth_buffer = GPUImage(dst_res, PM_PATCH_RADIUS, m_stream.get()); + + auto dst_screen_center = render_screen_center(dst_view.screen_center); + auto dst_focal_length = calc_focal_length(dst_res, dst_view.relative_focal_length, m_fov_axis, m_zoom); + + // First reproject from the source images as much as possible + for (size_t i = 0; i < src_views.size(); ++i) { + reproject_kernel<<>>( + m_render_aabb, + m_render_aabb_to_local, + m_rng, + m_reproject_min_t, + m_reproject_step_factor, + dst_view.render_buffer->spp(), + i, + src_views[i]->camera0, + src_screen_center[i], + src_focal_length[i], + src_res[i], + src_views[i]->foveation, + src_views[i]->lens, + MatrixView(src_views[i]->render_buffer->depth_buffer(), src_res[i].x, 1), + dst_view.camera0, + dst_screen_center, + dst_focal_length, + dst_res, + dst_view.foveation, + dst_view.lens, + dst_view.render_buffer->frame_buffer(), + dst_view.depth_buffer.view(), + dst_view.hole_mask.view(), + dst_view.index_field.view() + ); + } + + // auto old_holes_mask = std::move(dst_view.hole_mask); + // dst_view.hole_mask = GPUImage(dst_res, m_stream.get()); + // dilate_holes_kernel<<>>(dst_res, old_holes_mask.view(), dst_view.hole_mask.view()); + + // Then try reprojecting into the remaining holes from the previous rendering + if (m_reproject_reuse_last_frame && prev_depth_buffer.data()) { + reproject_kernel<<>>( + m_render_aabb, + m_render_aabb_to_local, + m_rng, + m_reproject_min_t, + m_reproject_step_factor, + dst_view.render_buffer->spp(), + 0, // Reprojecting from the most recent view will copy the previous index anyway. + dst_view.prev_camera, + render_screen_center(dst_view.screen_center), + calc_focal_length(prev_hole_mask.resolution(), dst_view.relative_focal_length, m_fov_axis, m_zoom), + prev_hole_mask.resolution(), + dst_view.prev_foveation, + dst_view.lens, + prev_depth_buffer.view(), + dst_view.camera0, + dst_screen_center, + dst_focal_length, + dst_res, + dst_view.foveation, + dst_view.lens, + dst_view.render_buffer->frame_buffer(), + dst_view.depth_buffer.view(), + dst_view.hole_mask.view(), + dst_view.index_field.view(), + prev_hole_mask.view(), + prev_index_field.view() + ); + } + + m_rng.advance(); + + auto hole_labels = GPUImage(dst_res, m_stream.get()); + + // Detect holes and label them + { + init_labels<<>>( + dst_res.x, dst_res.y, hole_labels.n_elements(), hole_labels.data(), dst_view.hole_mask.data() + ); + resolve_labels<<>>(dst_res.x, dst_res.y, hole_labels.n_elements(), hole_labels.data()); + label_reduction<<>>( + dst_res.x, dst_res.y, hole_labels.n_elements(), hole_labels.data(), dst_view.hole_mask.data() + ); + resolve_labels<<>>(dst_res.x, dst_res.y, hole_labels.n_elements(), hole_labels.data()); + } + + auto dst_state_buffer = GPUImage(dst_res, PM_PATCH_RADIUS, m_stream.get()); + + std::vector> src_rgbd_buffer(src_views.size()); + std::vector> src_depth_buffer(src_views.size()); + std::vector padded_src_res(src_views.size()); + + std::vector> src_rgbd_views(src_views.size()); + std::vector> src_depth_views(src_views.size()); + + for (size_t i = 0; i < src_views.size(); ++i) { + src_rgbd_buffer[i] = GPUImage(src_res[i], PM_PATCH_RADIUS, m_stream.get()); + src_depth_buffer[i] = GPUImage(src_res[i], PM_PATCH_RADIUS, m_stream.get()); + padded_src_res[i] = src_rgbd_buffer[i].resolution_padded(); + + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)padded_src_res[i].x, threads.x), div_round_up((uint32_t)padded_src_res[i].y, threads.y), 1}; + + pm_prepare_padded_src_buffers<<>>( + padded_src_res[i], + src_res[i], + MatrixView(src_views[i]->render_buffer->frame_buffer(), src_res[i].x, 1), + tmp_src_depth_buffer[i].view(), + src_rgbd_buffer[i].view(), + src_depth_buffer[i].view() + ); + + src_rgbd_views[i] = src_rgbd_buffer[i].view(); + src_depth_views[i] = src_depth_buffer[i].view(); + } + + GPUMemoryArena::Allocation views_alloc; + auto views_scratch = allocate_workspace_and_distribute, MatrixView, ivec2>( + m_stream.get(), &views_alloc, src_views.size(), src_views.size(), src_views.size() + ); + + auto* src_rgba_views_device = std::get<0>(views_scratch); + auto* src_depth_views_device = std::get<1>(views_scratch); + auto* src_res_device = std::get<2>(views_scratch); + + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_rgba_views_device, + src_rgbd_views.data(), + src_views.size() * sizeof(MatrixView), + cudaMemcpyHostToDevice, + m_stream.get() + )); + CUDA_CHECK_THROW(cudaMemcpyAsync( + src_depth_views_device, src_depth_views.data(), src_views.size() * sizeof(MatrixView), cudaMemcpyHostToDevice, m_stream.get() + )); + CUDA_CHECK_THROW(cudaMemcpyAsync(src_res_device, src_res.data(), src_views.size() * sizeof(ivec2), cudaMemcpyHostToDevice, m_stream.get()) + ); + + auto dst_rgba_buffer = GPUImage(dst_res, PM_PATCH_RADIUS, m_stream.get()); + auto dst_depth_threshold_buffer = GPUImage(dst_res, PM_PATCH_RADIUS, m_stream.get()); + ivec2 padded_dst_res = dst_rgba_buffer.resolution_padded(); + + default_rng_t fixed_seed_rng{0x1337}; + + { + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)padded_dst_res.x, threads.x), div_round_up((uint32_t)padded_dst_res.y, threads.y), 1}; + + pm_prepare_padded_dst_buffers<<>>( + padded_dst_res, + dst_res, + (uint32_t)src_views.size(), + src_res_device, + fixed_seed_rng, + src_rgba_views_device, + src_depth_views_device, + dst_state_buffer.view(), + dst_view.index_field.view(), + dst_rgba_buffer.view(), + dst_view.depth_buffer.view(), + dst_depth_threshold_buffer.view(), + dst_view.hole_mask.view() + ); + + fixed_seed_rng.advance(); + } + + + reproject_viz_kernel<<>>( + dst_res, + src_res_device, + m_pm_enable, + hole_labels.view(), + dst_state_buffer.view(), + dst_view.index_field.view(), + dst_rgba_buffer.view(), + dst_view.depth_buffer.view(), + src_rgba_views_device, + src_depth_views_device, + MatrixView(dst_view.render_buffer->frame_buffer(), dst_res.x, 1), + MatrixView(dst_view.render_buffer->depth_buffer(), dst_res.x, 1), + m_pm_viz_mode, + 1.0f + ); +} + +void Testbed::render(bool skip_rendering) { + // Don't do any smoothing here if a camera path is being rendered. It'll take care + // of the smoothing on its own. + float frame_ms = m_camera_path.rendering ? 0.0f : m_frame_ms.val(); + apply_camera_smoothing(frame_ms); + + if (!m_render_window || !m_render || skip_rendering) { + return; + } + + auto start = std::chrono::steady_clock::now(); + ScopeGuard timing_guard{[&]() { + m_render_ms.update(std::chrono::duration(std::chrono::steady_clock::now() - start).count()); + }}; + + if (frobenius_norm(m_smoothed_camera - m_camera) < 0.001f) { + m_smoothed_camera = m_camera; + } else if (!m_camera_path.rendering) { + reset_accumulation(true); + } + + if (m_autofocus) { + autofocus(); + } + + Lens lens = m_render_with_lens_distortion ? m_render_lens : Lens{}; + +#ifdef NGP_GUI + if (m_hmd && m_hmd->is_visible()) { + for (auto& view : m_views) { + view.visualized_dimension = m_visualized_dimension; + } + + m_n_views = {(int)m_views.size(), 1}; + + m_render_with_lens_distortion = false; + reset_accumulation(true); + } else { + set_n_views(1); + m_n_views = {1, 1}; + + auto& view = m_views.front(); + + view.full_resolution = m_window_res; + + view.camera0 = m_smoothed_camera; + + // Motion blur over the fraction of time that the shutter is open. Interpolate in log-space to preserve rotations. + view.camera1 = (m_camera_path.rendering && !m_gen3c_render_with_gen3c) ? + camera_log_lerp(m_smoothed_camera, m_camera_path.render_frame_end_camera, m_camera_path.render_settings.shutter_fraction) : + view.camera0; + + view.visualized_dimension = m_visualized_dimension; + view.relative_focal_length = m_relative_focal_length; + view.screen_center = m_screen_center; + view.render_buffer->set_hidden_area_mask(nullptr); + view.foveation = {}; + view.lens = lens; + view.device = &primary_device(); + } + + if (m_dlss) { + m_aperture_size = 0.0f; + if (!m_render_lens.supports_dlss()) { + m_render_with_lens_distortion = false; + } + } + + // Update dynamic res and DLSS + { + // Don't count the time being spent allocating buffers and resetting DLSS as part of the frame time. + // Otherwise the dynamic resolution calculations for following frames will be thrown out of whack + // and may even start oscillating. + auto skip_start = std::chrono::steady_clock::now(); + ScopeGuard skip_timing_guard{[&]() { start += std::chrono::steady_clock::now() - skip_start; }}; + + size_t n_pixels = 0, n_pixels_full_res = 0; + for (const auto& view : m_views) { + n_pixels += product(view.render_buffer->in_resolution()); + n_pixels_full_res += product(view.full_resolution); + } + + float pixel_ratio = n_pixels == 0 ? (1.0f / 256.0f) : ((float)n_pixels / (float)n_pixels_full_res); + + float last_factor = std::sqrt(pixel_ratio); + float factor = std::sqrt(pixel_ratio / m_render_ms.val() * 1000.0f / m_dynamic_res_target_fps); + if (!m_dynamic_res) { + factor = 8.f / (float)m_fixed_res_factor; + } + + factor = clamp(factor, 1.0f / 16.0f, 1.0f); + + vec2 avg_screen_center = vec2(0.0f); + for (size_t i = 0; i < m_views.size(); ++i) { + avg_screen_center += m_views[i].screen_center; + } + + avg_screen_center /= (float)m_views.size(); + + for (auto&& view : m_views) { + if (m_dlss) { + view.render_buffer->enable_dlss(*m_dlss_provider, view.full_resolution); + } else { + view.render_buffer->disable_dlss(); + } + + ivec2 render_res = view.render_buffer->in_resolution(); + ivec2 new_render_res = clamp(ivec2(vec2(view.full_resolution) * factor), view.full_resolution / 16, view.full_resolution); + + if (m_camera_path.rendering && !m_gen3c_render_with_gen3c) { + new_render_res = m_camera_path.render_settings.resolution; + } + + float ratio = std::sqrt((float)product(render_res) / (float)product(new_render_res)); + if (ratio > 1.2f || ratio < 0.8f || factor == 1.0f || !m_dynamic_res || (m_camera_path.rendering && !m_gen3c_render_with_gen3c)) { + render_res = new_render_res; + } + + if (view.render_buffer->dlss()) { + render_res = view.render_buffer->dlss()->clamp_resolution(render_res); + view.render_buffer->dlss()->update_feature( + render_res, view.render_buffer->dlss()->is_hdr(), view.render_buffer->dlss()->sharpen() + ); + } + + view.render_buffer->resize(render_res); + + if (m_foveated_rendering) { + if (m_dynamic_foveated_rendering) { + vec2 resolution_scale = vec2(render_res) / vec2(view.full_resolution); + + // Only start foveation when DLSS if off or if DLSS is asked to do more than 1.5x upscaling. + // The reason for the 1.5x threshold is that DLSS can do up to 3x upscaling, at which point a + // foveation factor of 2x = 3.0x/1.5x corresponds exactly to bilinear super sampling, which is + // helpful in suppressing DLSS's artifacts. + float foveation_begin_factor = m_dlss ? 1.5f : 1.0f; + + resolution_scale = + clamp(resolution_scale * foveation_begin_factor, vec2(1.0f / m_foveated_rendering_max_scaling), vec2(1.0f)); + view.foveation = {resolution_scale, vec2(1.0f) - view.screen_center, vec2(m_foveated_rendering_full_res_diameter * 0.5f)}; + + m_foveated_rendering_scaling = 2.0f / sum(resolution_scale); + } else { + view.foveation = { + vec2(1.0f / m_foveated_rendering_scaling), + vec2(1.0f) - view.screen_center, + vec2(m_foveated_rendering_full_res_diameter * 0.5f) + }; + } + } else { + view.foveation = {}; + } + } + } + + // Make sure all in-use auxiliary GPUs have the latest model and bitfield + std::unordered_set devices_in_use; + for (auto& view : m_views) { + if (!view.device || devices_in_use.count(view.device) != 0) { + continue; + } + + devices_in_use.insert(view.device); + sync_device(*view.render_buffer, *view.device); + } + + if (m_reproject_enable) { + render_by_reprojection(m_stream.get(), m_views); + } else { + SyncedMultiStream synced_streams{m_stream.get(), m_views.size()}; + + std::vector> futures(m_views.size()); + for (size_t i = 0; i < m_views.size(); ++i) { + auto& view = m_views[i]; + futures[i] = view.device->enqueue_task([this, &view, stream = synced_streams.get(i)]() { + auto device_guard = use_device(stream, *view.render_buffer, *view.device); + render_frame_main( + *view.device, view.camera0, view.camera1, view.screen_center, view.relative_focal_length, view.foveation, view.lens, view.visualized_dimension + ); + }); + } + + for (size_t i = 0; i < m_views.size(); ++i) { + auto& view = m_views[i]; + + if (futures[i].valid()) { + futures[i].get(); + } + + render_frame_epilogue( + synced_streams.get(i), + view.camera0, + view.prev_camera, + view.screen_center, + view.relative_focal_length, + view.foveation, + view.prev_foveation, + view.lens, + *view.render_buffer, + true + ); + + view.prev_camera = view.camera0; + view.prev_foveation = view.foveation; + } + } + + for (size_t i = 0; i < m_views.size(); ++i) { + m_rgba_render_textures.at(i)->blit_from_cuda_mapping(); + m_depth_render_textures.at(i)->blit_from_cuda_mapping(); + } + + if (m_picture_in_picture_res > 0) { + ivec2 res{(int)m_picture_in_picture_res, (int)(m_picture_in_picture_res * 9.0f / 16.0f)}; + m_pip_render_buffer->resize(res); + if (m_pip_render_buffer->spp() < 8) { + // a bit gross, but let's copy the keyframe's state into the global state in order to not have to plumb + // through the fov etc to render_frame. + CameraKeyframe backup = copy_camera_to_keyframe(); + CameraKeyframe pip_kf = m_camera_path.eval_camera_path(m_camera_path.play_time); + set_camera_from_keyframe(pip_kf); + + if (m_reproject_enable) { + std::vector views(1); + auto& view = views.front(); + view.camera0 = pip_kf.m(); + view.camera1 = pip_kf.m(); + view.prev_camera = pip_kf.m(); + view.screen_center = m_screen_center; + view.relative_focal_length = m_relative_focal_length; + view.foveation = {}; + view.prev_foveation = {}; + view.lens = lens; + view.visualized_dimension = m_visualized_dimension; + view.render_buffer = m_pip_render_buffer; + + render_by_reprojection(m_stream.get(), views); + } else { + render_frame( + m_stream.get(), + pip_kf.m(), + pip_kf.m(), + pip_kf.m(), + m_screen_center, + m_relative_focal_length, + {}, // foveation + {}, // prev foveation + lens, + m_visualized_dimension, + *m_pip_render_buffer + ); + } + + set_camera_from_keyframe(backup); + m_pip_render_texture->blit_from_cuda_mapping(); + } + } +#endif + + CUDA_CHECK_THROW(cudaStreamSynchronize(m_stream.get())); +} + +mat4x3 Testbed::view_camera(size_t view_idx) const { + if (m_views.size() <= view_idx) { + throw std::runtime_error{fmt::format("View #{} does not exist.", view_idx)}; + } + + auto& view = m_views.at(view_idx); + return view.camera0; +} + + +#ifdef NGP_GUI +void Testbed::create_second_window() { + if (m_second_window.window) { + return; + } + + bool frameless = false; + glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE); + glfwWindowHint(GLFW_RESIZABLE, !frameless); + glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); + glfwWindowHint(GLFW_CENTER_CURSOR, false); + glfwWindowHint(GLFW_DECORATED, !frameless); + glfwWindowHint(GLFW_SCALE_TO_MONITOR, frameless); + glfwWindowHint(GLFW_TRANSPARENT_FRAMEBUFFER, true); + // get the window size / coordinates + int win_w = 0, win_h = 0, win_x = 0, win_y = 0; + GLuint ps = 0, vs = 0; + + { + win_w = 1920; + win_h = 1080; + win_x = 0x40000000; + win_y = 0x40000000; + static const char* copy_shader_vert = + "\ + in vec2 vertPos_data;\n\ + out vec2 texCoords;\n\ + void main(){\n\ + gl_Position = vec4(vertPos_data.xy, 0.0, 1.0);\n\ + texCoords = (vertPos_data.xy + 1.0) * 0.5; texCoords.y=1.0-texCoords.y;\n\ + }"; + static const char* copy_shader_frag = + "\ + in vec2 texCoords;\n\ + out vec4 fragColor;\n\ + uniform sampler2D screenTex;\n\ + void main(){\n\ + fragColor = texture(screenTex, texCoords.xy);\n\ + }"; + vs = compile_shader(false, copy_shader_vert); + ps = compile_shader(true, copy_shader_frag); + } + + m_second_window.window = glfwCreateWindow(win_w, win_h, "Fullscreen Output", NULL, m_glfw_window); + if (win_x != 0x40000000) { + glfwSetWindowPos(m_second_window.window, win_x, win_y); + } + + glfwMakeContextCurrent(m_second_window.window); + m_second_window.program = glCreateProgram(); + glAttachShader(m_second_window.program, vs); + glAttachShader(m_second_window.program, ps); + glLinkProgram(m_second_window.program); + if (!check_shader(m_second_window.program, "shader program", true)) { + glDeleteProgram(m_second_window.program); + m_second_window.program = 0; + } + + // vbo and vao + glGenVertexArrays(1, &m_second_window.vao); + glGenBuffers(1, &m_second_window.vbo); + glBindVertexArray(m_second_window.vao); + const float fsquadVerts[] = {-1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f}; + glBindBuffer(GL_ARRAY_BUFFER, m_second_window.vbo); + glBufferData(GL_ARRAY_BUFFER, sizeof(fsquadVerts), fsquadVerts, GL_STATIC_DRAW); + glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 2 * sizeof(float), (void*)0); + glEnableVertexAttribArray(0); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); +} + +void Testbed::set_n_views(size_t n_views) { + bool changed_views = n_views != m_views.size(); + + while (m_views.size() > n_views) { + m_views.pop_back(); + } + + m_rgba_render_textures.resize(n_views); + m_depth_render_textures.resize(n_views); + + while (m_views.size() < n_views) { + size_t idx = m_views.size(); + m_rgba_render_textures[idx] = std::make_shared(); + m_depth_render_textures[idx] = std::make_shared(); + m_views.emplace_back(View{std::make_shared(m_rgba_render_textures[idx], m_depth_render_textures[idx])}); + } + +}; +#endif // NGP_GUI + +void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) { +#ifndef NGP_GUI + throw std::runtime_error{"init_window failed: NGP was built without GUI support"}; +#else + m_window_res = {resw, resh}; + + glfwSetErrorCallback(glfw_error_callback); + if (!glfwInit()) { + throw std::runtime_error{"GLFW could not be initialized."}; + } + +# ifdef NGP_VULKAN + // Only try to initialize DLSS (Vulkan+NGX) if the + // GPU is sufficiently new. Older GPUs don't support + // DLSS, so it is preferable to not make a futile + // attempt and emit a warning that confuses users. + if (primary_device().compute_capability() >= 70) { + try { + m_dlss_provider = init_vulkan_and_ngx(); + } catch (const std::runtime_error& e) { + tlog::warning() << "Could not initialize Vulkan and NGX. DLSS not supported. (" << e.what() << ")"; + } + } +# endif + + glfwWindowHint(GLFW_VISIBLE, hidden ? GLFW_FALSE : GLFW_TRUE); + std::string title = "Gen3C GUI"; + m_glfw_window = glfwCreateWindow(m_window_res.x, m_window_res.y, title.c_str(), NULL, NULL); + if (m_glfw_window == NULL) { + throw std::runtime_error{"GLFW window could not be created."}; + } + glfwMakeContextCurrent(m_glfw_window); +# ifdef _WIN32 + if (gl3wInit()) { + throw std::runtime_error{"GL3W could not be initialized."}; + } +# else + glewExperimental = 1; + if (glewInit()) { + throw std::runtime_error{"GLEW could not be initialized."}; + } +# endif + glfwSwapInterval(m_vsync ? 1 : 0); // Disable vsync + + GLint gl_version_minor, gl_version_major; + glGetIntegerv(GL_MINOR_VERSION, &gl_version_minor); + glGetIntegerv(GL_MAJOR_VERSION, &gl_version_major); + + if (gl_version_major < 3 || (gl_version_major == 3 && gl_version_minor < 1)) { + throw std::runtime_error{ + fmt::format("Unsupported OpenGL version {}.{}. Gen3C requires at least OpenGL 3.1", gl_version_major, gl_version_minor) + }; + } + + tlog::success() << "Initialized OpenGL version " << glGetString(GL_VERSION); + + glfwSetWindowUserPointer(m_glfw_window, this); + glfwSetDropCallback(m_glfw_window, [](GLFWwindow* window, int count, const char** paths) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (!testbed) { + return; + } + + if (testbed->m_file_drop_callback) { + if (testbed->m_file_drop_callback(std::vector(paths, paths + count))) { + // Files were handled by the callback. + return; + } + } + + for (int i = 0; i < count; i++) { + testbed->load_file(paths[i]); + } + }); + + glfwSetKeyCallback(m_glfw_window, [](GLFWwindow* window, int key, int scancode, int action, int mods) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed) { + testbed->redraw_gui_next_frame(); + } + }); + + glfwSetCursorPosCallback(m_glfw_window, [](GLFWwindow* window, double xpos, double ypos) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed && (ImGui::IsAnyItemActive() || ImGui::GetIO().WantCaptureMouse || ImGuizmo::IsUsing()) && + (ImGui::GetIO().MouseDown[0] || ImGui::GetIO().MouseDown[1] || ImGui::GetIO().MouseDown[2])) { + testbed->redraw_gui_next_frame(); + } + }); + + glfwSetMouseButtonCallback(m_glfw_window, [](GLFWwindow* window, int button, int action, int mods) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed) { + testbed->redraw_gui_next_frame(); + } + }); + + glfwSetScrollCallback(m_glfw_window, [](GLFWwindow* window, double xoffset, double yoffset) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed) { + testbed->redraw_gui_next_frame(); + } + }); + + glfwSetWindowSizeCallback(m_glfw_window, [](GLFWwindow* window, int width, int height) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed) { + testbed->redraw_next_frame(); + } + }); + + glfwSetFramebufferSizeCallback(m_glfw_window, [](GLFWwindow* window, int width, int height) { + Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window); + if (testbed) { + testbed->redraw_next_frame(); + } + }); + + float xscale, yscale; + glfwGetWindowContentScale(m_glfw_window, &xscale, &yscale); + + // IMGUI init + IMGUI_CHECKVERSION(); + ImGui::CreateContext(); + ImGuiIO& io = ImGui::GetIO(); + (void)io; + + // By default, imgui places its configuration (state of the GUI -- size of windows, which regions are expanded, etc.) in ./imgui.ini + // relative to the working directory. Instead, we would like to place imgui.ini in the directory that Gen3C project resides in. + static std::string ini_filename; + ini_filename = (root_dir() / "imgui.ini").str(); + io.IniFilename = ini_filename.c_str(); + + // New ImGui event handling seems to make camera controls laggy if input trickling is true. So disable input trickling. + io.ConfigInputTrickleEventQueue = false; + ImGui::StyleColorsDark(); + ImGui_ImplGlfw_InitForOpenGL(m_glfw_window, true); + ImGui_ImplOpenGL3_Init("#version 140"); + + ImGui::GetStyle().ScaleAllSizes(xscale); + ImFontConfig font_cfg; + font_cfg.SizePixels = 13.0f * xscale; + io.Fonts->AddFontDefault(&font_cfg); + ImFontConfig overlay_font_cfg; + overlay_font_cfg.SizePixels = 128.0f * xscale; + m_imgui.overlay_font = io.Fonts->AddFontDefault(&overlay_font_cfg); + + init_opengl_shaders(); + + // Make sure there's at least one usable render texture + set_n_views(1); + m_views.front().full_resolution = m_window_res; + m_views.front().render_buffer->resize(m_views.front().full_resolution); + + m_pip_render_texture = std::make_shared(); + m_pip_render_buffer = std::make_shared(m_pip_render_texture); + + m_render_window = true; + + if (m_second_window.window == nullptr && second_window) { + create_second_window(); + } +#endif // NGP_GUI +} + +void Testbed::destroy_window() { +#ifndef NGP_GUI + throw std::runtime_error{"destroy_window failed: NGP was built without GUI support"}; +#else + if (!m_render_window) { + throw std::runtime_error{"Window must be initialized to be destroyed."}; + } + + m_hmd.reset(); + + m_views.clear(); + m_rgba_render_textures.clear(); + m_depth_render_textures.clear(); + + m_pip_render_buffer.reset(); + m_pip_render_texture.reset(); + + m_dlss = false; + m_dlss_provider.reset(); + + ImGui_ImplOpenGL3_Shutdown(); + ImGui_ImplGlfw_Shutdown(); + ImGui::DestroyContext(); + glfwDestroyWindow(m_glfw_window); + glfwTerminate(); + + m_blit_program = 0; + m_blit_vao = 0; + + m_glfw_window = nullptr; + m_render_window = false; +#endif // NGP_GUI +} + +void Testbed::init_vr() { +#ifndef NGP_GUI + throw std::runtime_error{"init_vr failed: NGP was built without GUI support"}; +#else + try { + if (!m_glfw_window) { + throw std::runtime_error{"`init_window` must be called before `init_vr`"}; + } + +# if defined(XR_USE_PLATFORM_WIN32) + m_hmd = std::make_unique(wglGetCurrentDC(), glfwGetWGLContext(m_glfw_window)); +# elif defined(XR_USE_PLATFORM_XLIB) + Display* xDisplay = glfwGetX11Display(); + GLXContext glxContext = glfwGetGLXContext(m_glfw_window); + + int glxFBConfigXID = 0; + glXQueryContext(xDisplay, glxContext, GLX_FBCONFIG_ID, &glxFBConfigXID); + int attributes[3] = {GLX_FBCONFIG_ID, glxFBConfigXID, 0}; + int nelements = 1; + GLXFBConfig* pglxFBConfig = glXChooseFBConfig(xDisplay, 0, attributes, &nelements); + if (nelements != 1 || !pglxFBConfig) { + throw std::runtime_error{"init_vr(): Couldn't obtain GLXFBConfig"}; + } + + GLXFBConfig glxFBConfig = *pglxFBConfig; + + XVisualInfo* visualInfo = glXGetVisualFromFBConfig(xDisplay, glxFBConfig); + if (!visualInfo) { + throw std::runtime_error{"init_vr(): Couldn't obtain XVisualInfo"}; + } + + m_hmd = std::make_unique(xDisplay, visualInfo->visualid, glxFBConfig, glXGetCurrentDrawable(), glxContext); +# elif defined(XR_USE_PLATFORM_WAYLAND) + m_hmd = std::make_unique(glfwGetWaylandDisplay()); +# endif + + // Enable aggressive optimizations to make the VR experience smooth. + update_vr_performance_settings(); + + // If multiple GPUs are available, shoot for 60 fps in VR. + // Otherwise, it wouldn't be realistic to expect more than 30. + m_dynamic_res_target_fps = m_devices.size() > 1 ? 60 : 30; + m_background_color = {0.0f, 0.0f, 0.0f, 0.0f}; + } catch (const std::runtime_error& e) { + if (std::string{e.what()}.find("XR_ERROR_FORM_FACTOR_UNAVAILABLE") != std::string::npos) { + throw std::runtime_error{ + "Could not initialize VR. Ensure that SteamVR, OculusVR, or any other OpenXR-compatible runtime is running. Also set it as the active OpenXR runtime." + }; + } else { + throw std::runtime_error{fmt::format("Could not initialize VR: {}", e.what())}; + } + } +#endif // NGP_GUI +} + +void Testbed::update_vr_performance_settings() { +#ifdef NGP_GUI + if (m_hmd) { + auto blend_mode = m_hmd->environment_blend_mode(); + + // DLSS is instrumental in getting VR to look good. Enable if possible. + // If the environment is blended in (such as in XR/AR applications), + // DLSS causes jittering at object sillhouettes (doesn't deal well with alpha), + // and hence stays disabled. + m_dlss = (blend_mode == EEnvironmentBlendMode::Opaque) && m_dlss_provider; + + // Foveated rendering is similarly vital in getting high performance without losing + // resolution in the middle of the view. + m_foveated_rendering = true; + + // Many VR runtimes perform optical flow for automatic reprojection / motion smoothing. + // This breaks down for solid-color background, sometimes leading to artifacts. Hence: + // set background color to transparent and, in spherical_checkerboard_kernel(...), + // blend a checkerboard. If the user desires a solid background nonetheless, they can + // set the background color to have an alpha value of 1.0 manually via the GUI or via Python. + m_render_transparency_as_checkerboard = (blend_mode == EEnvironmentBlendMode::Opaque); + } else { + m_foveated_rendering = false; + m_render_transparency_as_checkerboard = false; + } +#endif // NGP_GUI +} + +bool Testbed::frame() { +#ifdef NGP_GUI + if (m_render_window) { + if (!begin_frame()) { + return false; + } + + handle_user_input(); + begin_vr_frame_and_handle_vr_input(); + } +#endif + + bool skip_rendering = false; + if (!m_dlss && m_max_spp > 0 && !m_views.empty() && m_views.front().render_buffer->spp() >= m_max_spp) { + skip_rendering = true; + } + + if (m_camera_path.rendering && !m_gen3c_render_with_gen3c) { + prepare_next_camera_path_frame(); + skip_rendering = false; + } + + if (m_record_camera_path && !m_views.empty()) { + m_camera_path.spline_order = 1; + const float timestamp = m_camera_path.duration_seconds() + m_frame_ms.val() / 1000.0f; + m_camera_path.add_camera(m_views[0].camera0, focal_length_to_fov(1.0f, m_views[0].relative_focal_length[m_fov_axis]), timestamp); + + m_camera_path.keyframe_subsampling = (int)m_camera_path.keyframes.size(); + m_camera_path.editing_kernel_type = EEditingKernel::Gaussian; + } + +#ifdef NGP_GUI + if (m_hmd && m_hmd->is_visible()) { + skip_rendering = false; + } +#endif + + if (!skip_rendering || std::chrono::steady_clock::now() - m_last_gui_draw_time_point > 50ms) { + redraw_gui_next_frame(); + } + + try { + while (true) { + (*m_task_queue.tryPop())(); + } + } catch (const SharedQueueEmptyException&) {} + + render(skip_rendering); + +#ifdef NGP_GUI + if (m_render_window) { + if (m_gui_redraw) { + draw_gui(); + m_gui_redraw = false; + + m_last_gui_draw_time_point = std::chrono::steady_clock::now(); + } + + ImGui::EndFrame(); + } + + if (m_hmd && m_vr_frame_info) { + // If HMD is visible to the user, splat rendered images to the HMD + if (m_hmd->is_visible()) { + size_t n_views = std::min(m_views.size(), m_vr_frame_info->views.size()); + + // Blit textures to the OpenXR-owned framebuffers (each corresponding to one eye) + for (size_t i = 0; i < n_views; ++i) { + const auto& vr_view = m_vr_frame_info->views.at(i); + + ivec2 resolution = { + vr_view.view.subImage.imageRect.extent.width, + vr_view.view.subImage.imageRect.extent.height, + }; + + blit_texture( + m_views.at(i).foveation, + m_rgba_render_textures.at(i)->texture(), + GL_LINEAR, + m_depth_render_textures.at(i)->texture(), + vr_view.framebuffer, + ivec2(0), + resolution + ); + } + + glFinish(); + } + + // Far and near planes are intentionally reversed, because we map depth inversely + // to z. I.e. a window-space depth of 1 refers to the near plane and a depth of 0 + // to the far plane. This results in much better numeric precision. + m_hmd->end_frame(m_vr_frame_info, m_ndc_zfar / m_scale, m_ndc_znear / m_scale, m_vr_use_depth_reproject); + } +#endif + + return true; +} + +bool Testbed::want_repl() { + bool b = m_want_repl; + m_want_repl = false; + return b; +} + +void Testbed::apply_camera_smoothing(float elapsed_ms) { + // Ensure our camera rotation remains an orthogonal matrix as numeric + // errors accumulate across frames. + m_camera = orthogonalize(m_camera); + + if (m_camera_smoothing) { + float decay = std::pow(0.02f, elapsed_ms / 1000.0f); + m_smoothed_camera = orthogonalize(camera_log_lerp(m_smoothed_camera, m_camera, 1.0f - decay)); + } else { + m_smoothed_camera = m_camera; + } +} + +CameraKeyframe Testbed::copy_camera_to_keyframe() const { return CameraKeyframe(m_camera, fov(), 0.0f); } + +void Testbed::set_camera_from_keyframe(const CameraKeyframe& k) { + m_camera = k.m(); + set_fov(k.fov); +} + +void Testbed::set_camera_from_time(float t) { + if (m_camera_path.keyframes.empty()) { + return; + } + + set_camera_from_keyframe(m_camera_path.eval_camera_path(t)); +} + +float Testbed::fov() const { return focal_length_to_fov(1.0f, m_relative_focal_length[m_fov_axis]); } + +void Testbed::set_fov(float val) { m_relative_focal_length = vec2(fov_to_focal_length(1, val)); } + +vec2 Testbed::fov_xy() const { return focal_length_to_fov(ivec2(1), m_relative_focal_length); } + +void Testbed::set_fov_xy(const vec2& val) { m_relative_focal_length = fov_to_focal_length(ivec2(1), val); } + +Testbed::Testbed(ETestbedMode mode) { + tcnn::set_log_callback([](LogSeverity severity, const std::string& msg) { + tlog::ESeverity s = tlog::ESeverity::Info; + switch (severity) { + case LogSeverity::Info: s = tlog::ESeverity::Info; break; + case LogSeverity::Debug: s = tlog::ESeverity::Debug; break; + case LogSeverity::Warning: s = tlog::ESeverity::Warning; break; + case LogSeverity::Error: s = tlog::ESeverity::Error; break; + case LogSeverity::Success: s = tlog::ESeverity::Success; break; + default: break; + } + tlog::log(s) << msg; + }); + + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + throw std::runtime_error{"Testbed requires CUDA 10.2 or later."}; + } + +#ifdef NGP_GUI + // Ensure we're running on the GPU that'll host our GUI. To do so, try creating a dummy + // OpenGL context, figure out the GPU it's running on, and then kill that context again. + if (!is_wsl() && glfwInit()) { + glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); + GLFWwindow* offscreen_context = glfwCreateWindow(640, 480, "", NULL, NULL); + + if (offscreen_context) { + glfwMakeContextCurrent(offscreen_context); + + int gl_device = -1; + unsigned int device_count = 0; + if (cudaGLGetDevices(&device_count, &gl_device, 1, cudaGLDeviceListAll) == cudaSuccess) { + if (device_count > 0 && gl_device >= 0) { + set_cuda_device(gl_device); + } + } + + glfwDestroyWindow(offscreen_context); + } + + glfwTerminate(); + } +#endif + + // Reset our stream, which was allocated on the originally active device, + // to make sure it corresponds to the now active device. + m_stream = {}; + + int active_device = cuda_device(); + int active_compute_capability = cuda_compute_capability(); + tlog::success() << fmt::format( + "Initialized CUDA {}. Active GPU is #{}: {} [{}]", cuda_runtime_version_string(), active_device, cuda_device_name(), active_compute_capability + ); + + if (active_compute_capability < MIN_GPU_ARCH) { + tlog::warning() << "Insufficient compute capability " << active_compute_capability << " detected."; + tlog::warning() << "This program was compiled for >=" << MIN_GPU_ARCH << " and may thus behave unexpectedly."; + } + + m_devices.emplace_back(active_device, true); + + int n_devices = cuda_device_count(); + for (int i = 0; i < n_devices; ++i) { + if (i == active_device) { + continue; + } + + if (cuda_compute_capability(i) >= MIN_GPU_ARCH) { + m_devices.emplace_back(i, false); + } + } + + if (m_devices.size() > 1) { + tlog::success() << "Detected auxiliary GPUs:"; + for (size_t i = 1; i < m_devices.size(); ++i) { + const auto& device = m_devices[i]; + tlog::success() << " #" << device.id() << ": " << device.name() << " [" << device.compute_capability() << "]"; + } + } + + set_mode(mode); + set_exposure(0); + + reset_camera(); +} + +Testbed::~Testbed() { + // If any temporary file was created, make sure it's deleted + clear_tmp_dir(); + + if (m_render_window) { + destroy_window(); + } +} + +bool Testbed::clear_tmp_dir() { + wait_all(m_render_futures); + m_render_futures.clear(); + + bool success = true; + auto tmp_dir = fs::path{"tmp"}; + if (tmp_dir.exists()) { + if (tmp_dir.is_directory()) { + for (const auto& path : fs::directory{tmp_dir}) { + if (path.is_file()) { + success &= path.remove_file(); + } + } + } + + success &= tmp_dir.remove_file(); + } + + return success; +} + +vec2 Testbed::calc_focal_length(const ivec2& resolution, const vec2& relative_focal_length, int fov_axis, float zoom) const { + return relative_focal_length * (float)resolution[fov_axis] * zoom; +} + +vec2 Testbed::render_screen_center(const vec2& screen_center) const { + // see pixel_to_ray for how screen center is used; 0.5, 0.5 is 'normal'. we flip so that it becomes the point in the + // original image we want to center on. + return (0.5f - screen_center) * m_zoom + 0.5f; +} + +__global__ void dlss_prep_kernel( + ivec2 resolution, + uint32_t sample_index, + vec2 focal_length, + vec2 screen_center, + vec3 parallax_shift, + bool snap_to_pixel_centers, + float* depth_buffer, + const float znear, + const float zfar, + mat4x3 camera, + mat4x3 prev_camera, + cudaSurfaceObject_t depth_surface, + cudaSurfaceObject_t mvec_surface, + cudaSurfaceObject_t exposure_surface, + Foveation foveation, + Foveation prev_foveation, + Lens lens +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + uint32_t idx = x + resolution.x * y; + + uint32_t x_orig = x; + uint32_t y_orig = y; + + const float depth = depth_buffer[idx]; + vec2 mvec = motion_vector( + sample_index, + {(int)x, (int)y}, + resolution, + focal_length, + camera, + prev_camera, + screen_center, + parallax_shift, + snap_to_pixel_centers, + depth, + foveation, + prev_foveation, + lens + ); + + surf2Dwrite(make_float2(mvec.x, mvec.y), mvec_surface, x_orig * sizeof(float2), y_orig); + + // DLSS was trained on games, which presumably used standard normalized device coordinates (ndc) + // depth buffers. So: convert depth to NDC with reasonable near- and far planes. + surf2Dwrite(to_ndc_depth(depth, znear, zfar), depth_surface, x_orig * sizeof(float), y_orig); + + // First thread write an exposure factor of 1. Since DLSS will run on tonemapped data, + // exposure is assumed to already have been applied to DLSS' inputs. + if (x_orig == 0 && y_orig == 0) { + surf2Dwrite(1.0f, exposure_surface, 0, 0); + } +} + +__global__ void spherical_checkerboard_kernel( + ivec2 resolution, + vec2 focal_length, + mat4x3 camera, + vec2 screen_center, + vec3 parallax_shift, + Foveation foveation, + Lens lens, + vec4 background_color, + vec4* frame_buffer +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + Ray ray = pixel_to_ray( + 0, + {(int)x, (int)y}, + resolution, + focal_length, + camera, + screen_center, + parallax_shift, + false, + 0.0f, + 1.0f, + 0.0f, + foveation, + {}, // No need for hidden area mask + lens + ); + + // Blend with checkerboard to break up reprojection weirdness in some VR runtimes + host_device_swap(ray.d.z, ray.d.y); + vec2 spherical = dir_to_spherical(normalize(ray.d)) * 32.0f / PI(); + const vec4 dark_gray = {0.5f, 0.5f, 0.5f, 1.0f}; + const vec4 light_gray = {0.55f, 0.55f, 0.55f, 1.0f}; + vec4 checker = fabsf(fmodf(floorf(spherical.x) + floorf(spherical.y), 2.0f)) < 0.5f ? dark_gray : light_gray; + + // Blend background color on top of checkerboard first (checkerboard is meant to be "behind" the background, + // representing transparency), and then blend the result behind the frame buffer. + background_color.rgb() = srgb_to_linear(background_color.rgb()); + background_color += (1.0f - background_color.a) * checker; + + uint32_t idx = x + resolution.x * y; + frame_buffer[idx] += (1.0f - frame_buffer[idx].a) * background_color; +} + +__global__ void vr_overlay_hands_kernel( + ivec2 resolution, + vec2 focal_length, + mat4x3 camera, + vec2 screen_center, + vec3 parallax_shift, + Foveation foveation, + Lens lens, + vec3 left_hand_pos, + float left_grab_strength, + vec4 left_hand_color, + vec3 right_hand_pos, + float right_grab_strength, + vec4 right_hand_color, + float hand_radius, + EColorSpace output_color_space, + cudaSurfaceObject_t surface + // TODO: overwrite depth buffer +) { + uint32_t x = threadIdx.x + blockDim.x * blockIdx.x; + uint32_t y = threadIdx.y + blockDim.y * blockIdx.y; + + if (x >= resolution.x || y >= resolution.y) { + return; + } + + Ray ray = pixel_to_ray( + 0, + {(int)x, (int)y}, + resolution, + focal_length, + camera, + screen_center, + parallax_shift, + false, + 0.0f, + 1.0f, + 0.0f, + foveation, + {}, // No need for hidden area mask + lens + ); + + vec4 color = vec4(0.0f); + auto composit_hand = [&](vec3 hand_pos, float grab_strength, vec4 hand_color) { + // Don't render the hand indicator if it's behind the ray origin. + if (dot(ray.d, hand_pos - ray.o) < 0.0f) { + return; + } + + float distance = ray.distance_to(hand_pos); + + vec4 base_color = vec4(0.0f); + const vec4 border_color = {0.4f, 0.4f, 0.4f, 0.4f}; + + // Divide hand radius into an inner part (4/5ths) and a border (1/5th). + float radius = hand_radius * 0.8f; + float border_width = hand_radius * 0.2f; + + // When grabbing, shrink the inner part as a visual indicator. + radius *= 0.5f + 0.5f * (1.0f - grab_strength); + + if (distance < radius) { + base_color = hand_color; + } else if (distance < radius + border_width) { + base_color = border_color; + } else { + return; + } + + // Make hand color opaque when grabbing. + base_color.a = grab_strength + (1.0f - grab_strength) * base_color.a; + color += base_color * (1.0f - color.a); + }; + + if (dot(ray.d, left_hand_pos - ray.o) < dot(ray.d, right_hand_pos - ray.o)) { + composit_hand(left_hand_pos, left_grab_strength, left_hand_color); + composit_hand(right_hand_pos, right_grab_strength, right_hand_color); + } else { + composit_hand(right_hand_pos, right_grab_strength, right_hand_color); + composit_hand(left_hand_pos, left_grab_strength, left_hand_color); + } + + // Blend with existing color of pixel + vec4 prev_color; + surf2Dread((float4*)&prev_color, surface, x * sizeof(float4), y); + if (output_color_space == EColorSpace::SRGB) { + prev_color.rgb() = srgb_to_linear(prev_color.rgb()); + } + + color += (1.0f - color.a) * prev_color; + + if (output_color_space == EColorSpace::SRGB) { + color.rgb() = linear_to_srgb(color.rgb()); + } + + surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y); +} + +void Testbed::render_by_reprojection(cudaStream_t stream, std::vector& views) { + // Reprojection from view cache + int n_src_views = std::max(std::min(m_reproject_max_src_view_index, (int)m_reproject_src_views.size()) - m_reproject_min_src_view_index, 0); + + std::vector src_views(n_src_views); + for (int i = 0; i < n_src_views; ++i) { + // Invert order of src views to reproject from the most recent one first and fill in the holes / closer content with older views. + src_views[n_src_views - i - 1] = &m_reproject_src_views[i + m_reproject_min_src_view_index]; + } + + for (size_t i = 0; i < views.size(); ++i) { + auto& view = views[i]; + + reproject_views(src_views, view); + + render_frame_epilogue( + stream, + view.camera0, + view.prev_camera, + view.screen_center, + view.relative_focal_length, + view.foveation, + view.prev_foveation, + view.lens, + *view.render_buffer, + true + ); + + view.prev_camera = view.camera0; + view.prev_foveation = view.foveation; + } +} + +void Testbed::render_frame( + cudaStream_t stream, + const mat4x3& camera_matrix0, + const mat4x3& camera_matrix1, + const mat4x3& prev_camera_matrix, + const vec2& orig_screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Foveation& prev_foveation, + const Lens& lens, + int visualized_dimension, + CudaRenderBuffer& render_buffer, + bool to_srgb, + CudaDevice* device +) { + if (!device) { + device = &primary_device(); + } + + sync_device(render_buffer, *device); + + { + auto device_guard = use_device(stream, render_buffer, *device); + render_frame_main( + *device, camera_matrix0, camera_matrix1, orig_screen_center, relative_focal_length, foveation, lens, visualized_dimension + ); + } + + render_frame_epilogue( + stream, camera_matrix0, prev_camera_matrix, orig_screen_center, relative_focal_length, foveation, prev_foveation, lens, render_buffer, to_srgb + ); +} + +void Testbed::render_frame_main( + CudaDevice& device, + const mat4x3& camera_matrix0, + const mat4x3& camera_matrix1, + const vec2& orig_screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Lens& lens, + int visualized_dimension +) { + device.render_buffer_view().clear(device.stream()); + + vec2 focal_length = calc_focal_length(device.render_buffer_view().resolution, relative_focal_length, m_fov_axis, m_zoom); + vec2 screen_center = render_screen_center(orig_screen_center); +} + +void Testbed::render_frame_epilogue( + cudaStream_t stream, + const mat4x3& camera_matrix0, + const mat4x3& prev_camera_matrix, + const vec2& orig_screen_center, + const vec2& relative_focal_length, + const Foveation& foveation, + const Foveation& prev_foveation, + const Lens& lens, + CudaRenderBuffer& render_buffer, + bool to_srgb +) { + vec2 focal_length = calc_focal_length(render_buffer.in_resolution(), relative_focal_length, m_fov_axis, m_zoom); + vec2 screen_center = render_screen_center(orig_screen_center); + + render_buffer.set_color_space(m_color_space); + render_buffer.set_tonemap_curve(m_tonemap_curve); + + // Prepare DLSS data: motion vectors, scaled depth, exposure + if (render_buffer.dlss()) { + auto res = render_buffer.in_resolution(); + + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + + dlss_prep_kernel<<>>( + res, + render_buffer.spp(), + focal_length, + screen_center, + m_parallax_shift, + m_snap_to_pixel_centers, + render_buffer.depth_buffer(), + m_ndc_znear, + m_ndc_zfar, + camera_matrix0, + prev_camera_matrix, + render_buffer.dlss()->depth(), + render_buffer.dlss()->mvec(), + render_buffer.dlss()->exposure(), + foveation, + prev_foveation, + lens + ); + + render_buffer.set_dlss_sharpening(m_dlss_sharpening); + } + + EColorSpace output_color_space = to_srgb ? EColorSpace::SRGB : EColorSpace::Linear; + + if (m_render_transparency_as_checkerboard) { + mat4x3 checkerboard_transform = mat4x3::identity(); + +#ifdef NGP_GUI + if (m_hmd && m_vr_frame_info && !m_vr_frame_info->views.empty()) { + checkerboard_transform = m_vr_frame_info->views[0].pose; + } +#endif + + auto res = render_buffer.in_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + spherical_checkerboard_kernel<<>>( + res, + focal_length, + checkerboard_transform, + screen_center, + m_parallax_shift, + foveation, + lens, + m_background_color, + render_buffer.frame_buffer() + ); + } + + render_buffer.accumulate(m_exposure, stream); + render_buffer.tonemap(m_exposure, m_background_color, output_color_space, m_ndc_znear, m_ndc_zfar, m_snap_to_pixel_centers, stream); + +#ifdef NGP_GUI + // If in VR, indicate the hand position and render transparent background + if (m_hmd && m_vr_frame_info) { + auto& hands = m_vr_frame_info->hands; + + auto res = render_buffer.out_resolution(); + const dim3 threads = {16, 8, 1}; + const dim3 blocks = {div_round_up((uint32_t)res.x, threads.x), div_round_up((uint32_t)res.y, threads.y), 1}; + vr_overlay_hands_kernel<<>>( + res, + focal_length * vec2(render_buffer.out_resolution()) / vec2(render_buffer.in_resolution()), + camera_matrix0, + screen_center, + m_parallax_shift, + foveation, + lens, + vr_to_world(hands[0].pose[3]), + hands[0].grab_strength, + {hands[0].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f}, + vr_to_world(hands[1].pose[3]), + hands[1].grab_strength, + {hands[1].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f}, + 0.05f * m_scale, // Hand radius + output_color_space, + render_buffer.surface() + ); + } +#endif +} + +float Testbed::get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const vec2& uv) { + if (!render_buffer.depth_buffer()) { + return m_scale; + } + + float depth; + auto res = render_buffer.in_resolution(); + ivec2 depth_pixel = clamp(ivec2(uv * vec2(res)), 0, res - 1); + + CUDA_CHECK_THROW( + cudaMemcpy(&depth, render_buffer.depth_buffer() + depth_pixel.x + depth_pixel.y * res.x, sizeof(float), cudaMemcpyDeviceToHost) + ); + return depth; +} + +vec3 Testbed::get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const vec2& pixel) { + float depth = get_depth_from_renderbuffer(render_buffer, pixel / vec2(m_window_res)); + auto ray = pixel_to_ray_pinhole( + 0, + ivec2(pixel), + m_window_res, + calc_focal_length(m_window_res, m_relative_focal_length, m_fov_axis, m_zoom), + m_smoothed_camera, + render_screen_center(m_screen_center) + ); + return ray(depth); +} + +void Testbed::autofocus() { + float new_slice_plane_z = std::max(dot(view_dir(), m_autofocus_target - view_pos()), 0.1f) - m_scale; + if (new_slice_plane_z != m_slice_plane_z) { + m_slice_plane_z = new_slice_plane_z; + if (m_aperture_size != 0.0f) { + reset_accumulation(); + } + } +} + +Testbed::LevelStats compute_level_stats(const float* params, size_t n_params) { + Testbed::LevelStats s = {}; + for (size_t i = 0; i < n_params; ++i) { + float v = params[i]; + float av = fabsf(v); + if (av < 0.00001f) { + s.numzero++; + } else { + if (s.count == 0) { + s.min = s.max = v; + } + s.count++; + s.x += v; + s.xsquared += v * v; + s.min = min(s.min, v); + s.max = max(s.max, v); + } + } + return s; +} + +Testbed::CudaDevice::CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} { + auto guard = device_guard(); + m_stream = std::make_unique(); + m_data = std::make_unique(); + m_render_worker = std::make_unique(is_primary ? 0u : 1u); +} + +ScopeGuard Testbed::CudaDevice::device_guard() { + int prev_device = cuda_device(); + if (prev_device == m_id) { + return {}; + } + + set_cuda_device(m_id); + return ScopeGuard{[prev_device]() { set_cuda_device(prev_device); }}; +} + +void Testbed::sync_device(CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) { + if (!device.dirty()) { + return; + } + + if (device.is_primary()) { + device.data().hidden_area_mask = render_buffer.hidden_area_mask(); + device.set_dirty(false); + return; + } + + m_stream.signal(device.stream()); + + int active_device = cuda_device(); + auto guard = device.device_guard(); + + if (render_buffer.hidden_area_mask()) { + auto ham = std::make_shared>(render_buffer.hidden_area_mask()->resolution()); + CUDA_CHECK_THROW(cudaMemcpyPeerAsync( + ham->data(), device.id(), render_buffer.hidden_area_mask()->data(), active_device, ham->bytes(), device.stream() + )); + device.data().hidden_area_mask = ham; + } else { + device.data().hidden_area_mask = nullptr; + } + + device.set_dirty(false); + device.signal(m_stream.get()); +} + +// From https://stackoverflow.com/questions/20843271/passing-a-non-copyable-closure-object-to-stdfunction-parameter +template auto make_copyable_function(F&& f) { + using dF = std::decay_t; + auto spf = std::make_shared(std::forward(f)); + return [spf](auto&&... args) -> decltype(auto) { return (*spf)(decltype(args)(args)...); }; +} + +ScopeGuard Testbed::use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) { + device.wait_for(stream); + + if (device.is_primary()) { + device.set_render_buffer_view(render_buffer.view()); + return ScopeGuard{[&device, stream]() { + device.set_render_buffer_view({}); + device.signal(stream); + }}; + } + + int active_device = cuda_device(); + auto guard = device.device_guard(); + + size_t n_pixels = product(render_buffer.in_resolution()); + + GPUMemoryArena::Allocation alloc; + auto scratch = allocate_workspace_and_distribute(device.stream(), &alloc, n_pixels, n_pixels); + + device.set_render_buffer_view({ + std::get<0>(scratch), + std::get<1>(scratch), + render_buffer.in_resolution(), + render_buffer.spp(), + device.data().hidden_area_mask, + }); + + return ScopeGuard{ + make_copyable_function([&render_buffer, &device, guard = std::move(guard), alloc = std::move(alloc), active_device, stream]() { + // Copy device's render buffer's data onto the original render buffer + CUDA_CHECK_THROW(cudaMemcpyPeerAsync( + render_buffer.frame_buffer(), + active_device, + device.render_buffer_view().frame_buffer, + device.id(), + product(render_buffer.in_resolution()) * sizeof(vec4), + device.stream() + )); + CUDA_CHECK_THROW(cudaMemcpyPeerAsync( + render_buffer.depth_buffer(), + active_device, + device.render_buffer_view().depth_buffer, + device.id(), + product(render_buffer.in_resolution()) * sizeof(float), + device.stream() + )); + + device.set_render_buffer_view({}); + device.signal(stream); + }) + }; +} + +void Testbed::set_all_devices_dirty() { + for (auto& device : m_devices) { + device.set_dirty(true); + } +} + +void Testbed::load_camera_path(const fs::path& path) { m_camera_path.load(path, mat4x3::identity()); } + +bool Testbed::loop_animation() { return m_camera_path.loop; } + +void Testbed::set_loop_animation(bool value) { m_camera_path.loop = value; } + +} // namespace ngp diff --git a/gui/src/thread_pool.cpp b/gui/src/thread_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d175af8317781fc18edd9a916ef00e200558a10c --- /dev/null +++ b/gui/src/thread_pool.cpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file was taken from the tev image viewer and is re-released here +// under the NVIDIA Source Code License with permission from the author. + +#include +#include + +#include + +namespace ngp { + +ThreadPool::ThreadPool() +: ThreadPool{std::thread::hardware_concurrency()} {} + +ThreadPool::ThreadPool(size_t max_num_threads, bool force) { + if (!force) { + max_num_threads = std::min((size_t)std::thread::hardware_concurrency(), max_num_threads); + } + start_threads(max_num_threads); +} + +ThreadPool::~ThreadPool() { + wait_until_queue_completed(); + shutdown_threads(m_threads.size()); +} + +void ThreadPool::start_threads(size_t num) { + m_num_threads += num; + for (size_t i = m_threads.size(); i < m_num_threads; ++i) { + m_threads.emplace_back([this, i] { + while (true) { + std::unique_lock lock{m_task_queue_mutex}; + + // look for a work item + while (i < m_num_threads && m_task_queue.empty()) { + // if there are none, signal that the queue is completed + // and wait for notification of new work items. + m_task_queue_completed_condition.notify_all(); + m_worker_condition.wait(lock); + } + + if (i >= m_num_threads) { + break; + } + + std::function task{std::move(m_task_queue.front())}; + m_task_queue.pop_front(); + + // Unlock the lock, so we can process the task without blocking other threads + lock.unlock(); + + task(); + } + }); + } +} + +void ThreadPool::shutdown_threads(size_t num) { + auto num_to_close = std::min(num, m_num_threads); + + { + std::lock_guard lock{m_task_queue_mutex}; + m_num_threads -= num_to_close; + } + + // Wake up all the threads to have them quit + m_worker_condition.notify_all(); + for (auto i = 0u; i < num_to_close; ++i) { + m_threads.back().join(); + m_threads.pop_back(); + } +} + +void ThreadPool::set_n_threads(size_t num) { + if (m_num_threads > num) { + shutdown_threads(m_num_threads - num); + } else if (m_num_threads < num) { + start_threads(num - m_num_threads); + } +} + +void ThreadPool::wait_until_queue_completed() { + std::unique_lock lock{m_task_queue_mutex}; + m_task_queue_completed_condition.wait(lock, [this]() { return m_task_queue.empty(); }); +} + +void ThreadPool::flush_queue() { + std::lock_guard lock{m_task_queue_mutex}; + m_task_queue.clear(); +} + +} diff --git a/gui/src/tiny-cuda-nn/common_host.cu b/gui/src/tiny-cuda-nn/common_host.cu new file mode 100644 index 0000000000000000000000000000000000000000..e4950bc4b075025d18a01a9a00ba466794a46b81 --- /dev/null +++ b/gui/src/tiny-cuda-nn/common_host.cu @@ -0,0 +1,351 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file common_host.cu + * @author Thomas Müller and Nikolaus Binder, NVIDIA + * @brief Common utilities that are needed by pretty much every component of this framework. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace tcnn { + +static_assert( + __CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2), + "tiny-cuda-nn requires at least CUDA 10.2" +); + +std::function g_log_callback = [](LogSeverity severity, const std::string& msg) { + switch (severity) { + case LogSeverity::Warning: std::cerr << fmt::format("tiny-cuda-nn warning: {}\n", msg); break; + case LogSeverity::Error: std::cerr << fmt::format("tiny-cuda-nn error: {}\n", msg); break; + default: break; + } + + if (verbose()) { + switch (severity) { + case LogSeverity::Debug: std::cerr << fmt::format("tiny-cuda-nn debug: {}\n", msg); break; + case LogSeverity::Info: std::cerr << fmt::format("tiny-cuda-nn info: {}\n", msg); break; + case LogSeverity::Success: std::cerr << fmt::format("tiny-cuda-nn success: {}\n", msg); break; + default: break; + } + } +}; + +const std::function& log_callback() { return g_log_callback; } +void set_log_callback(const std::function& cb) { g_log_callback = cb; } + +bool g_verbose = false; +bool verbose() { return g_verbose; } +void set_verbose(bool verbose) { g_verbose = verbose; } + +Activation string_to_activation(const std::string& activation_name) { + if (equals_case_insensitive(activation_name, "None")) { + return Activation::None; + } else if (equals_case_insensitive(activation_name, "ReLU")) { + return Activation::ReLU; + } else if (equals_case_insensitive(activation_name, "LeakyReLU")) { + return Activation::LeakyReLU; + } else if (equals_case_insensitive(activation_name, "Exponential")) { + return Activation::Exponential; + } else if (equals_case_insensitive(activation_name, "Sigmoid")) { + return Activation::Sigmoid; + } else if (equals_case_insensitive(activation_name, "Sine")) { + return Activation::Sine; + } else if (equals_case_insensitive(activation_name, "Squareplus")) { + return Activation::Squareplus; + } else if (equals_case_insensitive(activation_name, "Softplus")) { + return Activation::Softplus; + } else if (equals_case_insensitive(activation_name, "Tanh")) { + return Activation::Tanh; + } + + throw std::runtime_error{fmt::format("Invalid activation name: {}", activation_name)}; +} + +std::string to_string(Activation activation) { + switch (activation) { + case Activation::None: return "None"; + case Activation::ReLU: return "ReLU"; + case Activation::LeakyReLU: return "LeakyReLU"; + case Activation::Exponential: return "Exponential"; + case Activation::Sigmoid: return "Sigmoid"; + case Activation::Sine: return "Sine"; + case Activation::Squareplus: return "Squareplus"; + case Activation::Softplus: return "Softplus"; + case Activation::Tanh: return "Tanh"; + default: throw std::runtime_error{"Invalid activation."}; + } +} + +GridType string_to_grid_type(const std::string& grid_type) { + if (equals_case_insensitive(grid_type, "Hash")) { + return GridType::Hash; + } else if (equals_case_insensitive(grid_type, "Dense")) { + return GridType::Dense; + } else if (equals_case_insensitive(grid_type, "Tiled") || equals_case_insensitive(grid_type, "Tile")) { + return GridType::Tiled; + } + + throw std::runtime_error{fmt::format("Invalid grid type: {}", grid_type)}; +} + +std::string to_string(GridType grid_type) { + switch (grid_type) { + case GridType::Hash: return "Hash"; + case GridType::Dense: return "Dense"; + case GridType::Tiled: return "Tiled"; + default: throw std::runtime_error{"Invalid grid type."}; + } +} + +HashType string_to_hash_type(const std::string& hash_type) { + if (equals_case_insensitive(hash_type, "Prime")) { + return HashType::Prime; + } else if (equals_case_insensitive(hash_type, "CoherentPrime")) { + return HashType::CoherentPrime; + } else if (equals_case_insensitive(hash_type, "ReversedPrime")) { + return HashType::ReversedPrime; + } else if (equals_case_insensitive(hash_type, "Rng")) { + return HashType::Rng; + } else if (equals_case_insensitive(hash_type, "BaseConvert")) { + return HashType::BaseConvert; + } + + throw std::runtime_error{fmt::format("Invalid hash type: {}", hash_type)}; +} + +std::string to_string(HashType hash_type) { + switch (hash_type) { + case HashType::Prime: return "Prime"; + case HashType::CoherentPrime: return "CoherentPrime"; + case HashType::ReversedPrime: return "ReversedPrime"; + case HashType::Rng: return "Rng"; + case HashType::BaseConvert: return "BaseConvert"; + default: throw std::runtime_error{"Invalid hash type."}; + } +} + +InterpolationType string_to_interpolation_type(const std::string& interpolation_type) { + if (equals_case_insensitive(interpolation_type, "Nearest")) { + return InterpolationType::Nearest; + } else if (equals_case_insensitive(interpolation_type, "Linear")) { + return InterpolationType::Linear; + } else if (equals_case_insensitive(interpolation_type, "Smoothstep")) { + return InterpolationType::Smoothstep; + } + + throw std::runtime_error{fmt::format("Invalid interpolation type: {}", interpolation_type)}; +} + +std::string to_string(InterpolationType interpolation_type) { + switch (interpolation_type) { + case InterpolationType::Nearest: return "Nearest"; + case InterpolationType::Linear: return "Linear"; + case InterpolationType::Smoothstep: return "Smoothstep"; + default: throw std::runtime_error{"Invalid interpolation type."}; + } +} + +ReductionType string_to_reduction_type(const std::string& reduction_type) { + if (equals_case_insensitive(reduction_type, "Concatenation")) { + return ReductionType::Concatenation; + } else if (equals_case_insensitive(reduction_type, "Sum")) { + return ReductionType::Sum; + } else if (equals_case_insensitive(reduction_type, "Product")) { + return ReductionType::Product; + } + + throw std::runtime_error{fmt::format("Invalid reduction type: {}", reduction_type)}; +} + +std::string to_string(ReductionType reduction_type) { + switch (reduction_type) { + case ReductionType::Concatenation: return "Concatenation"; + case ReductionType::Sum: return "Sum"; + case ReductionType::Product: return "Product"; + default: throw std::runtime_error{"Invalid reduction type."}; + } +} + +int cuda_runtime_version() { + int version; + CUDA_CHECK_THROW(cudaRuntimeGetVersion(&version)); + return version; +} + +int cuda_device() { + int device; + CUDA_CHECK_THROW(cudaGetDevice(&device)); + return device; +} + +void set_cuda_device(int device) { + CUDA_CHECK_THROW(cudaSetDevice(device)); +} + +int cuda_device_count() { + int device_count; + CUDA_CHECK_THROW(cudaGetDeviceCount(&device_count)); + return device_count; +} + +bool cuda_supports_virtual_memory(int device) { + int supports_vmm; + CU_CHECK_THROW(cuDeviceGetAttribute(&supports_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, device)); + return supports_vmm != 0; +} + +std::unordered_map& cuda_device_properties() { + static auto* cuda_device_props = new std::unordered_map{}; + return *cuda_device_props; +} + +const cudaDeviceProp& cuda_get_device_properties(int device) { + if (cuda_device_properties().count(device) == 0) { + auto& props = cuda_device_properties()[device]; + CUDA_CHECK_THROW(cudaGetDeviceProperties(&props, device)); + } + + return cuda_device_properties().at(device); +} + +std::string cuda_device_name(int device) { + return cuda_get_device_properties(device).name; +} + +uint32_t cuda_compute_capability(int device) { + const auto& props = cuda_get_device_properties(device); + return props.major * 10 + props.minor; +} + +uint32_t cuda_max_supported_compute_capability() { + int cuda_version = cuda_runtime_version(); + if (cuda_version < 11000) { + return 75; + } else if (cuda_version < 11010) { + return 80; + } else if (cuda_version < 11080) { + return 86; + } else { + return 90; + } +} + +uint32_t cuda_supported_compute_capability(int device) { + return std::min(cuda_compute_capability(device), cuda_max_supported_compute_capability()); +} + +size_t cuda_max_shmem(int device) { + return cuda_get_device_properties(device).sharedMemPerBlockOptin; +} + +uint32_t cuda_max_registers(int device) { + return (uint32_t)cuda_get_device_properties(device).regsPerBlock; +} + +size_t cuda_memory_granularity(int device) { + size_t granularity; + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = 0; + CUresult granularity_result = cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + if (granularity_result == CUDA_ERROR_NOT_SUPPORTED) { + return 1; + } + CU_CHECK_THROW(granularity_result); + return granularity; +} + +MemoryInfo cuda_memory_info() { + MemoryInfo info; + CUDA_CHECK_THROW(cudaMemGetInfo(&info.free, &info.total)); + info.used = info.total - info.free; + return info; +} + +std::string generate_device_code_preamble() { + return dfmt(0, R"( + #include + #include + + using namespace tcnn; + )"); +} + +std::string to_snake_case(const std::string& str) { + std::stringstream result; + result << (char)std::tolower(str[0]); + for (uint32_t i = 1; i < str.length(); ++i) { + if (std::isupper(str[i])) { + result << "_" << (char)std::tolower(str[i]); + } else { + result << str[i]; + } + } + return result.str(); +} + +std::vector split(const std::string& text, const std::string& delim) { + std::vector result; + size_t begin = 0; + while (true) { + size_t end = text.find_first_of(delim, begin); + if (end == std::string::npos) { + result.emplace_back(text.substr(begin)); + return result; + } else { + result.emplace_back(text.substr(begin, end - begin)); + begin = end + 1; + } + } + + return result; +} + +std::string to_lower(std::string str) { + std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::tolower(c); }); + return str; +} + +std::string to_upper(std::string str) { + std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::toupper(c); }); + return str; +} + +template <> std::string type_to_string() { return "bool"; } +template <> std::string type_to_string() { return "int"; } +template <> std::string type_to_string() { return "char"; } +template <> std::string type_to_string() { return "uint8_t"; } +template <> std::string type_to_string() { return "uint16_t"; } +template <> std::string type_to_string() { return "uint32_t"; } +template <> std::string type_to_string() { return "double"; } +template <> std::string type_to_string() { return "float"; } +template <> std::string type_to_string<__half>() { return "__half"; } + +} diff --git a/scripts/check_video_links.py b/scripts/check_video_links.py new file mode 100644 index 0000000000000000000000000000000000000000..63fcf13a2fe91c0b2c0fafa23f6017ed8af49576 --- /dev/null +++ b/scripts/check_video_links.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +import requests + + +def find_md_files(root="."): + for dirpath, _, filenames in os.walk(root): + for f in filenames: + if f.endswith(".md"): + yield os.path.join(dirpath, f) + + +def extract_video_urls(md_file): + with open(md_file, "r", encoding="utf-8") as f: + content = f.read() + return re.findall(r' 0: + print(f"Checkpoint {save_path} already exists and is not empty") + return + + pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") + os.makedirs(pixtral_ckpt_dir, exist_ok=True) + repo_id = "mistralai/Pixtral-12B-2409" + print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") + snapshot_download( + repo_id=repo_id, + allow_patterns=["params.json", "consolidated.safetensors"], + local_dir=pixtral_ckpt_dir, + local_dir_use_symlinks=False, + ) + orig_dtype = torch.get_default_dtype() + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + + # Load checkpoint file + ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) + assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" + ckpt_path = ckpt_files[0] + ckpt = load_file(ckpt_path) + + # Split checkpoint into weights of vision encoder, projector, and LLM + vit_key_prefix = "vision_encoder." + vit_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix): + vit_ckpt[key.lstrip(vit_key_prefix)] = value + + projector_key_prefix = "vision_language_adapter." + projector_ckpt = {} + substring_replacement_map = { + "w_in.": "projector.0.", + "w_out.": "projector.2.", + } + for key, value in ckpt.items(): + if key.startswith(projector_key_prefix): + key = key.lstrip(projector_key_prefix) + for old, new in substring_replacement_map.items(): + key = key.replace(old, new) + projector_ckpt[key] = value + + llm_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): + continue + llm_ckpt[key] = value + + vlm_ckpt = {} + for key, value in llm_ckpt.items(): + vlm_ckpt["model." + key] = value + for key, value in projector_ckpt.items(): + vlm_ckpt["mm_projector." + key] = value + for key, value in vit_ckpt.items(): + vlm_ckpt["vision_encoder." + key] = value + + # Load config + config_path = os.path.join(pixtral_ckpt_dir, "params.json") + with open(config_path, "r") as f: + pixtral_config = json.load(f) + + # Extract the vision encoder configuration + vision_encoder_config = { + "dim": pixtral_config["vision_encoder"]["hidden_size"], + "num_channels": pixtral_config["vision_encoder"]["num_channels"], + "image_size": pixtral_config["vision_encoder"]["image_size"], + "patch_size": pixtral_config["vision_encoder"]["patch_size"], + "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], + "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], + "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], + "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "norm_type": "rmsnorm", + "norm_eps": pixtral_config["norm_eps"], + "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], + } + # Configuration for the 400M ViT of Pixtral 12B VLM + vit_config = dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + # Compare the two configurations + for key, value in vit_config.items(): + assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" + + llm_config_keys = [ + "dim", + "n_layers", + "head_dim", + "hidden_dim", + "n_heads", + "n_kv_heads", + "rope_theta", + "norm_eps", + "vocab_size", + ] + assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" + replace_map = { + "hidden_dim": "ffn_hidden_size", + } + llm_config = {} + for k, v in pixtral_config.items(): + if k in llm_config_keys: + llm_config[replace_map.get(k, k)] = v + elif k == "vision_encoder": + llm_config["vision_encoder"] = vit_type + else: + raise ValueError(f"Unknown key: {k}") + + ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} + torch.save(ckpt_to_save, save_path) + print(f"Model saved to {save_path}") + + # Save config + config_path = os.path.join(save_dir, "config.json") + with open(config_path, "w") as f: + json.dump(llm_config, f) + + torch.set_default_dtype(orig_dtype) # Reset the default dtype + + # Remove the original Pixtral checkpoint + shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) + print(f"Removed {pixtral_ckpt_dir}") + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Predict1-14B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Text2World/model.pt": "c69d1c6e51dc78b959040e8c4035a29b", + "Cosmos-Predict1-14B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Video2World/model.pt": "eaa7aa3678f61d88108c41d7fe201b18", + "Cosmos-Predict1-7B-WorldInterpolator/model.pt": "48a0bdc99d5e41eee05ba8597c4851da", + "Cosmos-Predict1-7B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Text2World/model.pt": "fe9ed68e16cf37b10e7414c9b3ee81e1", + "Cosmos-Predict1-7B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Video2World/model.pt": "ebcdb19c4c4a6a0e1e0bb65e346f6867", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CV8x8x8-720p/image_mean_std.pt": "9f19fd3312fc1198e4905ada02e68bce", + "Cosmos-UpsamplePrompt1-12B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-UpsamplePrompt1-12B-Text2World/model.pt": "52d7a6b8b1ac44d856b4c1ea3f8c8c74", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt": "e3a6ef070deaae0678acd529dc749ea4", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt": "1653f87dce3d558ee01416593552a91c", + "google-t5/t5-11b/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3", + "google-t5/t5-11b/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + # Check if there are any expected files for this model + expected_files = [key for key in MD5_CHECKSUM_LOOKUP if key.startswith(model_name + "/")] + if not expected_files: + # No expected files in MD5_CHECKSUM_LOOKUP, check if the directory exists and has content + model_dir = checkpoints_dir / model_name + if not model_dir.exists() or not any(model_dir.iterdir()): + print(f"Directory for {model_name} does not exist or is empty. Download required.") + return False + else: + print(f"Directory for {model_name} exists and contains files. Assuming download is complete.") + return True + # Proceed with checksum verification for models with expected files + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name + "/"): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match given MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args): + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "7B": "Cosmos-Predict1-7B", + "14B": "Cosmos-Predict1-14B", + } + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-Tokenize1-CV8x8x8-720p", + "google-t5/t5-11b", + ] + + if "Text2World" in args.model_types: + extra_models.append("Cosmos-UpsamplePrompt1-12B-Text2World") + + # Add interpolator if 7B model is selected + if "7B" in args.model_sizes: + extra_models.append("Cosmos-Predict1-7B-WorldInterpolator") + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict( + allow_patterns=[ + "README.md", + "model.pt", + "mean_std.pt", + "image_mean_std.pt", + "config.json", + "*.jit", + "guardrail/*", + ] + ) + + # Download the requested diffusion models + for size in args.model_sizes: + for model_type in args.model_types: + suffix = f"-{model_type}" + model_name = model_map[size] + suffix + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + # Download the always-included models + for model_name in extra_models: + if model_name == "google-t5/t5-11b": + repo_id = model_name + else: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files for Guardrail + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + if "Video2World" in args.model_types: + # Prompt Upsampler for Cosmos-Predict1-Video2World models + convert_pixtral_checkpoint( + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Pixtral-12B", + vit_type="pixtral-12b-vit", + ) + + download_guardrail_checkpoints(args.checkpoint_dir) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_diffusion_example_data.py b/scripts/download_diffusion_example_data.py new file mode 100644 index 0000000000000000000000000000000000000000..70bd4ba23a538d8f12dfa1fe7402ab9e33a7cd08 --- /dev/null +++ b/scripts/download_diffusion_example_data.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import ffmpeg +from pytubefix import YouTube + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") + parser.add_argument("--do_download", action="store_true", help="Download the videos") + parser.add_argument("--do_clip", action="store_true", help="Clip the videos") + return parser.parse_args() + + +def convert_time_to_seconds(time_str) -> int: + h, m, s = map(float, time_str.split(":")) + ms = int(time_str.split(".")[-1]) if "." in time_str else 0 + return int(h * 3600 + m * 60 + s) + ms / 1000 + + +def download_data(args) -> None: + urls_set = set() + download_count = 0 + + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + os.makedirs(videos_orig_dir, exist_ok=True) + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + + hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") + with open(hdvila_jsonl_path, "r") as fp: + for line in fp: + json_object = json.loads(line) + url = json_object["url"] + if url not in urls_set: # download videos with unique urls + yt = YouTube(json_object["url"]) + try: + # Download a video + yt.streams.get_highest_resolution().download( + output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" + ) + download_count += 1 + urls_set.add(url) + print(f"Downloaded videos: {download_count}/{args.N_videos}") + + # Save metadata - caption and whole metadata + meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) + with open(meta_txt_name, "w") as fp: + fp.write(json_object["caption"]) + meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) + with open(meta_json_name, "w") as fp: + json.dump(json_object, fp) + except Exception as e: + print(e) + continue + + if len(urls_set) >= args.N_videos: + break + + +def clip_data(args) -> None: + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") + ] + videos_orig_list = [ + os.path.join(videos_orig_dir, filename) + for filename in sorted(os.listdir(videos_orig_dir)) + if filename.endswith(".mp4") + ] + + for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): + with open(meta_filename, "r") as fp: + metadata = json.load(fp) + + # Convert time strings to seconds + start_time = convert_time_to_seconds(metadata["span_start"]) + end_time = convert_time_to_seconds(metadata["span_end"]) + # Clip the video + clip_name = os.path.join(videos_dir, metadata["clip_id"]) + ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() + + +def main(args) -> None: + if args.do_download: + download_data(args) + if args.do_clip: + clip_data(args) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_gen3c_checkpoints.py b/scripts/download_gen3c_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..114efa8a9fe7880c574e52faeeaa36cf87c6c92f --- /dev/null +++ b/scripts/download_gen3c_checkpoints.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import json +import os +import shutil +from glob import glob +from pathlib import Path + +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from scripts.download_guardrail_checkpoints import download_guardrail_checkpoints + + +def parse_args(): + parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos Predict1 Gen3C models from Hugging Face") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints." + ) + args = parser.parse_args() + return args + + +def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str): + """ + Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint. + + Args: + checkpoint_dir (str): Path to the checkpoint directory + checkpoint_name (str): Name of the checkpoint + vit_type (str): Type of ViT used in the Pixtral model + + This function performs the following steps: + 0. Download the checkpoint from Hugging Face + 1. Loads the original Pixtral checkpoint + 2. Splits the checkpoint into vision encoder, projector, and LLM weights + 3. Reorganizes the weights to match the expected format + 4. Extracts and verifies the vision encoder configuration + 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer + 6. Optionally saves the converted checkpoint and configuration + """ + + save_dir = os.path.join(checkpoint_dir, checkpoint_name) + os.makedirs(save_dir, exist_ok=True) + # Save the converted checkpoint + save_path = os.path.join(save_dir, "model.pt") + if os.path.exists(save_path) and os.path.getsize(save_path) > 0: + print(f"Checkpoint {save_path} already exists and is not empty") + return + + pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") + os.makedirs(pixtral_ckpt_dir, exist_ok=True) + repo_id = "mistralai/Pixtral-12B-2409" + print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") + snapshot_download( + repo_id=repo_id, + allow_patterns=["params.json", "consolidated.safetensors"], + local_dir=pixtral_ckpt_dir, + local_dir_use_symlinks=False, + ) + orig_dtype = torch.get_default_dtype() + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + + # Load checkpoint file + ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) + assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" + ckpt_path = ckpt_files[0] + ckpt = load_file(ckpt_path) + + # Split checkpoint into weights of vision encoder, projector, and LLM + vit_key_prefix = "vision_encoder." + vit_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix): + vit_ckpt[key.lstrip(vit_key_prefix)] = value + + projector_key_prefix = "vision_language_adapter." + projector_ckpt = {} + substring_replacement_map = { + "w_in.": "projector.0.", + "w_out.": "projector.2.", + } + for key, value in ckpt.items(): + if key.startswith(projector_key_prefix): + key = key.lstrip(projector_key_prefix) + for old, new in substring_replacement_map.items(): + key = key.replace(old, new) + projector_ckpt[key] = value + + llm_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): + continue + llm_ckpt[key] = value + + vlm_ckpt = {} + for key, value in llm_ckpt.items(): + vlm_ckpt["model." + key] = value + for key, value in projector_ckpt.items(): + vlm_ckpt["mm_projector." + key] = value + for key, value in vit_ckpt.items(): + vlm_ckpt["vision_encoder." + key] = value + + # Load config + config_path = os.path.join(pixtral_ckpt_dir, "params.json") + with open(config_path, "r") as f: + pixtral_config = json.load(f) + + # Extract the vision encoder configuration + vision_encoder_config = { + "dim": pixtral_config["vision_encoder"]["hidden_size"], + "num_channels": pixtral_config["vision_encoder"]["num_channels"], + "image_size": pixtral_config["vision_encoder"]["image_size"], + "patch_size": pixtral_config["vision_encoder"]["patch_size"], + "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], + "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], + "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], + "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "norm_type": "rmsnorm", + "norm_eps": pixtral_config["norm_eps"], + "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], + } + # Configuration for the 400M ViT of Pixtral 12B VLM + vit_config = dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + # Compare the two configurations + for key, value in vit_config.items(): + assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" + + llm_config_keys = [ + "dim", + "n_layers", + "head_dim", + "hidden_dim", + "n_heads", + "n_kv_heads", + "rope_theta", + "norm_eps", + "vocab_size", + ] + assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" + replace_map = { + "hidden_dim": "ffn_hidden_size", + } + llm_config = {} + for k, v in pixtral_config.items(): + if k in llm_config_keys: + llm_config[replace_map.get(k, k)] = v + elif k == "vision_encoder": + llm_config["vision_encoder"] = vit_type + else: + raise ValueError(f"Unknown key: {k}") + + ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} + torch.save(ckpt_to_save, save_path) + print(f"Model saved to {save_path}") + + # Save config + config_path = os.path.join(save_dir, "config.json") + with open(config_path, "w") as f: + json.dump(llm_config, f) + + torch.set_default_dtype(orig_dtype) # Reset the default dtype + + # Remove the original Pixtral checkpoint + shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) + print(f"Removed {pixtral_ckpt_dir}") + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Predict1-14B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Text2World/model.pt": "c69d1c6e51dc78b959040e8c4035a29b", + "Cosmos-Predict1-14B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Video2World/model.pt": "eaa7aa3678f61d88108c41d7fe201b18", + "Cosmos-Predict1-7B-WorldInterpolator/model.pt": "48a0bdc99d5e41eee05ba8597c4851da", + "Cosmos-Predict1-7B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Text2World/model.pt": "fe9ed68e16cf37b10e7414c9b3ee81e1", + "Cosmos-Predict1-7B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Video2World/model.pt": "ebcdb19c4c4a6a0e1e0bb65e346f6867", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CV8x8x8-720p/image_mean_std.pt": "9f19fd3312fc1198e4905ada02e68bce", + "Cosmos-UpsamplePrompt1-12B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-UpsamplePrompt1-12B-Text2World/model.pt": "52d7a6b8b1ac44d856b4c1ea3f8c8c74", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt": "e3a6ef070deaae0678acd529dc749ea4", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt": "1653f87dce3d558ee01416593552a91c", + "Gen3C-Cosmos-7B/model.pt": "38644bf823aa5272acef60cfad8bc0f7", + "google-t5/t5-11b/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3", + "google-t5/t5-11b/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + # Check if there are any expected files for this model + expected_files = [key for key in MD5_CHECKSUM_LOOKUP if key.startswith(model_name + "/")] + if not expected_files: + # No expected files in MD5_CHECKSUM_LOOKUP, check if the directory exists and has content + model_dir = checkpoints_dir / model_name + if not model_dir.exists() or not any(model_dir.iterdir()): + print(f"Directory for {model_name} does not exist or is empty. Download required.") + return False + else: + print(f"Directory for {model_name} exists and contains files. Assuming download is complete.") + return True + # Proceed with checksum verification for models with expected files + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name + "/"): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match given MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args): + ORG_NAME = "nvidia" + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-Tokenize1-CV8x8x8-720p", + "google-t5/t5-11b", + ] + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict( + allow_patterns=[ + "README.md", + "model.pt", + "mean_std.pt", + "image_mean_std.pt", + "config.json", + "*.jit", + "guardrail/*", + ] + ) + + # Download the requested diffusion models + model_name = "Gen3C-Cosmos-7B" + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + # Download the always-included models + for model_name in extra_models: + if model_name == "google-t5/t5-11b": + repo_id = model_name + else: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files for Guardrail + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_guardrail_checkpoints.py b/scripts/download_guardrail_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..81c4dafc13033b5ca8ce71b98ac28f51da510d0a --- /dev/null +++ b/scripts/download_guardrail_checkpoints.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +from huggingface_hub import snapshot_download + + +def download_models(models: List[str], destination_root: str): + """ + Download models from Hugging Face Hub and save them in org/project structure. + + Args: + models: List of model IDs in format 'org/project' + destination_root: Root directory where models will be saved + """ + for model_id in models: + model_id, revision = model_id.split(":") if ":" in model_id else (model_id, None) + print(f"Downloading {model_id}...") + + # Create the full path for the model + model_path = os.path.join(destination_root, model_id) + + try: + # Download the model + snapshot_download( + repo_id=model_id, + local_dir=model_path, + revision=revision, + ) + print(f"Successfully downloaded {model_id} to {model_path}") + + except Exception as e: + raise RuntimeError(f"Error downloading {model_id}: {str(e)}. Please delete the directory and try again.") + + +def download_guardrail_checkpoints(destination_root: str): + """ + Download guardrail checkpoints from Hugging Face Hub and save them in org/project structure. + + Args: + destination_root: Root directory where checkpoints will be saved + """ + # List of models to download + models_to_download = [ + "meta-llama/Llama-Guard-3-8B", + "nvidia/Cosmos-Guardrail1", + ] + + # Create the destination directory if it doesn't exist + os.makedirs(destination_root, exist_ok=True) + + # Download the models + download_models(models_to_download, destination_root) + + +if __name__ == "__main__": + download_guardrail_checkpoints("checkpoints") diff --git a/scripts/download_tokenizer_checkpoints.py b/scripts/download_tokenizer_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..249ee29cdad2ca648b691b74ff2ddb8b18c926b3 --- /dev/null +++ b/scripts/download_tokenizer_checkpoints.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import os +from pathlib import Path + +from huggingface_hub import snapshot_download + +from scripts.download_guardrail_checkpoints import download_guardrail_checkpoints + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="A script to download NVIDIA Cosmos-Tokenizer1 models from Hugging Face" + ) + parser.add_argument( + "--tokenizer_types", + nargs="*", + default=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CI8x8-360p", + "CI16x16-360p", + "CV4x8x8-360p", + "DI8x8-360p", + "DI16x16-360p", + "DV4x8x8-360p", + ], # Download all by default + choices=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CI8x8-360p", + "CI16x16-360p", + "CV4x8x8-360p", + "DI8x8-360p", + "DI16x16-360p", + "DV4x8x8-360p", + ], + help="Which tokenizer model types to download. Possible values: CV8x8x8-720p, DV8x16x16-720p, CV4x8x8-360p, DV4x8x8-360p", + ) + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints." + ) + args = parser.parse_args() + return args + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d", + "Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8", + "Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CI16x16-360p/autoencoder.jit": "98f8fdf2ada5537705d6d1bc22c63cf1", + "Cosmos-Tokenize1-CI16x16-360p/decoder.jit": "dd31a73a8c7062bab25492401d83b473", + "Cosmos-Tokenize1-CI16x16-360p/encoder.jit": "7be1dadea5a1c283996ca1ce5b1a95a9", + "Cosmos-Tokenize1-CI8x8-360p/autoencoder.jit": "b2ff9280b12a97202641bb2a41d7b271", + "Cosmos-Tokenize1-CI8x8-360p/decoder.jit": "57fb213cd88c0a991e9d400875164571", + "Cosmos-Tokenize1-CI8x8-360p/encoder.jit": "138fe257df41d7a43c17396c23086565", + "Cosmos-Tokenize1-CV4x8x8-360p/autoencoder.jit": "0690ff725700128424d082b44a1eda08", + "Cosmos-Tokenize1-CV4x8x8-360p/decoder.jit": "7573744ec14cb1b2abdf9c80318b7224", + "Cosmos-Tokenize1-CV4x8x8-360p/encoder.jit": "fe3a7193defcb2db0b849b6df480b5e6", + "Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d", + "Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8", + "Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299", + "Cosmos-Tokenize1-DI16x16-360p/autoencoder.jit": "88195130b86c3434d3d4b0e0376def6b", + "Cosmos-Tokenize1-DI16x16-360p/decoder.jit": "bf27a567388902acbd8abcc3a5afd8dd", + "Cosmos-Tokenize1-DI16x16-360p/encoder.jit": "12bae3a56c79a7ca0beb774843ee8c58", + "Cosmos-Tokenize1-DI8x8-360p/autoencoder.jit": "1d638e6034fcd43619bc1cdb343ebe56", + "Cosmos-Tokenize1-DI8x8-360p/decoder.jit": "b9b5eccaa7ab9ffbccae3b05b3903311", + "Cosmos-Tokenize1-DI8x8-360p/encoder.jit": "2bfa3c189aacdf9dc8faf17bcc30dd82", + "Cosmos-Tokenize1-DV4x8x8-360p/autoencoder.jit": "ff8802dc4497be60dc24a8f692833eed", + "Cosmos-Tokenize1-DV4x8x8-360p/decoder.jit": "f9a7d4bd24e4d2ee210cfd5f21550ce8", + "Cosmos-Tokenize1-DV4x8x8-360p/encoder.jit": "7af30a0223b2984d9d27dd3054fcd7af", + "Cosmos-Tokenize1-DV8x16x16-720p/autoencoder.jit": "606b8585b637f06057725cbb67036ae6", + "Cosmos-Tokenize1-DV8x16x16-720p/decoder.jit": "f0c8a9d992614a43e7ce24ebfc901e26", + "Cosmos-Tokenize1-DV8x16x16-720p/encoder.jit": "95186b0410346a3f0cf250b76daec452", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match give MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args) -> None: + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "CV8x8x8-720p": "Cosmos-Tokenize1-CV8x8x8-720p", + "DV8x16x16-720p": "Cosmos-Tokenize1-DV8x16x16-720p", + "CI8x8-360p": "Cosmos-Tokenize1-CI8x8-360p", + "CI16x16-360p": "Cosmos-Tokenize1-CI16x16-360p", + "CV4x8x8-360p": "Cosmos-Tokenize1-CV4x8x8-360p", + "DI8x8-360p": "Cosmos-Tokenize1-DI8x8-360p", + "DI16x16-360p": "Cosmos-Tokenize1-DI16x16-360p", + "DV4x8x8-360p": "Cosmos-Tokenize1-DV4x8x8-360p", + } + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict(allow_patterns=["README.md", "model.pt", "mean_std.pt", "config.json", "*.jit"]) + + # Download the requested Tokenizer models + for tokenizer_type in args.tokenizer_types: + model_name = model_map[tokenizer_type] + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + download_guardrail_checkpoints(args.checkpoint_dir) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_tokenizer_example_data.py b/scripts/download_tokenizer_example_data.py new file mode 100644 index 0000000000000000000000000000000000000000..28b7f2257f3f272cc29c6163ddcc8cd7e46ef2d0 --- /dev/null +++ b/scripts/download_tokenizer_example_data.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import ffmpeg +from pytubefix import YouTube + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_tokenizer_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") + parser.add_argument("--do_download", action="store_true", help="Download the videos") + parser.add_argument("--do_clip", action="store_true", help="Clip the videos") + return parser.parse_args() + + +def convert_time_to_seconds(time_str) -> int: + h, m, s = map(float, time_str.split(":")) + ms = int(time_str.split(".")[-1]) if "." in time_str else 0 + return int(h * 3600 + m * 60 + s) + ms / 1000 + + +def download_data(args) -> None: + urls_set = set() + download_count = 0 + + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + os.makedirs(videos_orig_dir, exist_ok=True) + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + + hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") + with open(hdvila_jsonl_path, "r") as fp: + for line in fp: + json_object = json.loads(line) + url = json_object["url"] + if url not in urls_set: # download videos with unique urls + yt = YouTube(json_object["url"]) + try: + # Download a video + yt.streams.get_highest_resolution().download( + output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" + ) + download_count += 1 + urls_set.add(url) + print(f"Downloaded videos: {download_count}/{args.N_videos}") + + # Save metadata - caption and whole metadata + meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) + with open(meta_txt_name, "w") as fp: + fp.write(json_object["caption"]) + meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) + with open(meta_json_name, "w") as fp: + json.dump(json_object, fp) + except Exception as e: + print(e) + continue + + if len(urls_set) >= args.N_videos: + break + + +def clip_data(args) -> None: + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") + ] + videos_orig_list = [ + os.path.join(videos_orig_dir, filename) + for filename in sorted(os.listdir(videos_orig_dir)) + if filename.endswith(".mp4") + ] + + for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): + with open(meta_filename, "r") as fp: + metadata = json.load(fp) + + # Convert time strings to seconds + start_time = convert_time_to_seconds(metadata["span_start"]) + end_time = convert_time_to_seconds(metadata["span_end"]) + # Clip the video + clip_name = os.path.join(videos_dir, metadata["clip_id"]) + ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() + + +def main(args) -> None: + if args.do_download: + download_data(args) + if args.do_clip: + clip_data(args) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100644 index 0000000000000000000000000000000000000000..42c3e2bd41407ab284b14bf2cb0bc76d67785374 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cosmos_root=$(git rev-parse --show-toplevel) +venv_folder=$cosmos_root/.venv +scripts_folder=$cosmos_root/scripts + +echo "Formatting $cosmos_root" +if [ ! -d "$scripts_folder" ]; then + echo "script has to be called from repo root dir!" + exit -1 +fi + +if [ ! -d "$venv_folder" ]; then + mkdir -p $venv_folder + python3 -m pip install virtualenv + python3 -m venv $venv_folder +fi + +source $venv_folder/bin/activate + +dependencies=($(pip freeze | grep -E 'pre-commit==3.7.1|flake8==7.1.0|black==24.4.2|isort==5.13.2|loguru|termcolor')) +if [ "${#dependencies[@]}" -ne 6 ]; then + python3 -m pip install --upgrade pip + python3 -m pip install pre-commit==3.7.1 + python3 -m pip install flake8==7.1.0 + python3 -m pip install black==24.4.2 + python3 -m pip install isort==5.13.2 + python3 -m pip install loguru + python3 -m pip install termcolor +fi +set -e +python3 $scripts_folder/ip_header.py +pre-commit install-hooks +pre-commit run --all diff --git a/scripts/get_t5_embeddings.py b/scripts/get_t5_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..53b6ebe889340e9360a946fa3022766dcf027c36 --- /dev/null +++ b/scripts/get_t5_embeddings.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + metas_dir = os.path.join(args.dataset_path, "metas") + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".txt") + ] + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_bridge.py b/scripts/get_t5_embeddings_from_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcdfe461914b86e61dbbac9c060f41938c7ed60 --- /dev/null +++ b/scripts/get_t5_embeddings_from_bridge.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_bridge.py --dataset_path datasets/bridge +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/bridge", help="Root path to the dataset") + parser.add_argument( + "--subset", + type=str, + default="train", + choices=("train", "val", "test"), + help="Subset of the bridge dataset to process", + ) + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + annotation_dir = os.path.join(args.dataset_path, "annotation", args.subset) + annotation_list = [ + os.path.join(annotation_dir, filename) + for filename in sorted(os.listdir(annotation_dir)) + if filename.endswith(".json") + ] + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for annotation_filename in annotation_list: + # Save T5 embeddings as pickle file + t5_xxl_filename = os.path.join( + annotation_dir, os.path.basename(annotation_filename).replace(".json", ".pickle") + ) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(annotation_filename, "r") as fp: + metadata = json.load(fp) + prompt = metadata["texts"][0] + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e8783c5058c8daba8f3ee6a1d93c321f532f25 --- /dev/null +++ b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_cosmos_nemo_assets.py --dataset_path datasets/cosmos_nemo_assets +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument( + "--dataset_path", type=str, default="datasets/cosmos_nemo_assets", help="Root path to the dataset" + ) + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + videos_dir = os.path.join(args.dataset_path, "videos") + + # Cosmos-NeMo-Assets come with videos only. A prompt is provided as an argument. + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + metas_list = [ + os.path.join(metas_dir, filename.replace(".mp4", ".txt")) + for filename in sorted(os.listdir(videos_dir)) + if filename.endswith(".mp4") + ] + + # Write txt files to match other dataset formats. + for meta_filename in metas_list: + if not os.path.exists(meta_filename): + with open(meta_filename, "w") as fp: + fp.write(args.prompt) + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_waymo.py b/scripts/get_t5_embeddings_from_waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..a33d80ca1d01b734daff570c12cafa72f610ad4d --- /dev/null +++ b/scripts/get_t5_embeddings_from_waymo.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_waymo.py --dataset_path datasets/waymo +""" + +PREFIX_PROMPTS = { + "pinhole_front": "The video is captured from a camera mounted on a car. The camera is facing forward.", + "pinhole_front_left": "The video is captured from a camera mounted on a car. The camera is facing forward and slightly to the left.", + "pinhole_front_right": "The video is captured from a camera mounted on a car. The camera is facing forward and slightly to the right.", + "pinhole_side_left": "The video is captured from a camera mounted on a car. The camera is facing to the left.", + "pinhole_side_right": "The video is captured from a camera mounted on a car. The camera is facing to the right.", +} + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/waymo", help="Root path to the dataset") + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + videos_dir = os.path.join(args.dataset_path, "videos") + + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + metas_list = [ + os.path.join(metas_dir, viewname, filename.replace(".mp4", ".txt")) + for viewname in sorted(os.listdir(videos_dir)) + for filename in sorted(os.listdir(videos_dir + "/" + viewname)) + if filename.endswith(".mp4") + ] + + # Write txt files to match other dataset formats. + for meta_filename in metas_list: + if not os.path.exists(meta_filename): + with open(meta_filename, "w") as fp: + fp.write(args.prompt) + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + # Extract T5 embeddings for prefix prompt + for view_name, prefix_prompt in PREFIX_PROMPTS.items(): + t5_xxl_filename = os.path.join(args.dataset_path, "cache", f"prefix_t5_embeddings_{view_name}.pickle") + os.makedirs(os.path.dirname(t5_xxl_filename), exist_ok=True) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prefix_prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join( + t5_xxl_dir, meta_filename.split("/")[-2], os.path.basename(meta_filename).replace(".txt", ".pickle") + ) + os.makedirs(os.path.dirname(t5_xxl_filename), exist_ok=True) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/ip_header.py b/scripts/ip_header.py new file mode 100644 index 0000000000000000000000000000000000000000..f139c36ed77da543f9006fe2adecd080686f118c --- /dev/null +++ b/scripts/ip_header.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +import termcolor + +parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") +parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") +args, files_to_check = parser.parse_known_args() + + +def get_header(ext: str = "py", old: str | bool = False) -> list[str]: + header = [ + "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", + "SPDX-License-Identifier: Apache-2.0", + "", + 'Licensed under the Apache License, Version 2.0 (the "License");', + "you may not use this file except in compliance with the License.", + "You may obtain a copy of the License at", + "", + "http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing, software", + 'distributed under the License is distributed on an "AS IS" BASIS,', + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "See the License for the specific language governing permissions and", + "limitations under the License.", + ] + if ext == ".py" and old: + if old == "single": + header = ["'''"] + header + ["'''"] + elif old == "double": + header = ['"""'] + header + ['"""'] + else: + raise NotImplementedError + elif ext in (".py", ".yaml"): + header = [("# " + line if line else "#") for line in header] + elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): + header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] + else: + raise NotImplementedError + return header + + +def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: + if file.endswith("__init__.py"): + return + ext = os.path.splitext(file)[1] + content = open(file).read().splitlines() + header = get_header(ext=ext) + if fix: + if _check_header(content, header): + return + print(f"fixing: {file}") + while len(content) > 0 and not content[0]: + content.pop(0) + content = header + [""] + content + with open(file, "w") as file_obj: + for line in content: + file_obj.write(line + "\n") + else: + if not _check_header(content, header): + bad_header = colorize("BAD HEADER", color="red", bold=True) + print(f"{bad_header}: {file}") + results[file] = 1 + else: + results[file] = 0 + + +def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: + files = os.listdir(path) + for file in files: + full_path = os.path.join(path, file) + if os.path.isdir(full_path): + traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) + elif os.path.isfile(full_path): + ext = os.path.splitext(file)[1] + to_skip = any(substr in full_path for substr in substrings_to_skip) + if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): + apply_file(full_path, results, fix=fix) + else: + raise NotImplementedError + + +def _check_header(content: list[str], header: list[str]) -> bool: + if content[: len(header)] != header: + return False + + i = len(header) + blank_line_count = 0 + + while i < len(content) and content[i].strip() == "": + blank_line_count += 1 + i += 1 + + # Allow at most two blank lines + if blank_line_count > 2: + return False + + # Must have at least one non-empty line after the blank lines + return i < len(content) + + +def colorize(x: str, color: str, bold: bool = False) -> str: + return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) # type: ignore + + +if __name__ == "__main__": + if not files_to_check: + files_to_check = [ + "cosmos_predict1/auxiliary", + "cosmos_predict1/diffusion", + "cosmos_predict1/callbacks", + "cosmos_predict1/checkpointer", + "cosmos_predict1/autoregressive", + "cosmos_predict1/tokenizer", + "cosmos_predict1/utils", + ] + + for file in files_to_check: + assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" + + substrings_to_skip = ["prompt_upsampler"] + results = dict() + for file in files_to_check: + if os.path.isfile(file): + apply_file(file, results, fix=args.fix) + elif os.path.isdir(file): + traverse_directory(file, results, fix=args.fix, substrings_to_skip=substrings_to_skip) + else: + raise NotImplementedError + + if any(results.values()): + sys.exit(1) diff --git a/scripts/merge_autoregressive_tp_checkpoints.py b/scripts/merge_autoregressive_tp_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..07cd6b4866c5160dc2a291cb031e4ec99e4abeb5 --- /dev/null +++ b/scripts/merge_autoregressive_tp_checkpoints.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.utils.checkpoint import merge_tensor_parallel_state_dicts +from cosmos_predict1.utils import log + + +def merge_sharded_checkpoints(checkpoint_path, output_path, tensor_parallel_size, model_size, model_family): + assert checkpoint_path.endswith(".pt"), "Checkpoint path must end with .pt" + assert model_family == "cosmos", "Only cosmos model family is currently supported" + assert model_size == "4b", "Only 4B model size is currently supported" + model_config, _ = create_video2world_model_config( + model_ckpt_path=checkpoint_path, + model_family=model_family, + model_size=model_size, + tensor_model_parallel_size=tensor_parallel_size, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + ) + log.info(f"Merging sharded checkpoints from {checkpoint_path.replace('.pt', '_model_mp_*.pt')} into {output_path}") + + checkpoint_paths = [checkpoint_path.replace(".pt", f"_model_mp_{rank}.pt") for rank in range(tensor_parallel_size)] + for path in checkpoint_paths: + assert os.path.exists(path), f"Checkpoint path {path} does not exist" + log.info(f"Found checkpoint {path}") + sharded_state_dicts = [torch.load(path, map_location="cpu") for path in checkpoint_paths] + merged_state_dict = merge_tensor_parallel_state_dicts(sharded_state_dicts, model_config) + torch.save(merged_state_dict, output_path) + log.info(f"Merged checkpoint saved to {output_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Merge Cosmos-Predict1-4B autoregressive checkpoints") + parser.add_argument( + "--checkpoint_path", + "-c", + type=str, + required=True, + help="Path to the checkpoint to merge. Must end with .pt and be colocated with the sharded checkpoints ending in _model_mp_{rank}.pt", + ) + parser.add_argument("--output_path", "-o", type=str, required=True, help="Path to the output merged checkpoint") + parser.add_argument("--tensor_parallel_size", "-t", type=int, required=True, help="Tensor parallel size") + parser.add_argument("--model_size", "-s", type=str, required=True, help="Model size") + parser.add_argument("--model_family", "-f", type=str, required=False, default="cosmos", help="Model family") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + merge_sharded_checkpoints( + args.checkpoint_path, args.output_path, args.tensor_parallel_size, args.model_size, args.model_family + ) diff --git a/scripts/shard_autoregressive_base_checkpoints.py b/scripts/shard_autoregressive_base_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..4e77144cd013c63cfc1993c0fda8963c6857e092 --- /dev/null +++ b/scripts/shard_autoregressive_base_checkpoints.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.utils.checkpoint import obtain_tensor_parallel_state_dict +from cosmos_predict1.utils import log + + +def shard_checkpoint(checkpoint_path, tensor_parallel_size, model_size, model_family, target_backend="pytorch"): + assert checkpoint_path.endswith(".pt"), "Checkpoint path must end with .pt" + assert model_family == "cosmos", "Only cosmos model family is currently supported" + assert model_size == "4b", "Only 4B model size is currently supported" + model_config, _ = create_video2world_model_config( + model_ckpt_path=checkpoint_path, + model_family=model_family, + model_size=model_size, + tensor_model_parallel_size=tensor_parallel_size, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + ) + log.info(f"Sharding checkpoint {checkpoint_path} with {tensor_parallel_size} ranks") + checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True) + for tensor_parallel_rank in range(tensor_parallel_size): + shard = obtain_tensor_parallel_state_dict( + checkpoint, tensor_parallel_size, tensor_parallel_rank, model_config, target_backend=target_backend + ) + shard_path = checkpoint_path.replace(".pt", f"_model_mp_{tensor_parallel_rank}.pt") + log.info(f"Saving shard {shard_path}") + torch.save(shard, shard_path) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Shard NVIDIA Cosmos Predict1 autoregressive models") + parser.add_argument( + "--checkpoint_path", + "-c", + type=str, + required=True, + default="checkpoints/Cosmos-Predict1-4B/model.pt", + help="Path to the checkpoint to shard", + ) + parser.add_argument("--tensor_parallel_size", "-t", type=int, required=True, help="Number of tensor parallel ranks") + parser.add_argument("--target_backend", "-b", type=str, required=False, default="pytorch", help="Target backend") + parser.add_argument("--model_size", "-s", type=str, required=True, help="Model size") + parser.add_argument("--model_family", "-f", type=str, required=False, default="cosmos", help="Model family") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + shard_checkpoint( + args.checkpoint_path, args.tensor_parallel_size, args.model_size, args.model_family, args.target_backend + ) diff --git a/scripts/test_environment.py b/scripts/test_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..606447884e3701b8992a6444f9c42b7aef9c400e --- /dev/null +++ b/scripts/test_environment.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os +import sys + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--training", + action="store_true", + help="Whether to check training-specific dependencies", + ) + return parser.parse_args() + + +def check_packages(package_list): + global all_success + for package in package_list: + try: + _ = importlib.import_module(package) + except Exception as e: + print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m") + all_success = False + else: + print(f"\033[92m[SUCCESS]\033[0m {package} found") + + +args = parse_args() + +if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): + detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m") + sys.exit(1) + +if "CONDA_PREFIX" not in os.environ: + print("\033[93m[WARNING]\033[0m Cosmos should be run under a conda environment.") + +print("Attempting to import critical packages...") + +packages = [ + "torch", + "torchvision", + "diffusers", + "transformers", + "megatron.core", + "transformer_engine", +] +packages_training = [ + "apex.multi_tensor_apply", +] +all_success = True + +check_packages(packages) +if args.training: + check_packages(packages_training) + +if all_success: + print("-----------------------------------------------------------") + print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!")