diff --git a/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml b/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89ccb8b9fb8b804d782d41101fc00a75d32c505f --- /dev/null +++ b/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,51 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve CogVideoX-Factory / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX-Factory 开源框架 +body: + - type: textarea + id: system-info + attributes: + label: System Info / 系統信息 + description: Your operating environment / 您的运行环境信息 + placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... + validations: + required: true + + - type: checkboxes + id: information-scripts-examples + attributes: + label: Information / 问题信息 + description: 'The problem arises when using: / 问题出现在' + options: + - label: "The official example scripts / 官方的示例脚本" + - label: "My own modified scripts / 我自己修改的脚本和任务" + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction / 复现过程 + description: | + Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. + If you have code snippets, error messages, stack traces, please provide them here as well. + Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. + + 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 + 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 + 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 + placeholder: | + Steps to reproduce the behavior/复现Bug的步骤: + + 1. + 2. + 3. + + - type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior / 期待表现 + description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml b/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac2f0fcf6162c1d0bfb6c79234df03e744dcb311 --- /dev/null +++ b/docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml @@ -0,0 +1,34 @@ +name: "\U0001F680 Feature request" +description: Submit a request for a new CogVideoX-Factory feature / 提交一个新的 CogVideoX-Factory 开源项目的功能建议 +labels: [ "feature" ] +body: + - type: textarea + id: feature-request + validations: + required: true + attributes: + label: Feature request / 功能建议 + description: | + A brief description of the functional proposal. Links to corresponding papers and code are desirable. + 对功能建议的简述。最好提供对应的论文和代码链接。 + + - type: textarea + id: motivation + validations: + required: true + attributes: + label: Motivation / 动机 + description: | + Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. + 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 + + - type: textarea + id: contribution + validations: + required: true + attributes: + label: Your contribution / 您的贡献 + description: | + + Your PR link or any other link you can help with. + 您的PR链接或者其他您能提供帮助的链接。 \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml b/docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..121ddba3c715c045f3dcdf7ee5915e91387823ca --- /dev/null +++ b/docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml @@ -0,0 +1,30 @@ +name: Fast tests for PRs + +on: + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + check_code_quality: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff==0.9.10 + - name: Check quality + run: make quality + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Quality check failed. Please install ruff: `pip install ruff` and then run `make style && make quality` from the root of the repository." >> $GITHUB_STEP_SUMMARY diff --git a/docs/finetrainers-src-codebase/.gitignore b/docs/finetrainers-src-codebase/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ea71da238f39bad1776569930a10df4f32be13e2 --- /dev/null +++ b/docs/finetrainers-src-codebase/.gitignore @@ -0,0 +1,179 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# JetBrains +.idea + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# manually added +wandb/ +*.txt +dump* +outputs* +*.slurm +.vscode/ +*dummy* +*curated* +validation_dataset/ +wan-framepack/ + +!requirements.txt diff --git a/docs/finetrainers-src-codebase/CONTRIBUTING.md b/docs/finetrainers-src-codebase/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c7636cb31b891955931ee257cad14e87d42be95b --- /dev/null +++ b/docs/finetrainers-src-codebase/CONTRIBUTING.md @@ -0,0 +1,41 @@ +# How to contribute to Finetrainers + +Finetrainers is an early-stage library for training diffusion models. Everyone is welcome to contribute - models, algorithms, refactors, docs, etc. - but due to the early stage of the project, we recommend bigger contributions be discussed in an issue before submitting a PR. Eventually, we will have a better process for this! + +## How to contribute + +### Adding a new model + +If you would like to add a new model, please follow these steps: + +- Create a new file in the `finetrainers/models` directory with the model name (if it's new), or use the same directory if it's a variant of an existing model. +- Implement the model specification in the file. For more details on what a model specification should look like, see the [ModelSpecification](TODO(aryan): add link) documentation. +- Update the supported configs in `finetrainers/config.py` to include the new model and the training types supported. +- Add a dummy model specification in the `tests/models` directory. +- Make sure to test training with the following settings: + - Single GPU + - 2x GPU with `--dp_degree 2 --dp_shards 1` + - 2x GPU with `--dp_degree 1 --dp_shards 2` + + For `SFTTrainer` additions, please make sure to train with atleast 1000 steps (atleast 2000 data points) to ensure the model training is working as expected. +- Open a PR with your changes. Please make sure to share your wandb logs for the above training settings in the PR description. This will help us verify the training is working as expected. + +### Adding a new algorithm + +Currently, we are not accepting algorithm contributions. We will update this section once we are better ready 🤗 + +### Refactors + +The library is in a very early stage. There are many instances of dead code, poorly written abstractions, and other issues. If you would like to refactor/clean-up a part of the codebase, please open an issue to discuss the changes before submitting a PR. + +### Dataset improvements + +Any changes to dataset/dataloader implementations can be submitted directly. The improvements and reasons for the changes should be conveyed appropriately for us to move quickly 🤗 + +### Documentation + +Due to the early stage of the project, the documentation is not as comprehensive as we would like. Any improvements/refactors are welcome directly! + +## Asking for help + +If you have any questions, feel free to open an issue and we will be sure to help you out asap! Please make sure to describe your issues in either English (preferable) or Chinese. Any other language will make it hard for us to help you, so we will most likely close such issues without explanation/answer. diff --git a/docs/finetrainers-src-codebase/LICENSE b/docs/finetrainers-src-codebase/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/docs/finetrainers-src-codebase/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/docs/finetrainers-src-codebase/Makefile b/docs/finetrainers-src-codebase/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..04d98b6a5a47771f2ce26080c82befb0d335f732 --- /dev/null +++ b/docs/finetrainers-src-codebase/Makefile @@ -0,0 +1,11 @@ +.PHONY: quality style + +check_dirs := finetrainers tests examples train.py setup.py + +quality: + ruff check $(check_dirs) --exclude examples/_legacy + ruff format --check $(check_dirs) --exclude examples/_legacy + +style: + ruff check $(check_dirs) --fix --exclude examples/_legacy + ruff format $(check_dirs) --exclude examples/_legacy diff --git a/docs/finetrainers/documentation_README.md b/docs/finetrainers-src-codebase/README.md similarity index 95% rename from docs/finetrainers/documentation_README.md rename to docs/finetrainers-src-codebase/README.md index 44dc094b4773dd90427a72abb395e9ad1de239e1..030fc58fcbc46be0796a5ee099614a971a026628 100644 --- a/docs/finetrainers/documentation_README.md +++ b/docs/finetrainers-src-codebase/README.md @@ -30,10 +30,10 @@ Checkout to the latest stable release tag: ```bash git fetch --all --tags -git checkout tags/v0.1.0 +git checkout tags/v0.2.0 ``` -Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.1.0) for the latest stable release. +Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.2.0-release) for the latest stable release. #### Using the main branch @@ -54,9 +54,10 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam ## Features -- DDP, FSDP-2 & HSDP support for all models +- DDP, FSDP-2 & HSDP, CP support - LoRA and full-rank finetuning; Conditional Control training - Memory-efficient single-GPU training +- Multiple attention backends supported - `flash`, `flex`, `sage`, `xformers` (see [attention](./docs/models/attention.md) docs) - Auto-detection of commonly used dataset formats - Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more - Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets @@ -65,6 +66,8 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam ## News +- 🔥 **2025-04-25**: Support for different attention providers added! +- 🔥 **2025-04-21**: Wan I2V supported added! - 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan! - 🔥 **2025-04-08**: `torch.compile` support added! - 🔥 **2025-04-06**: Flux support added! diff --git a/docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml b/docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a7660e0dc640b3cd8381bc793ed78b16c79d9c7 --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +dynamo_config: + dynamo_backend: INDUCTOR + dynamo_mode: max-autotune + dynamo_use_dynamic: true + dynamo_use_fullgraph: false +enable_cpu_affinity: false +gpu_ids: '3' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml b/docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62db0b4214e1faaac734f76086d6c6f7b6d3810b --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..348c1cae86a65ab605628fb39d8bc97269a11205 --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: '3' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..830b6e0daa8b12c74494c20f142da2b4a78d055e --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: 0,1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e15d4c6cd56145c4653e97b8cbdd823b154b6207 --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: 0,1,2,3 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee7f50c287c77246fd0d2042893378abd12a4943 --- /dev/null +++ b/docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/assets/contribute.md b/docs/finetrainers-src-codebase/assets/contribute.md new file mode 100644 index 0000000000000000000000000000000000000000..7dbd1c928fa051933e00a366c462276a23732953 --- /dev/null +++ b/docs/finetrainers-src-codebase/assets/contribute.md @@ -0,0 +1,16 @@ +# Contributions Welcome + +This project is in a very early stage, and we welcome contributions from everyone. We hope to receive contributions and support in the following areas: + +1. Support for more models. In addition to CogVideoX models, we also highly encourage contributions supporting other models. +2. Support for richer datasets. In our example, we used a Disney video generation dataset, but we hope to support more datasets as the current one is too limited for deeper fine-tuning exploration. +3. Anything in `TODO` we mention in our README.md + +## How to Submit + +We welcome you to create a new PR and describe the corresponding contribution. We will review it as soon as possible. + +## Naming Conventions + +- Please use English for naming, avoid using pinyin or other languages. All comments should be in English. +- Strictly follow PEP8 conventions, and use underscores to separate words. Please avoid using names like a, b, c. \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/assets/contribute_zh.md b/docs/finetrainers-src-codebase/assets/contribute_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..72308768d2526d635fd3e8c99a6e25b9f19cda9e --- /dev/null +++ b/docs/finetrainers-src-codebase/assets/contribute_zh.md @@ -0,0 +1,16 @@ +# 欢迎你们的贡献 + +本项目属于非常初级的阶段,欢迎大家进行贡献。我们希望在以下方面得到贡献和支持: + +1. 支持更多的模型,除了 CogVideoX 模型之外的模型,我们也非常支持。 +2. 更丰富的数据集支持。在我们的例子中,我们使用了一个 Disney 视频生成数据集,但是我们希望能够支持更多的数据集,这个数据集太少了,并不足以进行更深的微调探索。 +3. 任何我们在README中`TODO`提到的内容。 + +## 提交方式 + +我们欢迎您直接创建一个新的PR,并说明对应的贡献,我们将第一时间查看。 + +## 命名规范 + +- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。 +- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。 \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/assets/dataset_zh.md b/docs/finetrainers-src-codebase/assets/dataset_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..d43ab766b447788c9b9aba61328640ed705b59b9 --- /dev/null +++ b/docs/finetrainers-src-codebase/assets/dataset_zh.md @@ -0,0 +1,72 @@ +## 数据集格式 + +### 提示词数据集要求 + +创建 `prompt.txt` 文件,文件应包含逐行分隔的提示。请注意,提示必须是英文,并且建议使用 [提示润色脚本](https://github.com/THUDM/CogVideo/blob/main/inference/convert_demo.py) 进行润色。或者可以使用 [CogVideo-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) 进行数据标注: + +``` +A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction. +A black and white animated sequence on a ship’s deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language... +... +``` + +### 视频数据集要求 + +该框架支持的分辨率和帧数需要满足以下条件: + +- **支持的分辨率(宽 * 高)**: + - 任意分辨率且必须能被32整除。例如,`720 * 480`, `1920 * 1020` 等分辨率。 + +- **支持的帧数(Frames)**: + - 必须是 `4 * k` 或 `4 * k + 1`(例如:16, 32, 49, 81) + +所有的视频建议放在一个文件夹中。 + + +接着,创建 `videos.txt` 文件。 `videos.txt` 文件应包含逐行分隔的视频文件路径。请注意,路径必须相对于 `--data_root` 目录。格式如下: + +``` +videos/00000.mp4 +videos/00001.mp4 +... +``` + +对于有兴趣了解更多细节的开发者,您可以查看相关的 `BucketSampler` 代码。 + +### 数据集结构 + +您的数据集结构应如下所示,通过运行`tree`命令,你能看到: + +``` +dataset +├── prompt.txt +├── videos.txt +├── videos + ├── videos/00000.mp4 + ├── videos/00001.mp4 + ├── ... +``` + +### 使用数据集 + +当使用此格式时,`--caption_column` 应为 `prompt.txt`,`--video_column` 应为 `videos.txt`。如果您的数据存储在 CSV +文件中,也可以指定 `--dataset_file` 为 CSV 文件的路径,`--caption_column` 和 `--video_column` 为 CSV +文件中的实际列名。请参考 [test_dataset](../tests/test_dataset.py) 文件中的一些简单示例。 + +例如,使用 [这个](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) Disney 数据集进行微调。下载可通过🤗 +Hugging Face CLI 完成: + +``` +huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney +``` + +该数据集已按照预期格式准备好,可直接使用。但是,直接使用视频数据集可能会导致较小 VRAM 的 GPU 出现 +OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae) +(将视频编码为潜在空间)和大型 [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/) + +文本编码器。为了降低内存需求,您可以使用 `training/prepare_dataset.py` 脚本预先计算潜在变量和嵌入。 + +填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings` +以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。 +在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和 +VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。 diff --git a/docs/finetrainers-src-codebase/assets/sft_2b.png b/docs/finetrainers-src-codebase/assets/sft_2b.png new file mode 100644 index 0000000000000000000000000000000000000000..9340ef1f34f0e829d13d1e9833e167517db24af9 Binary files /dev/null and b/docs/finetrainers-src-codebase/assets/sft_2b.png differ diff --git a/docs/finetrainers-src-codebase/assets/sft_5b.png b/docs/finetrainers-src-codebase/assets/sft_5b.png new file mode 100644 index 0000000000000000000000000000000000000000..04509f38fab3e432250b06646abad47711972740 Binary files /dev/null and b/docs/finetrainers-src-codebase/assets/sft_5b.png differ diff --git a/docs/finetrainers-src-codebase/assets/tests/metadata.csv b/docs/finetrainers-src-codebase/assets/tests/metadata.csv new file mode 100644 index 0000000000000000000000000000000000000000..ac6f2df7482276aa1481f71ee4b2f206de2cb9e2 --- /dev/null +++ b/docs/finetrainers-src-codebase/assets/tests/metadata.csv @@ -0,0 +1,2 @@ +video,caption +"videos/hiker.mp4","""A hiker standing at the top of a mountain, triumphantly, high quality""" \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md b/docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md new file mode 100644 index 0000000000000000000000000000000000000000..c010437a163072073fcebd79b354b0d162d98f34 --- /dev/null +++ b/docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md @@ -0,0 +1,20 @@ +# Notes for Future Me + +>![NOTE] +> This doc page is intended for developers and contributors. + +FSDP dump: +- https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes +- https://github.com/pytorch/pytorch/issues/114299 +- Using FSDP1 requires that all FSDP flat parameters are of the same dtype. For LoRA training, we default lora parameters to fp32 and transformer parameters to dtype chosen by user. There seems to be no easy workaround than performing lora training in same dtype. +- https://github.com/pytorch/pytorch/issues/100945 +- https://github.com/pytorch/torchtune/blob/9b3836028fd0b48f593ea43474b86880c49a4d74/recipes/lora_finetune_distributed.py +- https://github.com/KellerJordan/modded-nanogpt/pull/68 +- https://github.com/pytorch/pytorch/pull/125394: monkey-patch method for FSDP pre/post-hooks to be triggered for method other than `forward` +- https://github.com/pytorch/pytorch/pull/127786: +- https://github.com/pytorch/pytorch/pull/130949: +- Sanity saver: create optimizers after parallelizing/activation-checkpointing models + +DTensor: +- https://github.com/pytorch/pytorch/issues/88838 +- https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/parallel/test_parallelize_api.py diff --git a/docs/finetrainers/documentation_args.md b/docs/finetrainers-src-codebase/docs/args.md similarity index 87% rename from docs/finetrainers/documentation_args.md rename to docs/finetrainers-src-codebase/docs/args.md index c168db381512b95550e6548100126f34dbb33309..0b6a0d98fd43f82fc9b405b762cfc3c76a290d66 100644 --- a/docs/finetrainers/documentation_args.md +++ b/docs/finetrainers-src-codebase/docs/args.md @@ -75,7 +75,10 @@ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embe naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers by default, and recommend adding more layers to the default list based on the model architecture. compile_modules (`List[str]`, defaults to `[]`): - Modules that should be regionally compiled with `torch.compile`. Choose one or more from ['transformer']. + Modules that should be regionally compiled with `torch.compile`. +compile_scopes (`str`, defaults to `None`): + The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as + `--compile_modules`. If `None`, will default to `regional` for all modules. DATASET ARGUMENTS ----------------- @@ -109,6 +112,9 @@ dataset_config (`str`): dataset_shuffle_buffer_size (`int`, defaults to `1`): The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default value of `1` means that the dataset will not be shuffled. +enable_precomputation (`bool`, defaults to `False`): + Whether or not to precompute the embeddings for the dataset. This is useful for faster training. If set to `True`, + the embeddings will be precomputed and saved to disk and loaded as required. precomputation_items (`int`, defaults to `512`): Number of data samples to precompute at once for memory-efficient training. The higher this value, the more disk memory will be used to save the precomputed samples (conditions and latents). @@ -118,8 +124,16 @@ precomputation_dir (`str`, defaults to `None`): precomputation_once (`bool`, defaults to `False`): Precompute embeddings from all datasets at once before training. This is useful to save time during training with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during - training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size - of your dataset(s). + training when required (that is, computing embeddings of more data samples once `precomputation_items` of them + have been exhausted across all distributed ranks). Make sure to set `precomputation_items` to a reasonable value + in line with the size of your dataset(s). +precomputation_reuse (`bool`, defaults to `False`): + Reuse precomputed embeddings from previous training runs. This is useful to save time during training + with medium/large datasets. By default, old precomputed embeddings that exist in the specified precomputation + directory, or default precomputation dir `{output_dir}/precomputed` will be deleted if this is not set to `True`. + This flag is ignored if `enable_precomputation` is `False`. The topology of the distributed training run must be + the same as the one used to precompute the embeddings for this to work correctly (this limitation will be + addressed in the future). DATALOADER_ARGUMENTS -------------------- @@ -248,8 +262,6 @@ logging_dir (`str`, defaults to `logs`): The directory where the logs will be stored. logging_steps (`int`, defaults to `1`): Training logs will be tracked every `logging_steps` steps. -allow_tf32 (`bool`, defaults to `False`): - Whether or not to allow the use of TF32 matmul on compatible hardware. nccl_timeout (`int`, defaults to `1800`): Timeout for the NCCL communication. report_to (`str`, defaults to `wandb`): @@ -260,6 +272,33 @@ verbose (`int`, defaults to `1`): - 1: Diffusers/Transformers info logging on local main process only - 2: Diffusers/Transformers debug logging on local main process only - 3: Diffusers/Transformers debug logging on all processes + +TORCH CONFIG ARGUMENTS +---------------------- +allow_tf32 (`bool`, defaults to `False`): + Whether or not to allow the use of TF32 matmul on compatible hardware. +float32_matmul_precision (`str`, defaults to `highest`): + The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium']. +``` + +### Attention Provider + +These arguments are relevant to setting attention provider for different modeling components. The attention providers may be set differently for training and validation/inference. + +``` +attn_provider_training (`str`, defaults to "native"): + The attention provider to use for training. Choose between + [ + 'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash', + '_native_math' + ] +attn_provider_inference (`str`, defaults to "native"): + The attention provider to use for validation. Choose between + [ + 'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash', + '_native_math', 'sage', 'sage_varlen', '_sage_qk_int8_pv_fp8_cuda', '_sage_qk_int8_pv_fp8_cuda_sm90', + '_sage_qk_int8_pv_fp16_cuda', '_sage_qk_int8_pv_fp16_triton', 'xformers' + ] ``` ## SFT training diff --git a/docs/finetrainers/documentation_dataset_README.md b/docs/finetrainers-src-codebase/docs/dataset/README.md similarity index 85% rename from docs/finetrainers/documentation_dataset_README.md rename to docs/finetrainers-src-codebase/docs/dataset/README.md index ee5a112f17989a72aab9497c62fea7efde2e78ab..bc882b783a63ba9fff4065af6e8b8872dd65ee22 100644 --- a/docs/finetrainers/documentation_dataset_README.md +++ b/docs/finetrainers-src-codebase/docs/dataset/README.md @@ -57,6 +57,8 @@ dataset #### CSV/JSON/JSONL format +- Supported names are: `metadata.json`, `metadata.jsonl`, `metadata.csv` + > [!NOTE] > Relevant classes to look for implementation: > - ImageFolderDataset @@ -75,6 +77,8 @@ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats. + + ## Validation Dataset Format Arguments related to validation are: @@ -148,18 +152,21 @@ For memory efficient training, it is important to precompute conditional and lat The following is a high-level overview of how datasets are loaded and preprocessed: -- Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Users can chain together multiple datasets too! For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and chain the data together. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions. +- Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Multiple datasets can be chained together. For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and create a new copy of the dataset containing multiresolution data. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions. +- When chaining multiple different datasets, make sure they are roughly the same size to avoid having smaller datasets repeatedly being used in the training loop. This is because the datasets are loaded in a round-robin fashion. For example, if you have 2 datasets of size 1000 and 2000, the first dataset will be fully seen twice before the second dataset is fully seen once by the model. - The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset. -- If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way yet]. +- If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way]. - The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1. -- The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precompute_batch_size`. This is done to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used. +- The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precomputation_items` to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used. - When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually] ## Understanding how datasets are precomputed -There are 3 arguments related to precomputation: +There are 4 arguments related to precomputation: +- `--enable_precomputation`: If set, precomputation will be enabled. The parameters that follow are only relevant if this flag is set. If this flag is not set, all models will be loaded in memory and training will take place without first precomputing embeddings. - `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion. - `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`. - `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now). +- `--precomputation_reuse`: If you're working with medium/large-size datasets and want to precompute all embeddings and re-use them across different training runs, make sure to set this flag. Batching is not yet supported for precomputation. This will be added in the future. diff --git a/docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md b/docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md new file mode 100644 index 0000000000000000000000000000000000000000..388e09c3c5285f05d224a27f0f3dff854a6a3217 --- /dev/null +++ b/docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md @@ -0,0 +1,44 @@ +# Distributed dataset debugging + +>![NOTE] +> This doc page is intended for developers and contributors. + +If the number of samples in the dataset is lower than the number of processes per node, the training will hand indefinitely. I haven't been able to pin down on how this could be fixed due to limited time, but basically: +- Start training with `--dp_degree 2` and `torchrun --standalone --nnodes=1 --nproc_per_node=2`. This launches training with DDP across 2 ranks. +- The dataset has `< dp_degree` samples +- When `datasets.distributed.split_dataset_by_node` is called, the data is distributed correctly to one rank, but the other rank hangs indefinitely. Due to this edge case, fast tests seem to fail. +- For now, we should just use `>= dp_degree` samples in the test dataset. However, should be fixed in the future. + +Minimal reproducer: + +```python +import torch +import torch.distributed as dist +from datasets import Dataset +from datasets.distributed import split_dataset_by_node +from torch.utils.data import DataLoader + +ds = Dataset.from_dict({"x": [1]}).to_iterable_dataset() + +dist.init_process_group() +rank, world_size = dist.get_rank(), dist.get_world_size() +ds = split_dataset_by_node(ds, rank=rank,world_size=world_size) +dl = DataLoader(ds) + +exhausted = torch.zeros(world_size, dtype=torch.bool) + +def loop(): + while True: + print(rank, "hello", flush=True) + yield from dl + yield "end" + +for x in loop(): + if x == "end": + exhausted[rank] = True + continue + dist.all_reduce(exhausted) + if torch.all(exhausted): + break + print(f"{rank} {x}", flush=True) +``` diff --git a/docs/finetrainers/documentation_environment.md b/docs/finetrainers-src-codebase/docs/environment.md similarity index 64% rename from docs/finetrainers/documentation_environment.md rename to docs/finetrainers-src-codebase/docs/environment.md index 9255cb6d8406af58d14338bd4beaecf8019aac28..0ae2c071fa8de0a1f407e8749cc75e81e3b2d753 100644 --- a/docs/finetrainers/documentation_environment.md +++ b/docs/finetrainers-src-codebase/docs/environment.md @@ -26,3 +26,14 @@ NVIDIA A100-SXM4-80GB, 81920 MiB ``` Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices. + +> [!IMPORTANT] +> +> For context parallelism, PyTorch 2.6+ is required. + +## Configuration + +The following environment variables may be configured to change the default behaviour of finetrainers: + +`FINETRAINERS_ATTN_PROVIDER`: Sets the default attention provider for training/validation. Defaults to `native`, as in native PyTorch SDPA. See [attention docs](./models/attention.md) for more information. +`FINETRAINERS_ATTN_CHECKS`: Whether or not to run basic sanity checks when using different attention providers. This is useful for debugging but you should leave it disabled for longer training runs. Defaults to `"0"`. Can be set to a truthy env value. diff --git a/docs/finetrainers/documentation_models_README.md b/docs/finetrainers-src-codebase/docs/models/README.md similarity index 100% rename from docs/finetrainers/documentation_models_README.md rename to docs/finetrainers-src-codebase/docs/models/README.md diff --git a/docs/finetrainers-src-codebase/docs/models/attention.md b/docs/finetrainers-src-codebase/docs/models/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..3879dc34a963e15c71a44179fffb91a04aadaea5 --- /dev/null +++ b/docs/finetrainers-src-codebase/docs/models/attention.md @@ -0,0 +1,263 @@ +# Attention backends + +Finetrainers supports multiple attention backends to support different hardware and tradeoff between speed and memory usage. The following attention implementations are supported: +- Training: + - If model uses attention masks: `flash_varlen`, `flex`, `native` + - If model does not use attention masks: `flash`, `flex`, `native`, `xformers` +- Inference: + - If model uses attention masks: `flash_varlen`, `flex`, `native`, `sage_varlen` + - If model does not use attention masks: `flash`, `flash_varlen`, `flex`, `native`, `sage`, `sage_varlen`, `xformers` + +Additionally, some specialized methods are available for debugging-specific purposes: `_native_cudnn`, `_native_efficient`, `_native_flash`, `_native_math`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`. With time, more attention-specific optimizations and custom implementations will be supported. Contributions are welcome! + +Unfortunately, due to limited time for testing, only specific versions of packages that provide these implementations are supported. Other versions may work. The supported versions will be gradually made lower for more flexibility, but for now, please use the following versions: +- `flash-attn>=2.6.3` +- `sageattention>=2.1.1` +- `xformers>=0.0.29.post3` + +This guide will help you quickly install flash-attn, sageattention, and xformers to make your models run faster and use less memory for training/inference. We'll cover installation on Linux (Ubuntu 22.04) and Windows (using WSL). + +Before you start, make sure to use a clean python virtual environment to not mess up your system seriously, or to avoid conflicting dependencies leading to failed installations which might leave the environment in hard-to-recover state. + +### Flash attention + +Providers covered: `flash`, `flash_varlen` + +The installation steps have only been tested with Ubuntu 22.04; CUDA version higher than 12.2 and 12.6. +- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`. +- You might need the following packages: `pip install packaging ninja` +- Linux: Run: `pip install flash-attn --no-build-isolation`. Verify the version with `pip show flash-attn` +- WSL: Same instruction as above should work. Native Windows might require building from source - check community guiders and follow the instruction [here](https://github.com/Dao-AILab/flash-attention). + +### Sage attention + +Providers covered: `sage`, `sage_varlen`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton` + +FP8 implementations will require CUDA compute capability of 90 or higher (H100, RTX 5090, etc.). Some may work on compute capability 89 as well (RTX 4090, for example). For FP16 implementations, compute capability of atleast 80 is required (A100, RTX 3090, etc.). For other GPUs, FP16 implementations may or may not work (this is untested by me). + +- Check your compute capability with the following command: + ```bash + python -c "import torch; print(torch.cuda.get_device_capability())" + ``` +- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`. +- You might need the following packages: `pip install triton`. For Windows, check out the [triton-windows](https://github.com/woct0rdho/triton-windows) project. +- Linux/WSL: Run: `pip install git+https://github.com/thu-ml/SageAttention`. Verify the version with `pip show sageattention`. +- Make sure to look at the official installation guide in [SageAttention](https://github.com/thu-ml/SageAttention) too! + +### xformers + +Providers covered: `xformers` + +- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`. +- Linux/WSL: Run: `pip install -U xformers --index-url https://download.pytorch.org/whl/cu126` (assuming CUDA 12.6). Verify the version with `pip show xformers`. +- Make sure to look at the official installation guide in [xformers](https://github.com/facebookresearch/xformers) too! + +---------- + +All other providers are either native PyTorch implementations or require a specific PyTorch version (for example, Flex Attention requires torch version of atleast 2.5.0). + +---------- + +## Usage + +There are two ways to use the attention dispatcher mechanism: +- Replace `scaled_dot_product_attention` globally: + ```python + import torch.nn.functional as F + from finetrainers.models.attention_dispatch import attention_dispatch + + F.scaled_dot_product_attention = attention_dispatch + ``` +- Replace all occurrences of `scaled_dot_product_attention` in your code with `attention_dispatch`. + +```python +# Use dispatcher directly +from finetrainers.models.attention_dispatch import attention_provider, AttentionProvider + +with attention_provider(AttentionProvider.FLASH_VARLEN): + model(...) + +# or, +with attention_provider("sage_varlen"): + model(...) +``` + +## Context Parallel + +References and reading material: +- https://docs.pytorch.org/tutorials/prototype/context_parallel.html +- https://insujang.github.io/2024-09-20/introducing-context-parallelism/ +- https://www.youtube.com/watch?v=ws7angQYIxI +- https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ +- https://arxiv.org/abs/2309.14509 + +There are three steps to enabling context parallelism with any model: +- Defining the context parallel plan: This is a dictionary that mentions what tensors to split and gather across CP region at different layers in the model +- Applying the CP plan with `apply_context_parallel` function: This registers the necessary hooks to split and gather tensors at the right places in the model without having to manually modify the model code. +- Running model under the `attention_provider` context manager + +For a quick example, refer to the [inference example](#inference) below. + +The CP plan is a dictionary that maps the name of the module to a list of `CPInput` or `CPOutput` objects. The keys in the dictionary are the names of the internal modules in the model, and the values are dictionaries that map a parameter identifier (either as an argument index or keyword argument as used in the forward method) to a `CPInput` or `CPOutput` object. The `CPInput` object specifies the input tensor to be split, and the `CPOutput` object specifies the output tensor to be gathered. + +```python +class ParamId: + name: Optional[str] = None + index: Optional[int] = None + +class CPInput: + split_dim: int + expected_dims: Optional[int] = None + split_output: bool = False + +class CPOutput: + gather_dim: int + expected_dims: Optional[int] = None +``` + +- The `split_dim` and `gather_dim` parameters specify the dimension along which to split or gather the tensor. When using CP with native scaled dot product attention from pytorch, the tensor shape is `[B, N, S, D]`, so the `split_dim` and `gather_dim` parameters should be set to `2` as it is the sequence dimension. +- The `expected_dims` parameter is an optional parameter that is used for sanity checking if the tensor contains the expected number of dimensions. +- By default, `CPInput`'s are split in a pre-forward hook and `CPOutput`'s are gathered in a post-forward hook. If you want to split the output of a module, you can set the `split_output` parameter to `True`. This will split the output tensor in the post-forward hook instead of the pre-forward hook. + +- Attention providers supported for training with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash` +- Attention providers supported for inference with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash` + +### Training + +To enable training with context parallelism, you need to make sure a suitable CP plan is registered for the model you are using and launch training with `--cp_degree 2`. For models supported in finetrainers, this is internally done in the [transformer metadata](https://github.com/a-r-r-o-w/finetrainers/tree/main/finetrainers/models/_metadata/transformer.py) file. For custom models, make sure to pass the `plan` argument to the `apply_context_parallel` function. + +Currently supported models include: CogVideoX, CogView4, Flux, Wan 2.1. Support for more models and attention providers is in progress. + +### Inference + +The following example shows how to run context parallel inference. For more examples and ready-to-use inference scripts, check out the [examples/inference](https://github.com/a-r-r-o-w/finetrainers/tree/main/examples/inference/) folder. + +
+ Example + +```python +import torch +import torch.distributed as dist +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.utils import export_to_video + +from finetrainers._metadata import ParamId, CPInput, CPOutput +from finetrainers.parallel.ptd import apply_context_parallel +from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch + +torch.nn.functional.scaled_dot_product_attention = attention_dispatch + + +def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module: + r"""Apply torch.compile to a model or its submodules if not already compiled.""" + if getattr(model, "_torch_compiled", False): + return model # Already compiled + + if compile_scope == "full": + model = torch.compile(model) + setattr(model, "_torch_compiled", True) + elif compile_scope == "regional": + if isinstance(model, torch.nn.ModuleList): + for name, module in model.named_children(): + if not getattr(module, "_torch_compiled", False): + compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False) + setattr(compiled_module, "_torch_compiled", True) + model.register_module(name, compiled_module) + else: + for name, module in model.named_children(): + apply_compile(module, compile_scope) + else: + raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.") + + return model + + +torch.manual_seed(0) +dist.init_process_group("nccl") +rank, world_size = dist.get_rank(), dist.get_world_size() +torch.cuda.set_device(rank) +cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"]) + +cp_plan = { + "rope": { + ParamId(index=0): CPInput(2, 4, split_output=True), + }, + "blocks.*": { + ParamId("encoder_hidden_states", 1): CPInput(1, 3), + }, + "blocks.0": { + ParamId("hidden_states", 0): CPInput(1, 3), + }, + "proj_out": [CPOutput(1, 3)], +} + +try: + model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + + apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan) + + apply_compile(pipe.transformer, compile_scope="regional") + + prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + with torch.no_grad(): + prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( + prompt=prompt, negative_prompt=negative_prompt, device="cuda", + ) + + attention_backend = "_native_flash" + generator = torch.Generator().manual_seed(0) + + # Warmup for compilation + with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"): + latents = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + height=480, + width=832, + num_frames=81, + num_inference_steps=2, + guidance_scale=5.0, + output_type="latent", + generator=generator, + ).frames[0] + + # Inference + with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"): + latents = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + num_inference_steps=30, + output_type="latent", + generator=generator, + ).frames[0] + + with torch.no_grad(): + latents = latents.to(pipe.vae.dtype) + latents_mean = ( + torch.tensor(pipe.vae.config.latents_mean) + .view(1, pipe.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = pipe.vae.decode(latents, return_dict=False)[0] + video = pipe.video_processor.postprocess_video(video, output_type="pil")[0] + + if rank == 0: + export_to_video(video, "output.mp4", fps=16) +finally: + dist.destroy_process_group() +``` + +
diff --git a/docs/finetrainers/documentation_models_cogvideox.md b/docs/finetrainers-src-codebase/docs/models/cogvideox.md similarity index 82% rename from docs/finetrainers/documentation_models_cogvideox.md rename to docs/finetrainers-src-codebase/docs/models/cogvideox.md index 7574bd2b2f3c2136a8ee9ed992feda2376e22fa0..15fa0c08053ce8f9bc46739193ea655529cd78e0 100644 --- a/docs/finetrainers/documentation_models_cogvideox.md +++ b/docs/finetrainers-src-codebase/docs/models/cogvideox.md @@ -20,9 +20,9 @@ On Windows, you will have to modify the script to a compatible format to run it. CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working: -* [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) -* [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B) -* [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B) +- [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) +- [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B) +- [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B) ## Inference @@ -45,6 +45,6 @@ export_to_video(video, "output.mp4") You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: -* [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) -* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) -* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file +- [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) +- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/docs/models/cogview4.md b/docs/finetrainers-src-codebase/docs/models/cogview4.md new file mode 100644 index 0000000000000000000000000000000000000000..1d63ca9407daf1cf87771bcf2c3f3719e7c42dd7 --- /dev/null +++ b/docs/finetrainers-src-codebase/docs/models/cogview4.md @@ -0,0 +1,94 @@ +# CogView4 + +## Training + +For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. + +Examples available: +- [Raider White Tarot cards style](../../examples/training/sft/cogview4/raider_white_tarot/) +- [Omni Edit Control LoRA](../../examples/training/control/cogview4/omni_edit/) +- [Canny Control LoRA](../../examples/training/control/cogview4/canny/) + +To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): + +```bash +chmod +x ./examples/training/sft/cogview4/raider_white_tarot/train.sh +./examples/training/sft/cogview4/raider_white_tarot/train.sh +``` + +On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] + +## Supported checkpoints + +The following checkpoints were tested with `finetrainers` and are known to be working: + +- [THUDM/CogView4-6B](https://huggingface.co/THUDM/CogView4-6B) + +## Inference + +Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: + +```diff +import torch +from diffusers import CogView4Pipeline +from diffusers.utils import export_to_video + +pipe = CogView4Pipeline.from_pretrained( + "THUDM/CogView4-6B", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogview4-lora") ++ pipe.set_adapters(["cogview4-lora"], [0.9]) + +video = pipe("").frames[0] +export_to_video(video, "output.mp4") +``` + +To use trained Control LoRAs, the following can be used for inference (ideally, you should raise a support request in Diffusers): + +
+ Control Lora inference + +```python +import torch +from diffusers import CogView4Pipeline +from diffusers.utils import load_image +from finetrainers.models.utils import _expand_linear_with_zeroed_weights +from finetrainers.patches import load_lora_weights +from finetrainers.patches.dependencies.diffusers.control import control_channel_concat + +dtype = torch.bfloat16 +device = torch.device("cuda") +generator = torch.Generator().manual_seed(0) + +pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype) + +in_channels = pipe.transformer.config.in_channels +patch_channels = pipe.transformer.patch_embed.proj.in_features +pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels) + +load_lora_weights(pipe, "/raid/aryan/cogview4-control-lora", "cogview4-lora") +pipe.to(device) + +prompt = "Make the image look like it's from an ancient Egyptian mural." +control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png") +height, width = 1024, 1024 + +with torch.no_grad(): + latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator) + control_image = pipe.image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(device=device, dtype=dtype) + control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator) + control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor + +with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]): + image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0] + +image.save("output.png") +``` +
+ +You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: + +- [CogView4 in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4) +- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) diff --git a/docs/finetrainers-src-codebase/docs/models/flux.md b/docs/finetrainers-src-codebase/docs/models/flux.md new file mode 100644 index 0000000000000000000000000000000000000000..6afc21cb9280555a38c0d06e79f89198c815a48d --- /dev/null +++ b/docs/finetrainers-src-codebase/docs/models/flux.md @@ -0,0 +1,53 @@ +# Flux + +## Training + +For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`. + +Examples available: +- [Raider White Tarot cards style](../../examples/training/sft/flux_dev/raider_white_tarot/) + +To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL): + +```bash +chmod +x ./examples/training/sft/flux_dev/raider_white_tarot/train.sh +./examples/training/sft/flux_dev/raider_white_tarot/train.sh +``` + +On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] + +> [!NOTE] +> Currently, only FLUX.1-dev is supported. It is a guidance-distilled model which directly predicts the outputs of its teacher model when the teacher is run with CFG. To match the output distribution of the distilled model with that of the teacher model, a guidance scale of 1.0 is hardcoded into the codebase. However, other values may work too but it is experimental. +> FLUX.1-schnell is not supported for training yet. It is a timestep-distilled model. Matching its output distribution for training is significantly more difficult. + +## Supported checkpoints + +The following checkpoints were tested with `finetrainers` and are known to be working: + +- [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) +- [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) + +## Inference + +Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: + +```diff +import torch +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="flux-lora") ++ pipe.set_adapters(["flux-lora"], [0.9]) + +# Make sure to set guidance_scale to 0.0 when inferencing with FLUX.1-schnell or derivative models +image = pipe("").images[0] +image.save("output.png") +``` + +You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: + +- [Flux in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) +- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) diff --git a/docs/finetrainers/documentation_models_hunyuan_video.md b/docs/finetrainers-src-codebase/docs/models/hunyuan_video.md similarity index 91% rename from docs/finetrainers/documentation_models_hunyuan_video.md rename to docs/finetrainers-src-codebase/docs/models/hunyuan_video.md index 095477df6566812a9169f878ad6fb070484eb4fc..51cde5e98c745b9c42f2deee67502a127806b5ff 100644 --- a/docs/finetrainers/documentation_models_hunyuan_video.md +++ b/docs/finetrainers-src-codebase/docs/models/hunyuan_video.md @@ -50,6 +50,6 @@ export_to_video(output, "output.mp4", fps=15) You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: -* [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) -* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) -* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file +- [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) +- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file diff --git a/docs/finetrainers/documentation_models_ltx_video.md b/docs/finetrainers-src-codebase/docs/models/ltx_video.md similarity index 88% rename from docs/finetrainers/documentation_models_ltx_video.md rename to docs/finetrainers-src-codebase/docs/models/ltx_video.md index f75132fe4539ad1464f9a74649815a41da2db2e4..7d104aed35fe0237cc89be940ac211a53d386ae1 100644 --- a/docs/finetrainers/documentation_models_ltx_video.md +++ b/docs/finetrainers-src-codebase/docs/models/ltx_video.md @@ -37,6 +37,6 @@ export_to_video(video, "output.mp4", fps=8) You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`: -* [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video) -* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) -* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file +- [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video) +- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file diff --git a/docs/finetrainers/documentation_models_optimization.md b/docs/finetrainers-src-codebase/docs/models/optimization.md similarity index 100% rename from docs/finetrainers/documentation_models_optimization.md rename to docs/finetrainers-src-codebase/docs/models/optimization.md diff --git a/docs/finetrainers/documentation_models_wan.md b/docs/finetrainers-src-codebase/docs/models/wan.md similarity index 87% rename from docs/finetrainers/documentation_models_wan.md rename to docs/finetrainers-src-codebase/docs/models/wan.md index 5d4ab31b867b161f052a4e50e2a92cf782ce91bc..8d3b160df6d48eb255977f42b7600d3f994e5fcd 100644 --- a/docs/finetrainers/documentation_models_wan.md +++ b/docs/finetrainers-src-codebase/docs/models/wan.md @@ -18,6 +18,16 @@ chmod +x ./examples/training/sft/wan/crush_smol_lora/train.sh On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows] +## Supported checkpoints + +Wan has multiple checkpoints as one can find [here](https://huggingface.co/Wan-AI). The following checkpoints were tested with `finetrainers` and are known to be working: + +- [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) +- [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) +- [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers) +- [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers) +- [Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers) + ## Inference Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: @@ -102,4 +112,4 @@ You can refer to the following guides to know more about the model pipeline and - [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan) - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference) -- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) \ No newline at end of file +- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras) diff --git a/docs/finetrainers/documentation_optimizers.md b/docs/finetrainers-src-codebase/docs/optimizer.md similarity index 100% rename from docs/finetrainers/documentation_optimizers.md rename to docs/finetrainers-src-codebase/docs/optimizer.md diff --git a/docs/finetrainers/documentation_parallel_processing_README.md b/docs/finetrainers-src-codebase/docs/parallel/README.md similarity index 87% rename from docs/finetrainers/documentation_parallel_processing_README.md rename to docs/finetrainers-src-codebase/docs/parallel/README.md index 02a897d19c68a24cdc3ee4db7325e09da4985d37..9fab017d9f7dcc54633408be791c7a09f4aaca92 100644 --- a/docs/finetrainers/documentation_parallel_processing_README.md +++ b/docs/finetrainers-src-codebase/docs/parallel/README.md @@ -14,11 +14,12 @@ As an experiment for comparing performance of different training backends, Finet ## Support matrix -There are various algorithms for parallel training. Currently, we only support: +Currently supported parallelizations include: - [DDP](https://pytorch.org/docs/stable/notes/ddp.html) - [FSDP2](https://pytorch.org/docs/stable/fsdp.html) - [HSDP](https://pytorch.org/docs/stable/fsdp.html) -- [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) +- [CP](https://docs.pytorch.org/tutorials/prototype/context_parallel.html) + ## Training @@ -28,7 +29,7 @@ The following parameters are relevant for launching training: - `pp_degree`: The degree of pipeline parallelism. Currently unsupported. - `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`. - `dp_shards`: The number of shards for data parallelism. Defaults to `1`. -- `cp_degree`: The degree of context parallelism. Currently unsupported. +- `cp_degree`: The degree of context parallelism. - `tp_degree`: The degree of tensor parallelism. For launching training with the Pytorch DTensor backend, use the following: @@ -57,3 +58,7 @@ accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0 # Multi-node - Nx8 GPUs available # TODO(aryan): Add slurm script ``` + +## Inference + +For inference-only purposes, the example implementation can be found in the [examples/inference/](../../examples/inference/) directory. diff --git a/docs/finetrainers/documentation_trainers_control_trainer.md b/docs/finetrainers-src-codebase/docs/trainer/control_trainer.md similarity index 100% rename from docs/finetrainers/documentation_trainers_control_trainer.md rename to docs/finetrainers-src-codebase/docs/trainer/control_trainer.md diff --git a/docs/finetrainers/documentation_trainers_sft_trainer.md b/docs/finetrainers-src-codebase/docs/trainer/sft_trainer.md similarity index 100% rename from docs/finetrainers/documentation_trainers_sft_trainer.md rename to docs/finetrainers-src-codebase/docs/trainer/sft_trainer.md diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/README.md b/docs/finetrainers-src-codebase/examples/_legacy/training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfa1672e684bba5a8cfd289f700e24d4c5880012 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/README.md @@ -0,0 +1,459 @@ +# CogVideoX Factory 🧪 + +[中文阅读](./README_zh.md) + +Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼 + + + + + +
+ +**Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)! + +## Quickstart + +Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`. + +Then download a dataset: + +```bash +# install `huggingface_hub` +huggingface-cli download \ + --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \ + --local-dir video-dataset-disney +``` + +Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice): + +```bash +# For LoRA finetuning of the text-to-video CogVideoX models +./train_text_to_video_lora.sh + +# For full finetuning of the text-to-video CogVideoX models +./train_text_to_video_sft.sh + +# For LoRA finetuning of the image-to-video CogVideoX models +./train_image_to_video_lora.sh +``` + +Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference: + +```diff +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora") ++ pipe.set_adapters(["cogvideox-lora"], [1.0]) + +video = pipe("").frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details): + +```python +from diffusers import CogVideoXImageToVideoPipeline + +pipe = CogVideoXImageToVideoPipeline.from_pretrained( + "THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16 +).to("cuda") + +# ... + +del pipe.transformer.patch_embed.pos_embedding +pipe.transformer.patch_embed.use_learned_positional_embeddings = False +pipe.transformer.config.use_learned_positional_embeddings = False +``` + +You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py). + +Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible. + +## Prepare Dataset and Training + +Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example. + +- Configure environment variables as per your choice: + + ```bash + export TORCH_LOGS="+dynamo,recompiles,graph_breaks" + export TORCHDYNAMO_VERBOSE=1 + export WANDB_MODE="offline" + export NCCL_P2P_DISABLE=1 + export TORCH_NCCL_ENABLE_MONITORING=0 + ``` + +- Configure which GPUs to use for training: `GPU_IDS="0,1"` + +- Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example: + + ```bash + LEARNING_RATES=("1e-4" "1e-3") + LR_SCHEDULES=("cosine_with_restarts") + OPTIMIZERS=("adamw" "adam") + MAX_TRAIN_STEPS=("3000") + ``` + +- Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`. + +- Specify the absolute paths and columns/files for captions and videos. + + ```bash + DATA_ROOT="/path/to/my/datasets/video-dataset-disney" + CAPTION_COLUMN="prompt.txt" + VIDEO_COLUMN="videos.txt" + ``` + +- Launch experiments sweeping different hyperparameters: + ``` + for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done + done + ``` + + To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`. + +Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below. + +## Memory requirements + + + + + + + + + + + + + + + + + + + + + + + + + +
CogVideoX LoRA Finetuning
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
CogVideoX Full Finetuning
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
+ +Supported and verified memory optimizations for training include: + +- `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either. +- Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work +- DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations. + +> [!IMPORTANT] +> The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs. +> +> If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`. + +### LoRA finetuning + +> [!NOTE] +> The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly. +> +> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> or provide a URL to a valid and accessible image. + +
+ AdamW + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 | + +
+ +
+ AdamW (8-bit bitsandbytes) + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 | +| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 | +| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +With `train_batch_size = 1`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 | +| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 | + +With `train_batch_size = 4`: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100. + +With `train_batch_size = 1`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 | +| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 | + +With `train_batch_size = 4`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 | +| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 | + +
+ +### Full finetuning + +> [!NOTE] +> The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly. +> +> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> or provide a URL to a valid and accessible image. + +> [!NOTE] +> Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified. + +
+ AdamW + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +
+ +
+ AdamW (8-bit bitsandbytes) + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 | +| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 | +| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +With `train_batch_size = 1`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 | +| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 | + +With `train_batch_size = 4`: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 | +| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100. + +With `train_batch_size = 1`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 | +| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 | + +With `train_batch_size = 4`: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 | +| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 | + +
+ +> [!NOTE] +> - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script. +> +> - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing. + + + + + + + + +
Slaying OOMs with PyTorch
+ +## TODOs + +- [x] Make scripts compatible with DDP +- [ ] Make scripts compatible with FSDP +- [x] Make scripts compatible with DeepSpeed +- [ ] vLLM-powered captioning script +- [x] Multi-resolution/frame support in `prepare_dataset.py` +- [ ] Analyzing traces for potential speedups and removing as many syncs as possible +- [x] Test scripts with memory-efficient optimizer from bitsandbytes +- [x] Test scripts with CPUOffloadOptimizer, etc. +- [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao)) +- [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out) +- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100) + +> [!IMPORTANT] +> Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training. diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md b/docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..f18a62218a1b3f3d3b388cf1dfcb29c273b74326 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md @@ -0,0 +1,455 @@ +# CogVideoX Factory 🧪 + +[Read in English](./README.md) + +在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼 + + + + + +
+ +## 快速开始 + +克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。 + +接着下载数据集: + +``` +# 安装 `huggingface_hub` +huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney +``` + +然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项): + +``` +# 对 CogVideoX 模型进行文本到视频的 LoRA 微调 +./train_text_to_video_lora.sh + +# 对 CogVideoX 模型进行文本到视频的完整微调 +./train_text_to_video_sft.sh + +# 对 CogVideoX 模型进行图像到视频的 LoRA 微调 +./train_image_to_video_lora.sh +``` + +假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理: + +``` +import torch +from diffusers import CogVideoXPipeline +from diffusers import export_to_video + +pipe = CogVideoXPipeline.from_pretrained( + "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16 +).to("cuda") ++ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"]) ++ pipe.set_adapters(["cogvideox-lora"], [1.0]) + +video = pipe("").frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。 + +**注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装 +diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。 + +以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。 + +## 训练 + +在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。 + +- 根据你的需求配置环境变量: + + ``` + export TORCH_LOGS="+dynamo,recompiles,graph_breaks" + export TORCHDYNAMO_VERBOSE=1 + export WANDB_MODE="offline" + export NCCL_P2P_DISABLE=1 + export TORCH_NCCL_ENABLE_MONITORING=0 + ``` + +- 配置用于训练的 GPU:`GPU_IDS="0,1"` + +- 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例: + + ``` + LEARNING_RATES=("1e-4" "1e-3") + LR_SCHEDULES=("cosine_with_restarts") + OPTIMIZERS=("adamw" "adam") + MAX_TRAIN_STEPS=("3000") + ``` + +- 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"` + 。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed + 等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。 + +- 指定字幕和视频的绝对路径以及列/文件。 + + ``` + DATA_ROOT="/path/to/my/datasets/video-dataset-disney" + CAPTION_COLUMN="prompt.txt" + VIDEO_COLUMN="videos.txt" + ``` + +- 运行实验,遍历不同的超参数: + ``` + for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done + done + ``` + +要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。 + +注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。 + +## 内存需求 + + + + + + + + + + + + + + + + + + + + + + + + + +
CogVideoX LoRA 微调
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
CogVideoX 全量微调
THUDM/CogVideoX-2bTHUDM/CogVideoX-5b
+ +支持和验证的训练内存优化包括: + +- `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao) + 。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) + 阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU + 优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile` + 。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。 +- 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers) + 的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。 +- DeepSpeed Zero2:由于我们依赖 `accelerate` + ,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed + Zero2 优化训练。 + +> [!重要提示] +> 内存需求是运行 `training/prepare_dataset.py` +> +后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。 +> +> 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。 + +### LoRA微调 + +> [!重要提示] +> 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。 +> +> 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> 或提供一个有效且可访问的图像URL。 + +
+ AdamW + +**注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 | + +
+ +
+ AdamW (8-bit bitsandbytes) + +**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 | +| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 | +| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 | +| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 | + +
+ +
+ AdamW + CPUOffloadOptimizer (with gradient offloading) + +**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。 + +当 `train_batch_size = 1` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 | +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 | +| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 | +| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 | +| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 | + +当 `train_batch_size = 4` 时: + +| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 | +| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 | +| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 | +| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 | +| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 | +| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 | + +
+ +
+ DeepSpeed (AdamW + CPU/Parameter offloading) + +**注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。 + +当 `train_batch_size = 1` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 | +| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 | + +当 `train_batch_size = 4` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 | +| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 | + +
+ +### Full finetuning + +> [!注意] +> 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。 +> +> 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧: +> `ffmpeg -i input.mp4 -frames:v 1 frame.png`, +> 或提供一个有效且可访问的图像 URL。 + +> [!注意] +> 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。 + +
+ AdamW + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 | +| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM | + +
+ +
+ AdamW (8-bit 量化) + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 | +| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 | +| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 | + +
+ +
+ AdamW + CPUOffloadOptimizer(带有梯度卸载) + +当 `train_batch_size = 1` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 | +| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 | + +当 `train_batch_size = 4` 时: + +| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 | +| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 | + +
+ +
+ DeepSpeed(AdamW + CPU/参数卸载) + +**注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。 + +当 `train_batch_size = 1` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 | +| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 | + +当 `train_batch_size = 4` 时: + +| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing | +|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:| +| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 | +| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 | + +
+ +> [!注意] +> - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载 + VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。 +> +> - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。 + + + + + + + + +
Slaying OOMs with PyTorch
+ +## 待办事项 + +- [x] 使脚本兼容 DDP +- [ ] 使脚本兼容 FSDP +- [x] 使脚本兼容 DeepSpeed +- [ ] 基于 vLLM 的字幕脚本 +- [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数 +- [ ] 分析性能瓶颈并尽可能减少同步操作 +- [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法 +- [x] 使用 bitsandbytes 的节省内存优化器测试脚本 +- [x] 使用 CPUOffloadOptimizer 等测试脚本 +- [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错) +- [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错) +- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化) + +> [!重要] +> 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。 \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fed7da6d85df5384bd7e29a7786e15616967ca --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py @@ -0,0 +1,484 @@ +import argparse + + +def _get_model_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--dataset_file", + type=str, + default=None, + help=("Path to a CSV file if loading prompts/video paths using this format."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=int, + default=[49], + help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53" + ) + parser.add_argument( + "--load_tensors", + action="store_true", + help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.", + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _get_validation_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + default=False, + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + + +def _get_training_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. " + "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this " + "argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-sft", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", + type=int, + default=49, + help="All input videos will be truncated to these many frames.", + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability when finetuning image-to-video.", + ) + parser.add_argument( + "--ignore_learned_positional_embeddings", + action="store_true", + default=False, + help=( + "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting " + "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it " + "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why." + ), + ) + + +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy", "came"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit", + action="store_true", + help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.", + ) + parser.add_argument( + "--use_4bit", + action="store_true", + help="Whether or not to use 4-bit optimizers from `torchao`.", + ) + parser.add_argument( + "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers." + ) + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta2", + type=float, + default=0.95, + help="The beta2 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument( + "--prodigy_decouple", + action="store_true", + help="Use AdamW style decoupled weight decay.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for optimizer.", + ) + parser.add_argument( + "--epsilon", + type=float, + default=1e-8, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--prodigy_use_bias_correction", + action="store_true", + help="Turn on Adam's bias correction.", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + parser.add_argument( + "--use_cpu_offload_optimizer", + action="store_true", + help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.", + ) + parser.add_argument( + "--offload_gradients", + action="store_true", + help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.", + ) + + +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--nccl_timeout", + type=int, + default=600, + help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + _get_model_args(parser) + _get_dataset_args(parser) + _get_training_args(parser) + _get_validation_args(parser) + _get_optimizer_args(parser) + _get_configuration_args(parser) + + return parser.parse_args() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c6be5aa14a67381cd5bd9116866eec801464d4 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py @@ -0,0 +1,1016 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import random +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX LoRA Finetune + + + +## Model description + +This is a lora finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image + +pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +image = load_image("/path/to/image.png") +video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "image-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXImageToVideoPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +def run_validation( + args: Dict[str, Any], + accelerator: Accelerator, + transformer, + scheduler, + model_config: Dict[str, Any], + weight_dtype: torch.dtype, +) -> None: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + # These changes will also be required when trying to run inference with the trained lora + if args.ignore_learned_positional_embeddings: + del transformer.patch_embed.pos_embedding + transformer.patch_embed.use_learned_positional_embeddings = False + transformer.config.use_learned_positional_embeddings = False + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + CogVideoXImageToVideoPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + transformer_ = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + "image_to_video": True, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + images = batch["images"].to(accelerator.device, non_blocking=True) + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype + ) + image_noise_sigma = torch.exp(image_noise_sigma) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_images).latent_dist + + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + image_latent_dist = DiagonalGaussianDistribution(images) + latent_dist = DiagonalGaussianDistribution(videos) + + image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR + image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = latent_dist.sample() * VAE_SCALING_FACTOR + video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + batch_size, num_frames, num_channels, height, width = video_latents.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=accelerator.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + + ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None, + ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0) + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = video_latents + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Validation + should_run_validation = args.validation_prompt is not None and ( + args.validation_steps is not None and global_step % args.validation_steps == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + should_run_validation = args.validation_prompt is not None and ( + args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXImageToVideoPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..c40b6803e24e1272aef72fc4c494c1356093dea6 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py @@ -0,0 +1,947 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import random +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType, init_empty_weights +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX Full Finetune + + + +## Model description + +This is a full finetune of the CogVideoX model `{base_model}`. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "image-to-video", + "diffusers-training", + "diffusers", + "cogvideox", + "cogvideox-diffusers", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXImageToVideoPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +def run_validation( + args: Dict[str, Any], + accelerator: Accelerator, + transformer, + scheduler, + model_config: Dict[str, Any], + weight_dtype: torch.dtype, +) -> None: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + if args.ignore_learned_positional_embeddings: + del transformer.patch_embed.pos_embedding + transformer.patch_embed.use_learned_positional_embeddings = False + transformer.config.use_learned_positional_embeddings = False + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + transformer.requires_grad_(True) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + model.save_pretrained( + os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB" + ) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + init_under_meta = False + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + with init_empty_weights(): + transformer_ = CogVideoXTransformer3DModel.from_config( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + init_under_meta = True + + load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer")) + transformer_.register_to_config(**load_model.config) + transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta) + del load_model + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + cast_training_params([transformer], dtype=torch.float32) + + transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + "image_to_video": True, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-sft" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + images = batch["images"].to(accelerator.device, non_blocking=True) + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype + ) + image_noise_sigma = torch.exp(image_noise_sigma) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_images).latent_dist + + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + image_latent_dist = DiagonalGaussianDistribution(images) + latent_dist = DiagonalGaussianDistribution(videos) + + image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR + image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = latent_dist.sample() * VAE_SCALING_FACTOR + video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + batch_size, num_frames, num_channels, height, width = video_latents.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=accelerator.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None, + ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0) + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = video_latents + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # Checkpointing + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Validation + should_run_validation = args.validation_prompt is not None and ( + args.validation_steps is not None and global_step % args.validation_steps == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + should_run_validation = args.validation_prompt is not None and ( + args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0 + ) + if should_run_validation: + run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype) + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + + transformer.save_pretrained( + os.path.join(args.output_dir, "transformer"), + safe_serialization=True, + max_shard_size="5GB", + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + validation_images = args.validation_images.split(args.validation_prompt_separator) + for validation_image, validation_prompt in zip(validation_images, validation_prompts): + pipeline_args = { + "image": load_image(validation_image), + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f2c5d5adad7976e2ed257bfc3761b115a10fef --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py @@ -0,0 +1,955 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) # isort:skip + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX LoRA Finetune + + + +## Model description + +This is a lora finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model = unwrap_model(accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + CogVideoXPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + transformer_ = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + latent_dist = DiagonalGaussianDistribution(videos) + + videos = latent_dist.sample() * VAE_SCALING_FACTOR + videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + model_input = videos + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=model_input.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..f0afb64bc1aec06fe5c9ece8c1d59587e0d9b597 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py @@ -0,0 +1,917 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +from accelerate import Accelerator, DistributedType, init_empty_weights +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + + +from args import get_args # isort:skip +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import ( + get_gradient_norm, + get_optimizer, + prepare_rotary_positional_embeddings, + print_memory, + reset_memory, + unwrap_model, +) # isort:skip + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# CogVideoX Full Finetune + + + +## Model description + +This is a full finetune of the CogVideoX model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda") + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX. + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "cogvideox", + "cogvideox-diffusers", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: CogVideoXPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + return { + "videos": videos, + "prompts": prompts, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + transformer.requires_grad_(True) + + VAE_SCALING_FACTOR = vae.config.scaling_factor + VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1) + RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL + RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + model: CogVideoXTransformer3DModel + model = unwrap_model(accelerator, model) + model.save_pretrained( + os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB" + ) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + init_under_meta = False + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))): + transformer_ = unwrap_model(accelerator, model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}") + else: + with init_empty_weights(): + transformer_ = CogVideoXTransformer3DModel.from_config( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + init_under_meta = True + + load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer")) + transformer_.register_to_config(**load_model.config) + transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta) + del load_model + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-sft" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + del vae, text_encoder + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + logs = {} + + with accelerator.accumulate(models_to_accumulate): + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + + # Encode videos + if not args.load_tensors: + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + latent_dist = DiagonalGaussianDistribution(videos) + + videos = latent_dist.sample() * VAE_SCALING_FACTOR + videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + model_input = videos + + # Encode prompts + if not args.load_tensors: + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (batch_size,), + dtype=torch.int64, + device=model_input.device, + ) + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SCALE_FACTOR_SPATIAL, + width=width * VAE_SCALE_FACTOR_SPATIAL, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL, + patch_size=model_config.patch_size, + patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + base_height=RoPE_BASE_HEIGHT, + base_width=RoPE_BASE_WIDTH, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + weights = 1 / (1 - alphas_cumprod[timesteps]) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean( + (weights * (model_pred - target) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: + gradient_norm_before_clip = get_gradient_norm(transformer.parameters()) + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer.parameters()) + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs.update( + { + "loss": loss.detach().item(), + "lr": last_lr, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(accelerator, transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + "max_sequence_length": model_config.max_text_seq_length, + } + + log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=False, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(accelerator, transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + + transformer.save_pretrained( + os.path.join(args.output_dir, "transformer"), + safe_serialization=True, + max_shard_size="5GB", + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec47b0b33e089cad3935d0dd7951137f27db9452 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py @@ -0,0 +1,428 @@ +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms as TT +from accelerate.logging import get_logger +from torch.utils.data import Dataset, Sampler +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] + + +class VideoDataset(Dataset): + def __init__( + self, + data_root: str, + dataset_file: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + max_num_frames: int = 49, + id_token: Optional[str] = None, + height_buckets: List[int] = None, + width_buckets: List[int] = None, + frame_buckets: List[int] = None, + load_tensors: bool = False, + random_flip: Optional[float] = None, + image_to_video: bool = False, + ) -> None: + super().__init__() + + self.data_root = Path(data_root) + self.dataset_file = dataset_file + self.caption_column = caption_column + self.video_column = video_column + self.max_num_frames = max_num_frames + self.id_token = f"{id_token.strip()} " if id_token else "" + self.height_buckets = height_buckets or HEIGHT_BUCKETS + self.width_buckets = width_buckets or WIDTH_BUCKETS + self.frame_buckets = frame_buckets or FRAME_BUCKETS + self.load_tensors = load_tensors + self.random_flip = random_flip + self.image_to_video = image_to_video + + self.resolutions = [ + (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets + ] + + # Two methods of loading data are supported. + # - Using a CSV: caption_column and video_column must be some column in the CSV. One could + # make use of other columns too, such as a motion score or aesthetic score, by modifying the + # logic in CSV processing. + # - Using two files containing line-separate captions and relative paths to videos. + # For a more detailed explanation about preparing dataset format, checkout the README. + if dataset_file is None: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_local_path() + else: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_csv() + + if len(self.video_paths) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.video_transforms = transforms.Compose( + [ + transforms.RandomHorizontalFlip(random_flip) + if random_flip + else transforms.Lambda(self.identity_transform), + transforms.Lambda(self.scale_transform), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + @staticmethod + def identity_transform(x): + return x + + @staticmethod + def scale_transform(x): + return x / 255.0 + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + if self.load_tensors: + image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index]) + + # This is hardcoded for now. + # The VAE's temporal compression ratio is 4. + # The VAE's spatial compression ratio is 8. + latent_num_frames = video_latents.size(1) + if latent_num_frames % 2 == 0: + num_frames = latent_num_frames * 4 + else: + num_frames = (latent_num_frames - 1) * 4 + 1 + + height = video_latents.size(2) * 8 + width = video_latents.size(3) * 8 + + return { + "prompt": prompt_embeds, + "image": image_latents, + "video": video_latents, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + }, + } + else: + image, video, _ = self._preprocess_video(self.video_paths[index]) + + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: + if not self.data_root.exists(): + raise ValueError("Root folder for videos does not exist") + + prompt_path = self.data_root.joinpath(self.caption_column) + video_path = self.data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + if not self.load_tensors and any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: + df = pd.read_csv(self.dataset_file) + prompts = df[self.caption_column].tolist() + video_paths = df[self.video_column].tolist() + video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + r""" + Loads a single video, or latent and prompt embedding, based on initialization parameters. + + If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, + F, C, H and W are the frames, channels, height and width of the input video. + + If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. + F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length + and embedding dimension of prompt embeddings. + """ + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: self.max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + filename_without_ext = path.name.split(".")[0] + pt_filename = f"{filename_without_ext}.pt" + + # The current path is something like: /a/b/c/d/videos/00001.mp4 + # We need to reach: /a/b/c/d/video_latents/00001.pt + image_latents_path = path.parent.parent.joinpath("image_latents") + video_latents_path = path.parent.parent.joinpath("video_latents") + embeds_path = path.parent.parent.joinpath("prompt_embeds") + + if ( + not video_latents_path.exists() + or not embeds_path.exists() + or (self.image_to_video and not image_latents_path.exists()) + ): + raise ValueError( + f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." + ) + + if self.image_to_video: + image_latent_filepath = image_latents_path.joinpath(pt_filename) + video_latent_filepath = video_latents_path.joinpath(pt_filename) + embeds_filepath = embeds_path.joinpath(pt_filename) + + if not video_latent_filepath.is_file() or not embeds_filepath.is_file(): + if self.image_to_video: + image_latent_filepath = image_latent_filepath.as_posix() + video_latent_filepath = video_latent_filepath.as_posix() + embeds_filepath = embeds_filepath.as_posix() + raise ValueError( + f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." + ) + + images = ( + torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None + ) + latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) + embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) + + return images, latents, embeds + + +class VideoDatasetWithResizing(VideoDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.video_reshape_mode = video_reshape_mode + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`VideoDataset`): + A PyTorch dataset object that is an instance of `VideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.resolutions} + + self._raised_warning_for_drop_last = False + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + for index, data in enumerate(self.data_source): + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12b29fa3de938b8890fb8e12511900e74666e016 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 + +import argparse +import functools +import json +import os +import pathlib +import queue +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from diffusers import AutoencoderKLCogVideoX +from diffusers.training_utils import set_seed +from diffusers.utils import export_to_video, get_logger +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm +from transformers import T5EncoderModel, T5Tokenizer + + +import decord # isort:skip + +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip + + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def check_height(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_width(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_frames(x: Any) -> int: + x = int(x) + if x % 4 != 0 and x % 4 != 1: + raise argparse.ArgumentTypeError( + f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria." + ) + return x + + +def get_args() -> Dict[str, Any]: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + type=str, + default="THUDM/CogVideoX-2b", + help="Hugging Face model ID to use for tokenizer, text encoder and VAE.", + ) + parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.") + parser.add_argument( + "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data." + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=check_height, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=check_width, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=check_frames, + default=[49], + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument( + "--save_image_latents", + action="store_true", + help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", + ) + parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.") + parser.add_argument( + "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings." + ) + parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.") + parser.add_argument( + "--save_latents_and_embeddings", + action="store_true", + help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.", + ) + parser.add_argument( + "--use_slicing", + action="store_true", + help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument( + "--use_tiling", + action="store_true", + help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.") + parser.add_argument( + "--num_decode_threads", + type=int, + default=0, + help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="fp32", + help="Data type to use when generating latents and prompt embeddings.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") + parser.add_argument( + "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts." + ) + return parser.parse_args() + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompts: List[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool = False, +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +to_pil_image = transforms.ToPILImage(mode="RGB") + + +def save_image(image: torch.Tensor, path: pathlib.Path) -> None: + image = image.to(dtype=torch.float32).clamp(-1, 1) + image = to_pil_image(image.float()) + image.save(path) + + +def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None: + video = video.to(dtype=torch.float32).clamp(-1, 1) + video = [to_pil_image(frame) for frame in video] + export_to_video(video, path, fps=fps) + + +def save_prompt(prompt: str, path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(prompt) + + +def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(json.dumps(metadata)) + + +@torch.no_grad() +def serialize_artifacts( + batch_size: int, + fps: int, + images_dir: Optional[pathlib.Path] = None, + image_latents_dir: Optional[pathlib.Path] = None, + videos_dir: Optional[pathlib.Path] = None, + video_latents_dir: Optional[pathlib.Path] = None, + prompts_dir: Optional[pathlib.Path] = None, + prompt_embeds_dir: Optional[pathlib.Path] = None, + images: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + videos: Optional[torch.Tensor] = None, + video_latents: Optional[torch.Tensor] = None, + prompts: Optional[List[str]] = None, + prompt_embeds: Optional[torch.Tensor] = None, +) -> None: + num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) + metadata = [{"num_frames": num_frames, "height": height, "width": width}] + + data_folder_mapper_list = [ + (images, images_dir, lambda img, path: save_image(img[0], path), "png"), + (image_latents, image_latents_dir, torch.save, "pt"), + (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"), + (video_latents, video_latents_dir, torch.save, "pt"), + (prompts, prompts_dir, save_prompt, "txt"), + (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), + (metadata, videos_dir, save_metadata, "txt"), + ] + filenames = [uuid.uuid4() for _ in range(batch_size)] + + for data, folder, save_fn, extension in data_folder_mapper_list: + if data is None: + continue + for slice, filename in zip(data, filenames): + if isinstance(slice, torch.Tensor): + slice = slice.clone().to("cpu") + path = folder.joinpath(f"{filename}.{extension}") + save_fn(slice, path) + + +def save_intermediates(output_queue: queue.Queue) -> None: + while True: + try: + item = output_queue.get(timeout=30) + if item is None: + break + serialize_artifacts(**item) + + except queue.Empty: + continue + + +@torch.no_grad() +def main(): + args = get_args() + set_seed(args.seed) + + output_dir = pathlib.Path(args.output_dir) + tmp_dir = output_dir.joinpath("tmp") + + output_dir.mkdir(parents=True, exist_ok=True) + tmp_dir.mkdir(parents=True, exist_ok=True) + + # Create task queue for non-blocking serializing of artifacts + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers) + save_future = save_thread.submit(save_intermediates, output_queue) + + # Initialize distributed processing + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + # Single GPU + local_rank = 0 + world_size = 1 + rank = 0 + torch.cuda.set_device(rank) + + # Create folders where intermediate tensors from each rank will be saved + images_dir = tmp_dir.joinpath(f"images/{rank}") + image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}") + videos_dir = tmp_dir.joinpath(f"videos/{rank}") + video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") + prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") + prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") + + images_dir.mkdir(parents=True, exist_ok=True) + image_latents_dir.mkdir(parents=True, exist_ok=True) + videos_dir.mkdir(parents=True, exist_ok=True) + video_latents_dir.mkdir(parents=True, exist_ok=True) + prompts_dir.mkdir(parents=True, exist_ok=True) + prompt_embeds_dir.mkdir(parents=True, exist_ok=True) + + weight_dtype = DTYPE_MAPPING[args.dtype] + target_fps = args.target_fps + + # 1. Dataset + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": False, + "random_flip": args.random_flip, + "image_to_video": args.save_image_latents, + } + if args.video_reshape_mode is None: + dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + original_dataset_size = len(dataset) + + # Split data among GPUs + if world_size > 1: + samples_per_gpu = original_dataset_size // world_size + start_index = rank * samples_per_gpu + end_index = start_index + samples_per_gpu + if rank == world_size - 1: + end_index = original_dataset_size # Make sure the last GPU gets the remaining data + + # Slice the data + dataset.prompts = dataset.prompts[start_index:end_index] + dataset.video_paths = dataset.video_paths[start_index:end_index] + else: + pass + + rank_dataset_size = len(dataset) + + # 2. Dataloader + def collate_fn(data): + prompts = [x["prompt"] for x in data[0]] + + images = None + if args.save_image_latents: + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + dataloader = DataLoader( + dataset, + batch_size=1, + sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # 3. Prepare models + device = f"cuda:{rank}" + + if args.save_latents_and_embeddings: + tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.to(device) + + vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype) + vae = vae.to(device) + + if args.use_slicing: + vae.enable_slicing() + if args.use_tiling: + vae.enable_tiling() + + # 4. Compute latents and embeddings and save + if rank == 0: + iterator = tqdm( + dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size + ) + else: + iterator = dataloader + + for step, batch in enumerate(iterator): + try: + images = None + image_latents = None + video_latents = None + prompt_embeds = None + + if args.save_image_latents: + images = batch["images"].to(device, non_blocking=True) + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + videos = batch["videos"].to(device, non_blocking=True) + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + prompts = batch["prompts"] + + # Encode videos & images + if args.save_latents_and_embeddings: + if args.use_slicing: + if args.save_image_latents: + encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)] + image_latents = torch.cat(encoded_slices) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)] + video_latents = torch.cat(encoded_slices) + + else: + if args.save_image_latents: + image_latents = vae._encode(images) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = vae._encode(videos) + + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + # Encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + args.max_sequence_length, + device, + weight_dtype, + requires_grad=False, + ) + + if images is not None: + images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 + + videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2 + + output_queue.put( + { + "batch_size": len(prompts), + "fps": target_fps, + "images_dir": images_dir, + "image_latents_dir": image_latents_dir, + "videos_dir": videos_dir, + "video_latents_dir": video_latents_dir, + "prompts_dir": prompts_dir, + "prompt_embeds_dir": prompt_embeds_dir, + "images": images, + "image_latents": image_latents, + "videos": videos, + "video_latents": video_latents, + "prompts": prompts, + "prompt_embeds": prompt_embeds, + } + ) + + except Exception: + print("-------------------------") + print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}") + traceback.print_exc() + print("-------------------------") + + # 5. Complete distributed processing + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + output_queue.put(None) + save_thread.shutdown(wait=True) + save_future.result() + + # 6. Combine results from each rank + if rank == 0: + print( + f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`" + ) + + # Move files from each rank to common directory + for subfolder, extension in [ + ("images", "png"), + ("image_latents", "pt"), + ("videos", "mp4"), + ("video_latents", "pt"), + ("prompts", "txt"), + ("prompt_embeds", "pt"), + ("videos", "txt"), + ]: + tmp_subfolder = tmp_dir.joinpath(subfolder) + combined_subfolder = output_dir.joinpath(subfolder) + combined_subfolder.mkdir(parents=True, exist_ok=True) + pattern = f"*.{extension}" + + for file in tmp_subfolder.rglob(pattern): + file.replace(combined_subfolder / file.name) + + # Remove temporary directories + def rmdir_recursive(dir: pathlib.Path) -> None: + for child in dir.iterdir(): + if child.is_file(): + child.unlink() + else: + rmdir_recursive(child) + dir.rmdir() + + rmdir_recursive(tmp_dir) + + # Combine prompts and videos into individual text files and single jsonl + prompts_folder = output_dir.joinpath("prompts") + prompts = [] + stems = [] + + for filename in prompts_folder.rglob("*.txt"): + with open(filename, "r") as file: + prompts.append(file.read().strip()) + stems.append(filename.stem) + + prompts_txt = output_dir.joinpath("prompts.txt") + videos_txt = output_dir.joinpath("videos.txt") + data_jsonl = output_dir.joinpath("data.jsonl") + + with open(prompts_txt, "w") as file: + for prompt in prompts: + file.write(f"{prompt}\n") + + with open(videos_txt, "w") as file: + for stem in stems: + file.write(f"videos/{stem}.mp4\n") + + with open(data_jsonl, "w") as file: + for prompt, stem in zip(prompts, stems): + video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt") + with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file: + metadata = json.loads(metadata_file.read()) + + data = { + "prompt": prompt, + "prompt_embed": f"prompt_embeds/{stem}.pt", + "image": f"images/{stem}.png", + "image_latent": f"image_latents/{stem}.pt", + "video": f"videos/{stem}.mp4", + "video_latent": f"video_latents/{stem}.pt", + "metadata": metadata, + } + file.write(json.dumps(data) + "\n") + + print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`") + + +if __name__ == "__main__": + main() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09f9e8cd8466c60a9c266879223f8fe4f304a524 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py @@ -0,0 +1 @@ +from .text_encoder import compute_prompt_embeddings diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9237875dd621baca5979a5f685288486ef1532d5 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str, + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool = False, +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/utils.py b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d03b14217403908233405a2005acf2f8703431c3 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/utils.py @@ -0,0 +1,260 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.utils.torch_utils import is_compiled_module + + +logger = get_logger(__name__) + + +def get_optimizer( + params_to_optimize, + optimizer_name: str = "adam", + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.98, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + prodigy_decouple: bool = False, + prodigy_use_bias_correction: bool = False, + prodigy_safeguard_warmup: bool = False, + use_8bit: bool = False, + use_4bit: bool = False, + use_torchao: bool = False, + use_deepspeed: bool = False, + use_cpu_offload_optimizer: bool = False, + offload_gradients: bool = False, +) -> torch.optim.Optimizer: + optimizer_name = optimizer_name.lower() + + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=learning_rate, + betas=(beta1, beta2), + eps=epsilon, + weight_decay=weight_decay, + ) + + if use_8bit and use_4bit: + raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") + + if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: + try: + import torchao + + torchao.__version__ + except ImportError: + raise ImportError( + "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." + ) + + if not use_torchao and use_4bit: + raise ValueError("4-bit Optimizers are only supported with torchao.") + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy", "came"] + if optimizer_name not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." + ) + optimizer_name = "adamw" + + if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: + raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + + if use_8bit: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if optimizer_name == "adamw": + if use_torchao: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + + optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + else: + optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "adam": + if use_torchao: + from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + + optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam + else: + optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + init_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "beta3": beta3, + "eps": epsilon, + "weight_decay": weight_decay, + "decouple": prodigy_decouple, + "use_bias_correction": prodigy_use_bias_correction, + "safeguard_warmup": prodigy_safeguard_warmup, + } + + elif optimizer_name == "came": + try: + import came_pytorch + except ImportError: + raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + + optimizer_class = came_pytorch.CAME + + init_kwargs = { + "lr": learning_rate, + "eps": (1e-30, 1e-16), + "betas": (beta1, beta2, beta3), + "weight_decay": weight_decay, + } + + if use_cpu_offload_optimizer: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + + if "fused" in inspect.signature(optimizer_class.__init__).parameters: + init_kwargs.update({"fused": True}) + + optimizer = CPUOffloadOptimizer( + params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + ) + else: + optimizer = optimizer_class(params_to_optimize, **init_kwargs) + + return optimizer + + +def get_gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + patch_size_t: int = None, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + if patch_size_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB") + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/README.md b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..13d0c391a58597aaf34a6b5579515420c74893c0 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/README.md @@ -0,0 +1,111 @@ +# Simple Mochi-1 finetuner + + + + + + + + + + +
Dataset Sample Test Sample
+ +Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨 + +We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation. + +**Updates** + +December 1 2024: Support for checkpoint saving and loading. + +## Getting started + +Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`. + +Download a demo dataset: + +```bash +huggingface-cli download \ + --repo-type dataset sayakpaul/video-dataset-disney-organized \ + --local-dir video-dataset-disney-organized +``` + +The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions): + +```bash +video_1.mp4 +video_1.txt -- One-paragraph description of video_1 +video_2.mp4 +video_2.txt -- One-paragraph description of video_2 +... +``` + +Then run (be sure to check the paths accordingly): + +```bash +bash prepare_dataset.sh +``` + +We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video. + +> [!IMPORTANT] +> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements. + +Now, we're ready to fine-tune. To launch, run: + +```bash +bash train.sh +``` + +You can disable intermediate validation by: + +```diff +- --validation_prompt "..." \ +- --validation_prompt_separator ::: \ +- --num_validation_videos 1 \ +- --validation_epochs 1 \ +``` + +We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM. + +To use the LoRA checkpoint: + +```py +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("path-to-lora") +pipe.enable_model_cpu_offload() + +pipeline_args = { + "prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions", + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": 480, + "width": 848, + "max_sequence_length": 256, + "output_type": "np", +} + +with torch.autocast("cuda", torch.bfloat16) + video = pipe(**pipeline_args).frames[0] +export_to_video(video) +``` + +## Known limitations + +(Contributions are welcome 🤗) + +Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below: + +* No support for distributed training. +* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. +* No support for 8bit optimizers (but should be relatively easy to add). + +**Misc**: + +* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). +* `embed.py` script is non-batched. diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/args.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/args.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb420e73540dee050a65a0bb62c5c8c77e8d9b0 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/args.py @@ -0,0 +1,268 @@ +""" +Default values taken from +https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml +when applicable. +""" + +import argparse + + +def _get_model_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--cast_dit", + action="store_true", + help="If we should cast DiT params to a lower precision.", + ) + parser.add_argument( + "--compile_dit", + action="store_true", + help="If we should compile the DiT.", + ) + + +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--caption_dropout", + type=float, + default=None, + help=("Probability to drop out captions randomly."), + ) + + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _get_validation_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + default=False, + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="FPS to use when serializing the output videos.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + ) + parser.add_argument( + "--width", + type=int, + default=848, + ) + + +def _get_training_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=16, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--target_modules", + nargs="+", + type=str, + default=["to_k", "to_q", "to_v", "to_out.0"], + help="Target modules to train LoRA for.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="mochi-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=200, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=None, + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + ) + + +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.01, + help="Weight decay to use for optimizer.", + ) + + +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.") + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") + + _get_model_args(parser) + _get_dataset_args(parser) + _get_training_args(parser) + _get_validation_args(parser) + _get_optimizer_args(parser) + _get_configuration_args(parser) + + return parser.parse_args() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/dataset_simple.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/dataset_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc6153be09a6d4b18b0edb482897e53cab7411d --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/dataset_simple.py @@ -0,0 +1,50 @@ +""" +Taken from +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py +""" + +from pathlib import Path + +import click +import torch +from torch.utils.data import DataLoader, Dataset + + +def load_to_cpu(x): + return torch.load(x, map_location=torch.device("cpu"), weights_only=True) + + +class LatentEmbedDataset(Dataset): + def __init__(self, file_paths, repeat=1): + self.items = [ + (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt")) + for p in file_paths + if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file() + ] + self.items = self.items * repeat + print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.") + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + latent_path, embed_path = self.items[idx] + return load_to_cpu(latent_path), load_to_cpu(embed_path) + + +@click.command() +@click.argument("directory", type=click.Path(exists=True, file_okay=False)) +def process_videos(directory): + dir_path = Path(directory) + mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")] + assert mp4_files, f"No mp4 files found" + + dataset = LatentEmbedDataset(mp4_files) + dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + + for latents, embeds in dataloader: + print([(k, v.shape) for k, v in latents.items()]) + + +if __name__ == "__main__": + process_videos() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/embed.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ec35ebb061618e133f093f459df525a3cf4567b3 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/embed.py @@ -0,0 +1,111 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py +""" + +import click +import torch +import torchvision +from pathlib import Path +from diffusers import AutoencoderKLMochi, MochiPipeline +from transformers import T5EncoderModel, T5Tokenizer +from tqdm.auto import tqdm + + +def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): + T, H, W = [int(s) for s in shape.split("x")] + assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" + video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") + fps = metadata["video_fps"] + video = video.permute(3, 0, 1, 2) + og_shape = video.shape + assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" + assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" + assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" + if video.shape[1] > T: + video = video[:, :T] + print(f"Trimmed video from {og_shape[1]} to first {T} frames") + video = video.unsqueeze(0) + video = video.float() / 127.5 - 1.0 + video = video.to(model.device) + + assert video.ndim == 5 + + with torch.inference_mode(): + with torch.autocast("cuda", dtype=torch.bfloat16): + ldist = model._encode(video) + + torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) + + +@click.command() +@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) +@click.option( + "--model_id", + type=str, + help="Repo id. Should be genmo/mochi-1-preview", + default="genmo/mochi-1-preview", +) +@click.option("--shape", default="163x480x848", help="Shape of the video to encode") +@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.") +def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: + """Process all videos and captions in a directory using a single GPU.""" + # comment out when running on unsupported hardware + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Get all video paths + video_paths = list(output_dir.glob("**/*.mp4")) + if not video_paths: + print(f"No MP4 files found in {output_dir}") + return + + text_paths = list(output_dir.glob("**/*.txt")) + if not text_paths: + print(f"No text files found in {output_dir}") + return + + # load the models + vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") + text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") + tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") + pipeline = MochiPipeline.from_pretrained( + model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None + ).to("cuda") + + for idx, video_path in tqdm(enumerate(sorted(video_paths))): + print(f"Processing {video_path}") + try: + if video_path.with_suffix(".latent.pt").exists() and not overwrite: + print(f"Skipping {video_path}") + continue + + # encode videos. + encode_videos(vae, vid_path=video_path, shape=shape) + + # embed captions. + prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") + embed_path = prompt_path.with_suffix(".embed.pt") + + if embed_path.exists() and not overwrite: + print(f"Skipping {prompt_path} - embeddings already exist") + continue + + with open(prompt_path) as f: + text = f.read().strip() + with torch.inference_mode(): + conditioning = pipeline.encode_prompt(prompt=[text]) + + conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} + torch.save(conditioning, embed_path) + + except Exception as e: + import traceback + + traceback.print_exc() + print(f"Error processing {video_path}: {str(e)}") + + +if __name__ == "__main__": + batch_process() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/prepare_dataset.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/prepare_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..c424b4e5913cbfc64172951573ec1d7eb578b5d5 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/prepare_dataset.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +GPU_ID=0 +VIDEO_DIR=video-dataset-disney-organized +OUTPUT_DIR=videos_prepared +NUM_FRAMES=37 +RESOLUTION=480x848 + +# Extract width and height from RESOLUTION +WIDTH=$(echo $RESOLUTION | cut -dx -f1) +HEIGHT=$(echo $RESOLUTION | cut -dx -f2) + +python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample + +CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT} diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/requirements.txt b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a03ceeb0ee286d80009b9e2dc39d801d18603e4c --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/requirements.txt @@ -0,0 +1,8 @@ +peft +transformers +wandb +torch +torchvision +av==11.0.0 +moviepy==1.0.3 +click \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/text_to_video_lora.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/text_to_video_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..af1ce6268dbc8cc88d2b02d1b292c285ceb43b96 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/text_to_video_lora.py @@ -0,0 +1,592 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +from glob import glob +import math +import os +import torch.nn.functional as F +import numpy as np +from pathlib import Path +from typing import Any, Dict, Tuple, List + +import torch +import wandb +from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + + +from args import get_args # isort:skip +from dataset_simple import LatentEmbedDataset + +import sys +from utils import print_memory, reset_memory # isort:skip + + +# Taken from +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139 +def get_cosine_annealing_lr_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, +): + def lr_lambda(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + else: + return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=30, +): + widget_dict = [] + if videos is not None and len(videos) > 0: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"final_video_{i}.mp4"}, + } + ) + + model_description = f""" +# Mochi-1 Preview LoRA Finetune + + + +## Model description + +This is a lora finetune of the Mochi-1 preview model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("CHANGE_ME") +pipe.enable_model_cpu_offload() + +with torch.autocast("cuda", torch.bfloat16): + video = pipe( + prompt="CHANGE_ME", + guidance_scale=6.0, + num_inference_steps=64, + height=480, + width=848, + max_sequence_length=256, + output_type="np" + ).frames[0] +export_to_video(video) +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="apache-2.0", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "mochi-1-preview", + "mochi-1-preview-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe: MochiPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + wandb_run: str = None, + is_final_validation: bool = False, +): + print( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + phase_name = "test" if is_final_validation else "validation" + + if not args.enable_model_cpu_offload: + pipe = pipe.to("cuda") + + # run inference + generator = torch.manual_seed(args.seed) if args.seed else None + + videos = [] + with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=30) + video_filenames.append(filename) + + if wandb_run: + wandb.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +# Adapted from the original code: +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578 +def cast_dit(model, dtype): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert any( + n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"] + ), f"Unexpected linear layer: {name}" + module.to(dtype=dtype) + elif isinstance(module, torch.nn.Conv2d): + module.to(dtype=dtype) + return model + + +def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path): + lora_state_dict = get_peft_model_state_dict(model) + torch.save( + { + "state_dict": lora_state_dict, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "global_step": global_step, + }, + checkpoint_path, + ) + + +class CollateFunction: + def __init__(self, caption_dropout: float = None) -> None: + self.caption_dropout = caption_dropout + + def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]: + ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0) + z = DiagonalGaussianDistribution(ldists).sample() + assert torch.isfinite(z).all() + + # Sample noise which we will add to the samples. + eps = torch.randn_like(z) + sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) + + prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0) + prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0) + if self.caption_dropout and random.random() < self.caption_dropout: + prompt_embeds.zero_() + prompt_attention_mask = prompt_attention_mask.long() + prompt_attention_mask.zero_() + prompt_attention_mask = prompt_attention_mask.bool() + + return dict( + z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask + ) + + +def main(args): + if not torch.cuda.is_available(): + raise ValueError("Not supported without CUDA.") + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + # Handle the repository creation + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + transformer = MochiTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + + transformer.requires_grad_(False) + transformer.to("cuda") + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.cast_dit: + transformer = cast_dit(transformer, torch.bfloat16) + if args.compile_dit: + transformer.compile() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights="gaussian", + target_modules=args.target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.train_batch_size + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + # Prepare optimizer + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters) + optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) + + # Dataset and DataLoader + train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) + train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] + print(f"Found {len(train_vids)} training videos in {args.data_root}") + assert len(train_vids) > 0, f"No training data found in {args.data_root}" + + collate_fn = CollateFunction(caption_dropout=args.caption_dropout) + train_dataset = LatentEmbedDataset(train_vids, repeat=1) + train_dataloader = DataLoader( + train_dataset, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # LR scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = len(train_dataloader) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_cosine_annealing_lr_scheduler( + optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = len(train_dataloader) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + wandb_run = None + if args.report_to == "wandb": + tracker_name = args.tracker_name or "mochi-1-lora" + wandb_run = wandb.init(project=tracker_name, config=vars(args)) + + # Resume from checkpoint if specified + if args.resume_from_checkpoint: + checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu", weights_only=True) + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + if "optimizer" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + if "lr_scheduler" in checkpoint: + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + set_peft_model_state_dict(transformer, checkpoint["state_dict"]) + + print(f"Resuming from checkpoint: {args.resume_from_checkpoint}") + print(f"Resuming from global step: {global_step}") + else: + global_step = 0 + + print("===== Memory before training =====") + reset_memory("cuda") + print_memory("cuda") + + # Train! + total_batch_size = args.train_batch_size + print("***** Running training *****") + print(f" Num trainable parameters = {num_trainable_parameters}") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num epochs = {args.num_train_epochs}") + print(f" Instantaneous batch size per device = {args.train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Total optimization steps = {args.max_train_steps}") + + first_epoch = 0 + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + ) + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + with torch.no_grad(): + z = batch["z"].to("cuda") + eps = batch["eps"].to("cuda") + sigma = batch["sigma"].to("cuda") + prompt_embeds = batch["prompt_embeds"].to("cuda") + prompt_attention_mask = batch["prompt_attention_mask"].to("cuda") + + sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps + ut = z - eps + + # (1 - sigma) because of + # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 + # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. + timesteps = (1 - sigma) * scheduler.config.num_train_timesteps + + with torch.autocast("cuda", torch.bfloat16): + model_pred = transformer( + hidden_states=z_sigma, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + assert model_pred.shape == z.shape + loss = F.mse_loss(model_pred.float(), ut.float()) + loss.backward() + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + progress_bar.update(1) + global_step += 1 + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs = {"loss": loss.detach().item(), "lr": last_lr} + progress_bar.set_postfix(**logs) + if wandb_run: + wandb_run.log(logs, step=global_step) + + if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0: + print(f"Saving checkpoint at step {global_step}") + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") + save_checkpoint( + transformer, + optimizer, + lr_scheduler, + global_step, + checkpoint_path, + ) + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + print("===== Memory before validation =====") + print_memory("cuda") + + transformer.eval() + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + log_validation( + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + wandb_run=wandb_run, + ) + + print("===== Memory after validation =====") + print_memory("cuda") + reset_memory("cuda") + + del pipe.text_encoder + del pipe.vae + del pipe + gc.collect() + torch.cuda.empty_cache() + + transformer.train() + + transformer.eval() + transformer_lora_layers = get_peft_model_state_dict(transformer) + MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) + + # Cleanup trained models to save memory + del transformer + + gc.collect() + torch.cuda.empty_cache() + + # Final test inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + print("===== Memory before testing =====") + print_memory("cuda") + reset_memory("cuda") + + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") + pipe.set_adapters(["mochi-lora"], [lora_scaling]) + + # Run inference + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + + video = log_validation( + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + wandb_run=wandb_run, + is_final_validation=True, + ) + validation_outputs.extend(video) + + print("===== Memory after testing =====") + print_memory("cuda") + reset_memory("cuda") + torch.cuda.synchronize("cuda") + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["*.bin"], + ) + print(f"Params pushed to {repo_id}.") + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/train.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..2c378e2e1b7c8bc262ce74dbea01e8b6ed4994e4 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/train.sh @@ -0,0 +1,37 @@ +#!/bin/bash +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +DATA_ROOT="videos_prepared" +MODEL="genmo/mochi-1-preview" +OUTPUT_PATH="mochi-lora" + +cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \ + --pretrained_model_name_or_path $MODEL \ + --cast_dit \ + --data_root $DATA_ROOT \ + --seed 42 \ + --output_dir $OUTPUT_PATH \ + --train_batch_size 1 \ + --dataloader_num_workers 4 \ + --pin_memory \ + --caption_dropout 0.1 \ + --max_train_steps 2000 \ + --gradient_checkpointing \ + --enable_slicing \ + --enable_tiling \ + --enable_model_cpu_offload \ + --optimizer adamw \ + --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --allow_tf32 \ + --report_to wandb \ + --push_to_hub" + +echo "Running command: $cmd" +eval $cmd +echo -ne "-------------------- Finished executing script --------------------\n\n" \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/trim_and_crop_videos.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/trim_and_crop_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6f411d9e5fc6133232405f38b0ec5d7a627765 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/trim_and_crop_videos.py @@ -0,0 +1,126 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py +""" + +from pathlib import Path +import shutil + +import click +from moviepy.editor import VideoFileClip +from tqdm import tqdm + + +@click.command() +@click.argument("folder", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_folder", type=click.Path(dir_okay=True)) +@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames") +@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution") +@click.option("--force_upsample", is_flag=True, help="Force upsample.") +def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample): + """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution""" + input_path = Path(folder) + output_path = Path(output_folder) + output_path.mkdir(parents=True, exist_ok=True) + + # Parse target resolution + target_height, target_width = map(int, resolution.split("x")) + + # Calculate duration + duration = (num_frames / 30) + 0.09 + + # Find all MP4 and MOV files + video_files = ( + list(input_path.rglob("*.mp4")) + + list(input_path.rglob("*.MOV")) + + list(input_path.rglob("*.mov")) + + list(input_path.rglob("*.MP4")) + ) + + for file_path in tqdm(video_files): + try: + relative_path = file_path.relative_to(input_path) + output_file = output_path / relative_path.with_suffix(".mp4") + output_file.parent.mkdir(parents=True, exist_ok=True) + + click.echo(f"Processing: {file_path}") + video = VideoFileClip(str(file_path)) + + # Skip if video is too short + if video.duration < duration: + click.echo(f"Skipping {file_path} as it is too short") + continue + + # Skip if target resolution is larger than input + if target_width > video.w or target_height > video.h: + if force_upsample: + click.echo( + f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video." + ) + video = video.resize(width=target_width, height=target_height) + else: + click.echo( + f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}" + ) + continue + + # First truncate duration + truncated = video.subclip(0, duration) + + # Calculate crop dimensions to maintain aspect ratio + target_ratio = target_width / target_height + current_ratio = truncated.w / truncated.h + + if current_ratio > target_ratio: + # Video is wider than target ratio - crop width + new_width = int(truncated.h * target_ratio) + x1 = (truncated.w - new_width) // 2 + final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height)) + else: + # Video is taller than target ratio - crop height + new_height = int(truncated.w / target_ratio) + y1 = (truncated.h - new_height) // 2 + final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height)) + + # Set output parameters for consistent MP4 encoding + output_params = { + "codec": "libx264", + "audio": False, # Disable audio + "preset": "medium", # Balance between speed and quality + "bitrate": "5000k", # Adjust as needed + } + + # Set FPS to 30 + final = final.set_fps(30) + + # Check for a corresponding .txt file + txt_file_path = file_path.with_suffix(".txt") + if txt_file_path.exists(): + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(txt_file_path, output_txt_file) + click.echo(f"Copied {txt_file_path} to {output_txt_file}") + else: + # Print warning in bold yellow with a warning emoji + click.echo( + f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m" + ) + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + output_txt_file.touch() + + # Write the output file + final.write_videofile(str(output_file), **output_params) + + # Clean up + video.close() + truncated.close() + final.close() + + except Exception as e: + click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True) + raise + + +if __name__ == "__main__": + truncate_videos() diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/utils.py b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76fe35c2036da2483aa431d354226ccb1c16b9bc --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/mochi-1/utils.py @@ -0,0 +1,22 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch + +logger = get_logger(__name__) + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB") diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/prepare_dataset.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/prepare_dataset.sh new file mode 100755 index 0000000000000000000000000000000000000000..304786d309834e54f39d290af6eba770b30cdc03 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/prepare_dataset.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +MODEL_ID="THUDM/CogVideoX-2b" + +NUM_GPUS=8 + +# For more details on the expected data format, please refer to the README. +DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located. +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset" +HEIGHT_BUCKETS="480 720" +WIDTH_BUCKETS="720 960" +FRAME_BUCKETS="49" +MAX_NUM_FRAMES="49" +MAX_SEQUENCE_LENGTH=226 +TARGET_FPS=8 +BATCH_SIZE=1 +DTYPE=fp32 + +# To create a folder-style dataset structure without pre-encoding videos and captions +# For Image-to-Video finetuning, make sure to pass `--save_image_latents` +CMD_WITHOUT_PRE_ENCODING="\ + torchrun --nproc_per_node=$NUM_GPUS \ + training/prepare_dataset.py \ + --model_id $MODEL_ID \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --output_dir $OUTPUT_DIR \ + --height_buckets $HEIGHT_BUCKETS \ + --width_buckets $WIDTH_BUCKETS \ + --frame_buckets $FRAME_BUCKETS \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_sequence_length $MAX_SEQUENCE_LENGTH \ + --target_fps $TARGET_FPS \ + --batch_size $BATCH_SIZE \ + --dtype $DTYPE +" + +CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings" + +# Select which you'd like to run +CMD=$CMD_WITH_PRE_ENCODING + +echo "===== Running \`$CMD\` =====" +eval $CMD +echo -ne "===== Finished running script =====\n" diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_lora.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_lora.sh new file mode 100755 index 0000000000000000000000000000000000000000..8ff0111a88018fba1fea322e3c91feb4fc23516b --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_lora.sh @@ -0,0 +1,82 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4" "1e-3") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw" "adam") +MAX_TRAIN_STEPS=("3000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/disney-dataset" +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" + +# Launch experiments with different hyperparameters +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_image_to_video_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b-I2V \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_images \"/path/to/image1.png:::/path/to/image2.png\" + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --noised_image_dropout 0.05 \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_sft.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_sft.sh new file mode 100755 index 0000000000000000000000000000000000000000..7cdbf338908dea16e0099bad3b7c124feba50678 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/train_image_to_video_sft.sh @@ -0,0 +1,87 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +# export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export TOKENIZERS_PARALLELISM=true +export OMP_NUM_THREADS=16 + +GPU_IDS="0,1" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw") +MAX_TRAIN_STEPS=("20000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/video-dataset-disney" +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +MODEL_PATH="THUDM/CogVideoX1.5-5B-I2V" + +# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process. +# Launch experiments with different hyperparameters + +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="./cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \ + --gpu_ids $GPU_IDS \ + training/cogvideox/cogvideox_image_to_video_sft.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 77 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_images \"/path/to/image1.png:::/path/to/image2.png\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --seed 42 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 77 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 2000 \ + --gradient_accumulation_steps 4 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 800 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --noised_image_dropout 0.05 \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_lora.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_lora.sh new file mode 100755 index 0000000000000000000000000000000000000000..e7239f56242108023280ed9533e731297edf216d --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_lora.sh @@ -0,0 +1,86 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4" "1e-3") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw" "adam") +MAX_TRAIN_STEPS=("3000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset +DATA_ROOT="/path/to/my/datasets/disney-dataset" + +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +MODEL_PATH="THUDM/CogVideoX-5b" + +# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process. +# Launch experiments with different hyperparameters + +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="./cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 128 \ + --lora_alpha 128 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 400 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --enable_model_cpu_offload \ + --load_tensors \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_sft.sh b/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_sft.sh new file mode 100755 index 0000000000000000000000000000000000000000..b4de76caa4fa035959451f440e5757e33a88c9f6 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/_legacy/training/train_text_to_video_sft.sh @@ -0,0 +1,77 @@ +export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="0" + +# Training Configurations +# Experiment with as many hyperparameters as you want! +LEARNING_RATES=("1e-4") +LR_SCHEDULES=("cosine_with_restarts") +OPTIMIZERS=("adamw") +MAX_TRAIN_STEPS=("20000") + +# Single GPU uncompiled training +ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + +# Absolute path to where the data is located. Make sure to have read the README for how to prepare data. +# This example assumes you downloaded an already prepared dataset from HF CLI as follows: +# huggingface-cli download --repo-type dataset Wild-Heart/Tom-and-Jerry-VideoGeneration-Dataset --local-dir /path/to/my/datasets/tom-and-jerry-dataset +DATA_ROOT="/path/to/my/datasets/tom-and-jerry-dataset" +CAPTION_COLUMN="captions.txt" +VIDEO_COLUMN="videos.txt" + +# Launch experiments with different hyperparameters +for learning_rate in "${LEARNING_RATES[@]}"; do + for lr_schedule in "${LR_SCHEDULES[@]}"; do + for optimizer in "${OPTIMIZERS[@]}"; do + for steps in "${MAX_TRAIN_STEPS[@]}"; do + output_dir="/path/to/my/models/cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/" + + cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_sft.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-5b \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --height_buckets 480 \ + --width_buckets 720 \ + --frame_buckets 49 \ + --dataloader_num_workers 8 \ + --pin_memory \ + --validation_prompt \"Tom, the mischievous gray cat, is sprawled out on a vibrant red pillow, his body relaxed and his eyes half-closed, as if he's just woken up or is about to doze off. His white paws are stretched out in front of him, and his tail is casually draped over the edge of the pillow. The setting appears to be a cozy corner of a room, with a warm yellow wall in the background and a hint of a wooden floor. The scene captures a rare moment of tranquility for Tom, contrasting with his usual energetic and playful demeanor:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --seed 42 \ + --mixed_precision bf16 \ + --output_dir $output_dir \ + --max_num_frames 49 \ + --train_batch_size 1 \ + --max_train_steps $steps \ + --checkpointing_steps 2000 \ + --gradient_accumulation_steps 4 \ + --gradient_checkpointing \ + --learning_rate $learning_rate \ + --lr_scheduler $lr_schedule \ + --lr_warmup_steps 800 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer $optimizer \ + --beta1 0.9 \ + --beta2 0.95 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --nccl_timeout 1800" + + echo "Running command: $cmd" + eval $cmd + echo -ne "-------------------- Finished executing script --------------------\n\n" + done + done + done +done diff --git a/docs/finetrainers-src-codebase/examples/formats/hunyuan_video/convert_to_original_format.py b/docs/finetrainers-src-codebase/examples/formats/hunyuan_video/convert_to_original_format.py new file mode 100644 index 0000000000000000000000000000000000000000..c776fe16c6944c7971766eca5f6bbf5b7e3b1e12 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/formats/hunyuan_video/convert_to_original_format.py @@ -0,0 +1,163 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file, save_file + + +def convert_lora_sd(diffusers_lora_sd): + double_block_patterns = { + "attn.to_out.0": "img_attn_proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "attn.to_add_out": "txt_attn_proj", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + } + + prefix = "diffusion_model." + + double_block_pattern = "transformer.transformer_blocks" + single_block_pattern = "transformer.single_transformer_blocks" + + converted_lora_sd = {} + for key in diffusers_lora_sd.keys(): + # double_blocks + if key.startswith(double_block_pattern): + # img_attn + if key.endswith("to_q.lora_A.weight"): + # lora_A + to_q_A = diffusers_lora_sd[key] + to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")] + to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")] + + to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0) + qkv_A_key = key.replace(double_block_pattern, prefix + "double_blocks").replace( + "attn.to_q", "img_attn_qkv" + ) + converted_lora_sd[qkv_A_key] = to_qkv_A + + # lora_B + to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")] + to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")] + to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")] + + to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B) + qkv_B_key = qkv_A_key.replace("lora_A", "lora_B") + converted_lora_sd[qkv_B_key] = to_qkv_B + + # txt_attn + elif key.endswith("add_q_proj.lora_A.weight"): + # lora_A + to_q_A = diffusers_lora_sd[key] + to_k_A = diffusers_lora_sd[key.replace("add_q_proj", "add_k_proj")] + to_v_A = diffusers_lora_sd[key.replace("add_q_proj", "add_v_proj")] + + to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0) + qkv_A_key = key.replace(double_block_pattern, prefix + "double_blocks").replace( + "attn.add_q_proj", "txt_attn_qkv" + ) + converted_lora_sd[qkv_A_key] = to_qkv_A + + # lora_B + to_q_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_q_proj.lora_B")] + to_k_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_k_proj.lora_B")] + to_v_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_v_proj.lora_B")] + + to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B) + qkv_B_key = qkv_A_key.replace("lora_A", "lora_B") + converted_lora_sd[qkv_B_key] = to_qkv_B + + # just rename + for k, v in double_block_patterns.items(): + if k in key: + new_key = key.replace(k, v).replace(double_block_pattern, prefix + "double_blocks") + converted_lora_sd[new_key] = diffusers_lora_sd[key] + + # single_blocks + elif key.startswith(single_block_pattern): + if key.endswith("to_q.lora_A.weight"): + # lora_A + to_q_A = diffusers_lora_sd[key] + to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")] + to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")] + proj_mlp_A_key = key.replace("attn.to_q", "proj_mlp") + if proj_mlp_A_key in diffusers_lora_sd: + proj_mlp_A = diffusers_lora_sd[proj_mlp_A_key] + else: + proj_mlp_A = torch.zeros((to_q_A.shape[0], to_q_A.shape[1])) + linear1_A = torch.cat([to_q_A, to_k_A, to_v_A, proj_mlp_A], dim=0) + linear1_A_key = key.replace(single_block_pattern, prefix + "single_blocks").replace( + "attn.to_q", "linear1" + ) + converted_lora_sd[linear1_A_key] = linear1_A + + # lora_B + to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")] + to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")] + to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")] + proj_mlp_B_key = key.replace("to_q.lora_A", "attn.to_q.lora_B") + if proj_mlp_B_key in diffusers_lora_sd: + proj_mlp_B = diffusers_lora_sd[proj_mlp_B_key] + else: + proj_mlp_B = torch.zeros((to_q_B.shape[0] * 4, to_q_B.shape[1])) + linear1_B = torch.block_diag(to_q_B, to_k_B, to_v_B, proj_mlp_B) + linear1_B_key = linear1_A_key.replace("lora_A", "lora_B") + converted_lora_sd[linear1_B_key] = linear1_B + + elif "proj_out" in key: + new_key = key.replace("proj_out", "linear2").replace(single_block_pattern, prefix + "single_blocks") + converted_lora_sd[new_key] = diffusers_lora_sd[key] + + else: + print(f"unknown or not implemented: {key}") + + return converted_lora_sd + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_lora", type=str, required=True, help="Path to LoRA .safetensors") + parser.add_argument("--alpha", type=float, default=None, help="Optional alpha value, defaults to rank") + parser.add_argument( + "--dtype", type=str, default=None, help="Optional dtype (bfloat16, float16, float32), defaults to input dtype" + ) + parser.add_argument("--debug", action="store_true", help="Print converted keys instead of saving") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + converted_lora_sd = convert_lora_sd(load_file(args.input_lora)) + + if args.alpha is not None: + for key in list(converted_lora_sd.keys()): + if "lora_A" in key: + alpha_name = key.replace(".lora_A.weight", ".alpha") + converted_lora_sd[alpha_name] = torch.tensor([args.alpha], dtype=converted_lora_sd[key].dtype) + + dtype = None + if args.dtype == "bfloat16": + dtype = torch.bfloat16 + elif args.dtype == "float16": + dtype = torch.float16 + elif args.dtype == "float32": + dtype = torch.float32 + + if dtype is not None: + dtype_min = torch.finfo(dtype).min + dtype_max = torch.finfo(dtype).max + for key in converted_lora_sd.keys(): + if converted_lora_sd[key].min() < dtype_min or converted_lora_sd[key].max() > dtype_max: + print(f"warning: {key} has values outside of {dtype} {dtype_min} {dtype_max} range") + converted_lora_sd[key] = converted_lora_sd[key].to(dtype) + + if args.debug: + for key in sorted(converted_lora_sd.keys()): + print(key, converted_lora_sd[key].shape, converted_lora_sd[key].dtype) + exit() + + output_path = os.path.splitext(args.input_lora)[0] + "_converted.safetensors" + save_file(converted_lora_sd, output_path) + print(f"saved to {output_path}") diff --git a/docs/finetrainers-src-codebase/examples/inference/cogvideox/cogvideox_text_to_video.sh b/docs/finetrainers-src-codebase/examples/inference/cogvideox/cogvideox_text_to_video.sh new file mode 100755 index 0000000000000000000000000000000000000000..5e6b20fcbc4698e0a41f70d5349f90f467507b8d --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/cogvideox/cogvideox_text_to_video.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +# export WANDB_MODE="offline" +export WANDB_MODE="disabled" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Download the validation dataset +if [ ! -d "examples/inference/datasets/openvid-1k-split-validation" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/inference/datasets/openvid-1k-split-validation +else + echo "Validation dataset already exists. Skipping download." +fi + +BACKEND="ptd" + +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +DATASET_FILE="examples/inference/cogvideox/dummy_text_to_video.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" +CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" +# FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +# FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +# HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $CP_2 +) + +# Model arguments +model_cmd=( + --model_name cogvideox + --pretrained_model_name_or_path "THUDM/CogVideoX-5B" + --enable_slicing + --enable_tiling +) + +# Inference arguments +inference_cmd=( + --inference_type text_to_video + --dataset_file "$DATASET_FILE" +) + +# Attention provider arguments +attn_provider_cmd=( + --attn_provider sage +) + +# Torch config arguments +torch_config_cmd=( + --allow_tf32 + --float32_matmul_precision high +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --seed 31337 + --tracker_name "finetrainers-inference" + --output_dir "/raid/aryan/cogvideox-inference" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the inference script +export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + examples/inference/inference.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${inference_cmd[@]}" \ + "${attn_provider_cmd[@]}" \ + "${torch_config_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/inference/cogview4/cogview4_text_to_image.sh b/docs/finetrainers-src-codebase/examples/inference/cogview4/cogview4_text_to_image.sh new file mode 100755 index 0000000000000000000000000000000000000000..0ef0b9da2dd57d31a37d21dae9c94c93f20ce04e --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/cogview4/cogview4_text_to_image.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +# export WANDB_MODE="offline" +export WANDB_MODE="disabled" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +BACKEND="ptd" + +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +DATASET_FILE="examples/inference/cogview4/dummy_text_to_image.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" +CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" +# FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +# FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +# HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $CP_2 +) + +# Model arguments +model_cmd=( + --model_name "cogview4" + --pretrained_model_name_or_path "THUDM/CogView4-6B" + --enable_slicing + --enable_tiling +) + +# Inference arguments +inference_cmd=( + --inference_type text_to_image + --dataset_file "$DATASET_FILE" +) + +# Attention provider arguments +attn_provider_cmd=( + --attn_provider flash_varlen +) + +# Torch config arguments +torch_config_cmd=( + --allow_tf32 + --float32_matmul_precision high +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --seed 31337 + --tracker_name "finetrainers-inference" + --output_dir "/raid/aryan/cogview4-inference" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the inference script +export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + examples/inference/inference.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${inference_cmd[@]}" \ + "${attn_provider_cmd[@]}" \ + "${torch_config_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/inference/datasets/.gitignore b/docs/finetrainers-src-codebase/examples/inference/datasets/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3db22f8dec2f9a00d999aae1ce3c065c4f052c09 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/datasets/.gitignore @@ -0,0 +1 @@ +openvid-1k-split-validation diff --git a/docs/finetrainers-src-codebase/examples/inference/flux/flux_text_to_image.sh b/docs/finetrainers-src-codebase/examples/inference/flux/flux_text_to_image.sh new file mode 100755 index 0000000000000000000000000000000000000000..ff71bc30fd5c014c79dfd4269e3cda0ecd20ca65 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/flux/flux_text_to_image.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +# export WANDB_MODE="offline" +export WANDB_MODE="disabled" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +BACKEND="ptd" + +NUM_GPUS=4 +CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Check the JSON files for the expected JSON format +DATASET_FILE="examples/inference/flux/dummy_text_to_image.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" +CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" +# FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +# FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +# HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $CP_4 +) + +# Model arguments +model_cmd=( + --model_name "flux" + --pretrained_model_name_or_path "black-forest-labs/FLUX.1-dev" + --cache_dir /raid/.cache/huggingface + --enable_slicing + --enable_tiling +) + +# Inference arguments +inference_cmd=( + --inference_type text_to_image + --dataset_file "$DATASET_FILE" +) + +# Attention provider arguments +attn_provider_cmd=( + --attn_provider flash_varlen +) + +# Torch config arguments +torch_config_cmd=( + --allow_tf32 + --float32_matmul_precision high +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --seed 31337 + --tracker_name "finetrainers-inference" + --output_dir "/raid/aryan/flux-inference" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the inference script +export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + examples/inference/inference.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${inference_cmd[@]}" \ + "${attn_provider_cmd[@]}" \ + "${torch_config_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/inference/inference.py b/docs/finetrainers-src-codebase/examples/inference/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a66fa4cae270c70c6f0bd09fb3e6981e82f7e6 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/inference.py @@ -0,0 +1,854 @@ +import argparse +import json +import os +import time +import traceback +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import datasets.distributed +import torch +import wandb +from diffusers.hooks import HookRegistry, ModelHook +from diffusers.utils import export_to_video + +from finetrainers import data, get_logger, logging, parallel, patches, utils +from finetrainers.args import AttentionProviderInference +from finetrainers.config import ModelType +from finetrainers.models import ModelSpecification, attention_provider +from finetrainers.models.cogvideox import CogVideoXModelSpecification +from finetrainers.models.cogview4 import CogView4ModelSpecification +from finetrainers.models.flux import FluxModelSpecification +from finetrainers.models.wan import WanModelSpecification +from finetrainers.parallel import ParallelBackendEnum +from finetrainers.state import ParallelBackendType +from finetrainers.utils import ArgsConfigMixin + + +logger = get_logger() + + +def main(): + try: + import multiprocessing + + multiprocessing.set_start_method("fork") + except Exception as e: + logger.error( + f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' + f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" + f"Error: {e}" + ) + + try: + args = BaseArgs() + args.parse_args() + + model_specification_cls = get_model_specifiction_cls(args.model_name, args.inference_type) + model_specification = model_specification_cls( + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + tokenizer_id=args.tokenizer_id, + tokenizer_2_id=args.tokenizer_2_id, + tokenizer_3_id=args.tokenizer_3_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + text_encoder_3_id=args.text_encoder_3_id, + transformer_id=args.transformer_id, + vae_id=args.vae_id, + text_encoder_dtype=args.text_encoder_dtype, + text_encoder_2_dtype=args.text_encoder_2_dtype, + text_encoder_3_dtype=args.text_encoder_3_dtype, + transformer_dtype=args.transformer_dtype, + vae_dtype=args.vae_dtype, + revision=args.revision, + cache_dir=args.cache_dir, + ) + + inferencer = Inference(args, model_specification) + inferencer.run() + + except KeyboardInterrupt: + logger.info("Received keyboard interrupt. Exiting...") + except Exception as e: + logger.error(f"An error occurred during training: {e}") + logger.error(traceback.format_exc()) + + +class InferenceType(str, Enum): + TEXT_TO_IMAGE = "text_to_image" + TEXT_TO_VIDEO = "text_to_video" + IMAGE_TO_VIDEO = "image_to_video" + + +# We do a union because every ArgsConfigMixin registered to BaseArgs can be looked up using the `__getattribute__` override +BaseArgsType = Union[ + "BaseArgs", "ParallelArgs", "ModelArgs", "InferenceArgs", "AttentionProviderArgs", "TorchConfigArgs" +] + +_DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, +} + + +SUPPORTED_MODEL_CONFIGS = { + ModelType.COGVIDEOX: { + InferenceType.TEXT_TO_VIDEO: CogVideoXModelSpecification, + }, + ModelType.COGVIEW4: { + InferenceType.TEXT_TO_IMAGE: CogView4ModelSpecification, + }, + ModelType.FLUX: { + InferenceType.TEXT_TO_IMAGE: FluxModelSpecification, + }, + ModelType.WAN: { + InferenceType.TEXT_TO_VIDEO: WanModelSpecification, + InferenceType.IMAGE_TO_VIDEO: WanModelSpecification, + }, +} + + +def get_model_specifiction_cls(model_name: str, inference_type: InferenceType) -> ModelSpecification: + """ + Get the model specification class for the given model name and inference type. + """ + if model_name not in SUPPORTED_MODEL_CONFIGS: + raise ValueError( + f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" + ) + if inference_type not in SUPPORTED_MODEL_CONFIGS[model_name]: + raise ValueError( + f"Inference type {inference_type} not supported for model {model_name}. Supported inference types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" + ) + return SUPPORTED_MODEL_CONFIGS[model_name][inference_type] + + +@dataclass +class State: + # Parallel state + parallel_backend: ParallelBackendType = None + + # Training state + generator: torch.Generator = None + + +class Inference: + def __init__(self, args: BaseArgsType, model_specification: ModelSpecification): + self.args = args + self.model_specification = model_specification + self.state = State() + + self.pipeline = None + self.dataset = None + self.dataloader = None + + self._init_distributed() + self._init_config_options() + + patches.perform_patches_for_inference(args, self.state.parallel_backend) + + def run(self) -> None: + try: + self._prepare_pipeline() + self._prepare_distributed() + self._prepare_dataset() + self._inference() + except Exception as e: + logger.error(f"Error during inference: {e}") + self.state.parallel_backend.destroy() + raise e + + def _prepare_pipeline(self) -> None: + logger.info("Initializing pipeline") + + transformer = self.model_specification.load_diffusion_models()["transformer"] + self.pipeline = self.model_specification.load_pipeline( + transformer=transformer, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=False, # TODO(aryan): handle model/sequential/group offloading + training=False, + ) + + def _prepare_distributed(self) -> None: + parallel_backend = self.state.parallel_backend + cp_mesh = parallel_backend.get_mesh("cp") if parallel_backend.context_parallel_enabled else None + + if parallel_backend.context_parallel_enabled: + cp_mesh = parallel_backend.get_mesh()["cp"] + parallel_backend.apply_context_parallel(self.pipeline.transformer, cp_mesh) + + registry = HookRegistry.check_if_exists_or_initialize(self.pipeline.transformer) + hook = AttentionProviderHook( + self.args.attn_provider, cp_mesh, self.args.cp_rotate_method, self.args.cp_reduce_precision + ) + registry.register_hook(hook, "attn_provider") + + self._maybe_torch_compile() + + self._init_logging() + self._init_trackers() + self._init_directories() + + def _prepare_dataset(self) -> None: + logger.info("Preparing dataset for inference") + parallel_backend = self.state.parallel_backend + + dp_mesh = None + if parallel_backend.data_replication_enabled: + dp_mesh = parallel_backend.get_mesh("dp_replicate") + if dp_mesh is not None: + local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + else: + local_rank, dp_world_size = 0, 1 + + dataset = data.ValidationDataset(self.args.dataset_file) + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size) + dataloader = data.DPDataLoader( + local_rank, + dataset, + batch_size=1, + num_workers=0, # TODO(aryan): handle dataloader_num_workers + collate_fn=lambda items: items, + ) + + self.dataset = dataset + self.dataloader = dataloader + + def _inference(self) -> None: + parallel_backend = self.state.parallel_backend + seed = self.args.seed if self.args.seed is not None else 0 + generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) + + if parallel_backend._dp_degree > 1: + dp_mesh = parallel_backend.get_mesh("dp") + dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + else: + dp_mesh = None + dp_local_rank, dp_world_size = parallel_backend.local_rank, 1 + + self.pipeline.to(parallel_backend.device) + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before inference start: {json.dumps(memory_statistics, indent=4)}") + + data_iterator = iter(self.dataloader) + main_process_prompts_to_filenames = {} # Used to save model card + all_processes_artifacts = [] # Used to gather artifacts from all processes + + while True: + inference_data = next(data_iterator, None) + if inference_data is None: + break + + inference_data = inference_data[0] + with torch.inference_mode(): + inference_artifacts = self.model_specification.validation( + pipeline=self.pipeline, generator=generator, **inference_data + ) + + if dp_local_rank != 0: + continue + + PROMPT = inference_data["prompt"] + IMAGE = inference_data.get("image", None) + VIDEO = inference_data.get("video", None) + EXPORT_FPS = inference_data.get("export_fps", 30) + + # 2.1. If there are any initial images or videos, they will be logged to keep track of them as + # conditioning for generation. + prompt_filename = utils.string_to_filename(PROMPT)[:25] + artifacts = { + "input_image": data.ImageArtifact(value=IMAGE), + "input_video": data.VideoArtifact(value=VIDEO), + } + + # 2.2. Track the artifacts generated from inference + for i, inference_artifact in enumerate(inference_artifacts): + if inference_artifact.value is None: + continue + artifacts.update({f"artifact_{i}": inference_artifact}) + + # 2.3. Save the artifacts to the output directory and create appropriate logging objects + # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. + for index, (key, artifact) in enumerate(list(artifacts.items())): + assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) + if artifact.value is None: + continue + + time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension + filename = f"inference-{rank}-{index}-{prompt_filename}-{time_}.{ext}" + output_filename = os.path.join(self.args.output_dir, filename) + + if parallel_backend.is_main_process and ext in ["mp4", "jpg", "jpeg", "png"]: + main_process_prompts_to_filenames[PROMPT] = filename + + if isinstance(artifact, data.ImageArtifact): + artifact.value.save(output_filename) + all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT)) + elif isinstance(artifact, data.VideoArtifact): + export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) + all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT)) + + # 3. Cleanup & log artifacts + parallel_backend.wait_for_everyone() + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after inference end: {json.dumps(memory_statistics, indent=4)}") + + # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. + all_artifacts = [None] * dp_world_size + if dp_world_size > 1: + torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) + else: + all_artifacts = [all_processes_artifacts] + all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] + + if parallel_backend.is_main_process: + tracker_key = "inference" + artifact_log_dict = {} + + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + if len(image_artifacts) > 0: + artifact_log_dict["images"] = image_artifacts + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + if len(video_artifacts) > 0: + artifact_log_dict["videos"] = video_artifacts + parallel_backend.log({tracker_key: artifact_log_dict}, step=0) + + parallel_backend.wait_for_everyone() + + def _init_distributed(self) -> None: + world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) + + # TODO(aryan): handle other backends + backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) + self.state.parallel_backend = backend_cls( + world_size=world_size, + pp_degree=self.args.pp_degree, + dp_degree=self.args.dp_degree, + dp_shards=self.args.dp_shards, + cp_degree=self.args.cp_degree, + tp_degree=self.args.tp_degree, + backend="nccl", + timeout=self.args.init_timeout, + logging_dir=self.args.logging_dir, + output_dir=self.args.output_dir, + ) + + if self.args.seed is not None: + self.state.parallel_backend.enable_determinism(self.args.seed) + + def _init_logging(self) -> None: + logging._set_parallel_backend(self.state.parallel_backend) + logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) + logger.info("Initialized Finetrainers") + + def _init_trackers(self) -> None: + # TODO(aryan): handle multiple trackers + trackers = [self.args.report_to] + experiment_name = self.args.tracker_name or "finetrainers-inference" + self.state.parallel_backend.initialize_trackers( + trackers, experiment_name=experiment_name, config=self.args.to_dict(), log_dir=self.args.logging_dir + ) + + def _init_directories(self) -> None: + if self.state.parallel_backend.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + + def _init_config_options(self) -> None: + # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision(self.args.float32_matmul_precision) + + def _maybe_torch_compile(self): + for model_name, compile_scope in zip(self.args.compile_modules, self.args.compile_scopes): + model = getattr(self.pipeline, model_name, None) + if model is not None: + logger.info(f"Applying torch.compile to '{model_name}' with scope '{compile_scope}'.") + compiled_model = utils.apply_compile(model, compile_scope) + setattr(self.pipeline, model_name, compiled_model) + + +class AttentionProviderHook(ModelHook): + def __init__( + self, + provider: str, + mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + rotate_method: str = "allgather", + reduce_precision: bool = False, + ): + super().__init__() + self.provider = provider + self.mesh = mesh + self.rotate_method = rotate_method + self.convert_to_fp32 = not reduce_precision + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + with attention_provider( + self.provider, mesh=self.mesh, convert_to_fp32=self.convert_to_fp32, rotate_method=self.rotate_method + ): + return self.fn_ref.original_forward(*args, **kwargs) + + +class ParallelArgs(ArgsConfigMixin): + """ + Args: + parallel_backend (`str`, defaults to "accelerate"): + The parallel backend to use for inference. Choose between ['accelerate', 'ptd']. + pp_degree (`int`, defaults to 1): + The degree of pipeline parallelism. + dp_degree (`int`, defaults to 1): + The degree of data parallelism (number of model replicas). + dp_shards (`int`, defaults to 1): + The number of data parallel shards (number of model partitions). + cp_degree (`int`, defaults to 1): + The degree of context parallelism. + """ + + parallel_backend: ParallelBackendEnum = ParallelBackendEnum.ACCELERATE + pp_degree: int = 1 + dp_degree: int = 1 + dp_shards: int = 1 + cp_degree: int = 1 + tp_degree: int = 1 + cp_rotate_method: str = "allgather" + cp_reduce_precision: bool = False + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("--parallel_backend", type=str, default="accelerate", choices=["accelerate", "ptd"]) + parser.add_argument("--pp_degree", type=int, default=1) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--dp_shards", type=int, default=1) + parser.add_argument("--cp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--cp_rotate_method", type=str, default="allgather", choices=["allgather", "alltoall"]) + parser.add_argument("--cp_reduce_precision", action="store_true") + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.parallel_backend = argparse_args.parallel_backend + mapped_args.pp_degree = argparse_args.pp_degree + mapped_args.dp_degree = argparse_args.dp_degree + mapped_args.dp_shards = argparse_args.dp_shards + mapped_args.cp_degree = argparse_args.cp_degree + mapped_args.tp_degree = argparse_args.tp_degree + mapped_args.cp_rotate_method = argparse_args.cp_rotate_method + mapped_args.cp_reduce_precision = argparse_args.cp_reduce_precision + + def validate_args(self, args: "BaseArgs"): + if args.parallel_backend != "ptd": + raise ValueError("Only 'ptd' parallel backend is supported for now.") + if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.tp_degree]): + raise ValueError("Parallel degrees must be 1 except for `cp_degree` for now.") + + +class ModelArgs(ArgsConfigMixin): + """ + Args: + model_name (`str`): + Name of model to train. + pretrained_model_name_or_path (`str`): + Path to pretrained model or model identifier from https://huggingface.co/models. The model should be + loadable based on specified `model_name`. + revision (`str`, defaults to `None`): + If provided, the model will be loaded from a specific branch of the model repository. + variant (`str`, defaults to `None`): + Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk + storage requirements. + cache_dir (`str`, defaults to `None`): + The directory where the downloaded models and datasets will be stored, or loaded from. + tokenizer_id (`str`, defaults to `None`): + Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_2_id (`str`, defaults to `None`): + Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_3_id (`str`, defaults to `None`): + Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + text_encoder_id (`str`, defaults to `None`): + Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_2_id (`str`, defaults to `None`): + Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_3_id (`str`, defaults to `None`): + Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + transformer_id (`str`, defaults to `None`): + Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`. + vae_id (`str`, defaults to `None`): + Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`. + text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder when generating text embeddings. + text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 2 when generating text embeddings. + text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 3 when generating text embeddings. + transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the transformer model. + vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the VAE model. + layerwise_upcasting_modules (`List[str]`, defaults to `[]`): + Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer']. + layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`): + Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2']. + layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`): + Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision + naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers + by default, and recommend adding more layers to the default list based on the model architecture. + enable_slicing (`bool`, defaults to `False`): + Whether to enable VAE slicing. + enable_tiling (`bool`, defaults to `False`): + Whether to enable VAE tiling. + """ + + model_name: str = None + pretrained_model_name_or_path: str = None + revision: Optional[str] = None + variant: Optional[str] = None + cache_dir: Optional[str] = None + tokenizer_id: Optional[str] = None + tokenizer_2_id: Optional[str] = None + tokenizer_3_id: Optional[str] = None + text_encoder_id: Optional[str] = None + text_encoder_2_id: Optional[str] = None + text_encoder_3_id: Optional[str] = None + transformer_id: Optional[str] = None + vae_id: Optional[str] = None + text_encoder_dtype: torch.dtype = torch.bfloat16 + text_encoder_2_dtype: torch.dtype = torch.bfloat16 + text_encoder_3_dtype: torch.dtype = torch.bfloat16 + transformer_dtype: torch.dtype = torch.bfloat16 + vae_dtype: torch.dtype = torch.bfloat16 + layerwise_upcasting_modules: List[str] = [] + layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn + # fmt: off + layerwise_upcasting_skip_modules_pattern: List[str] = ["patch_embed", "pos_embed", "x_embedder", "context_embedder", "time_embed", "^proj_in$", "^proj_out$", "norm"] + # fmt: on + enable_slicing: bool = False + enable_tiling: bool = False + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()] + ) + parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--variant", type=str, default=None) + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--tokenizer_id", type=str, default=None) + parser.add_argument("--tokenizer_2_id", type=str, default=None) + parser.add_argument("--tokenizer_3_id", type=str, default=None) + parser.add_argument("--text_encoder_id", type=str, default=None) + parser.add_argument("--text_encoder_2_id", type=str, default=None) + parser.add_argument("--text_encoder_3_id", type=str, default=None) + parser.add_argument("--transformer_id", type=str, default=None) + parser.add_argument("--vae_id", type=str, default=None) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16") + parser.add_argument("--transformer_dtype", type=str, default="bf16") + parser.add_argument("--vae_dtype", type=str, default="bf16") + parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"]) + parser.add_argument( + "--layerwise_upcasting_storage_dtype", + type=str, + default="float8_e4m3fn", + choices=["float8_e4m3fn", "float8_e5m2"], + ) + parser.add_argument( + "--layerwise_upcasting_skip_modules_pattern", + type=str, + default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], + nargs="+", + ) + parser.add_argument("--enable_slicing", action="store_true") + parser.add_argument("--enable_tiling", action="store_true") + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.model_name = argparse_args.model_name + mapped_args.pretrained_model_name_or_path = argparse_args.pretrained_model_name_or_path + mapped_args.revision = argparse_args.revision + mapped_args.variant = argparse_args.variant + mapped_args.cache_dir = argparse_args.cache_dir + mapped_args.tokenizer_id = argparse_args.tokenizer_id + mapped_args.tokenizer_2_id = argparse_args.tokenizer_2_id + mapped_args.tokenizer_3_id = argparse_args.tokenizer_3_id + mapped_args.text_encoder_id = argparse_args.text_encoder_id + mapped_args.text_encoder_2_id = argparse_args.text_encoder_2_id + mapped_args.text_encoder_3_id = argparse_args.text_encoder_3_id + mapped_args.transformer_id = argparse_args.transformer_id + mapped_args.vae_id = argparse_args.vae_id + mapped_args.text_encoder_dtype = _DTYPE_MAP[argparse_args.text_encoder_dtype] + mapped_args.text_encoder_2_dtype = _DTYPE_MAP[argparse_args.text_encoder_2_dtype] + mapped_args.text_encoder_3_dtype = _DTYPE_MAP[argparse_args.text_encoder_3_dtype] + mapped_args.transformer_dtype = _DTYPE_MAP[argparse_args.transformer_dtype] + mapped_args.vae_dtype = _DTYPE_MAP[argparse_args.vae_dtype] + mapped_args.layerwise_upcasting_modules = argparse_args.layerwise_upcasting_modules + mapped_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[argparse_args.layerwise_upcasting_storage_dtype] + mapped_args.layerwise_upcasting_skip_modules_pattern = argparse_args.layerwise_upcasting_skip_modules_pattern + mapped_args.enable_slicing = argparse_args.enable_slicing + mapped_args.enable_tiling = argparse_args.enable_tiling + + def validate_args(self, args: "BaseArgs"): + pass + + +class InferenceArgs(ArgsConfigMixin): + """ + Args: + inference_type (`str`): + The type of inference to run. Choose between ['text_to_video']. + dataset_file (`str`, defaults to `None`): + Path to a CSV/JSON/PARQUET/ARROW file containing information for inference. The file must contain atleast the + "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path" + will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path" + will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary. + The dataset file may contain other attributes such as: + - "height" and "width" and "num_frames": Resolution + - "num_inference_steps": Number of inference steps + - "guidance_scale": Classifier-free Guidance Scale + - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be + invoked with the sample dictionary to validate the sample.) + """ + + inference_type: InferenceType = InferenceType.TEXT_TO_VIDEO + dataset_file: str = None + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--inference_type", + type=str, + default=InferenceType.TEXT_TO_VIDEO.value, + choices=[x.value for x in InferenceType.__members__.values()], + ) + parser.add_argument("--dataset_file", type=str, required=True) + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.inference_type = InferenceType(argparse_args.inference_type) + mapped_args.dataset_file = argparse_args.dataset_file + + def validate_args(self, args: "BaseArgs"): + pass + + +class AttentionProviderArgs(ArgsConfigMixin): + """ + Args: + attn_provider (`str`, defaults to "native"): + The attention provider to use for inference. Choose between ['flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash', '_native_math', 'sage', 'sage_varlen', '_sage_qk_int8_pv_fp8_cuda', '_sage_qk_int8_pv_fp8_cuda_sm90', '_sage_qk_int8_pv_fp16_cuda', '_sage_qk_int8_pv_fp16_triton', 'xformers']. + """ + + attn_provider: AttentionProviderInference = "native" + # attn_provider_specialized_modules: List[str] = [] + + def add_args(self, parser: argparse.ArgumentParser) -> None: + # fmt: off + parser.add_argument("--attn_provider", type=str, default="native", choices=["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "sage", "sage_varlen", "_sage_qk_int8_pv_fp8_cuda", "_sage_qk_int8_pv_fp8_cuda_sm90", "_sage_qk_int8_pv_fp16_cuda", "_sage_qk_int8_pv_fp16_triton", "xformers"]) + # fmt: on + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.attn_provider = argparse_args.attn_provider + + def validate_args(self, args: "BaseArgs"): + pass + + +class TorchConfigArgs(ArgsConfigMixin): + """ + Args: + compile_modules (`List[str]`, defaults to `[]`): + Modules that should be regionally compiled with `torch.compile`. + compile_scopes (`str`, defaults to `None`): + The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as + `--compile_modules`. If `None`, will default to `regional` for all modules. + allow_tf32 (`bool`, defaults to `False`): + Whether or not to allow the use of TF32 matmul on compatible hardware. + float32_matmul_precision (`str`, defaults to `highest`): + The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium']. + """ + + compile_modules: List[str] = [] + compile_scopes: List[str] = None + allow_tf32: bool = False + float32_matmul_precision: str = "highest" + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("--compile_modules", type=str, default=[], nargs="+") + parser.add_argument("--compile_scopes", type=str, default=None, nargs="+") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument( + "--float32_matmul_precision", + type=str, + default="highest", + choices=["highest", "high", "medium"], + help="The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].", + ) + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + compile_scopes = argparse_args.compile_scopes + if len(argparse_args.compile_modules) > 0: + if compile_scopes is None: + compile_scopes = "regional" + if isinstance(compile_scopes, list) and len(compile_scopes) == 1: + compile_scopes = compile_scopes[0] + if isinstance(compile_scopes, str): + compile_scopes = [compile_scopes] * len(argparse_args.compile_modules) + else: + compile_scopes = [] + + mapped_args.compile_modules = argparse_args.compile_modules + mapped_args.compile_scopes = compile_scopes + mapped_args.allow_tf32 = argparse_args.allow_tf32 + mapped_args.float32_matmul_precision = argparse_args.float32_matmul_precision + + def validate_args(self, args: "BaseArgs"): + if len(args.compile_modules) > 0: + assert len(args.compile_modules) == len(args.compile_scopes) and all( + x in ["regional", "full"] for x in args.compile_scopes + ), ( + "Compile modules and compile scopes must be of the same length and compile scopes must be either 'regional' or 'full'" + ) + + +class MiscellaneousArgs(ArgsConfigMixin): + """ + Args: + seed (`int`, defaults to `None`): + Random seed for reproducibility under same initialization conditions. + tracker_name (`str`, defaults to `finetrainers`): + Name of the tracker/project to use for logging inference metrics. + output_dir (`str`, defaults to `None`): + The directory where the model checkpoints and logs will be stored. + logging_dir (`str`, defaults to `logs`): + The directory where the logs will be stored. + logging_steps (`int`, defaults to `1`): + Inference logs will be tracked every `logging_steps` steps. + nccl_timeout (`int`, defaults to `1800`): + Timeout for the NCCL communication. + report_to (`str`, defaults to `wandb`): + The name of the logger to use for logging inference metrics. Choose between ['wandb']. + verbose (`int`, defaults to `1`): + Whether or not to print verbose logs. + - 0: Diffusers/Transformers warning logging on local main process only + - 1: Diffusers/Transformers info logging on local main process only + - 2: Diffusers/Transformers debug logging on local main process only + - 3: Diffusers/Transformers debug logging on all processes + """ + + seed: Optional[int] = None + tracker_name: str = "finetrainers-inference" + output_dir: str = None + logging_dir: Optional[str] = "logs" + init_timeout: int = 300 # 5 minutes + nccl_timeout: int = 600 # 10 minutes + report_to: str = "wandb" + verbose: int = 1 + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--tracker_name", type=str, default="finetrainers") + parser.add_argument("--output_dir", type=str, default="finetrainers-inference") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--init_timeout", type=int, default=300) + parser.add_argument("--nccl_timeout", type=int, default=600) + parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"]) + parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3]) + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.seed = argparse_args.seed + mapped_args.tracker_name = argparse_args.tracker_name + mapped_args.output_dir = argparse_args.output_dir + mapped_args.logging_dir = argparse_args.logging_dir + mapped_args.init_timeout = argparse_args.init_timeout + mapped_args.nccl_timeout = argparse_args.nccl_timeout + mapped_args.report_to = argparse_args.report_to + mapped_args.verbose = argparse_args.verbose + + def validate_args(self, args: "BaseArgs"): + pass + + +class BaseArgs: + """The arguments for the finetrainers inference script.""" + + parallel_args = ParallelArgs() + model_args = ModelArgs() + inference_args = InferenceArgs() + attention_provider_args = AttentionProviderArgs() + torch_config_args = TorchConfigArgs() + miscellaneous_args = MiscellaneousArgs() + + _registered_config_mixins: List[ArgsConfigMixin] = [] + _arg_group_map: Dict[str, ArgsConfigMixin] = {} + + def __init__(self): + self._arg_group_map: Dict[str, ArgsConfigMixin] = { + "parallel_args": self.parallel_args, + "model_args": self.model_args, + "inference_args": self.inference_args, + "attention_provider_args": self.attention_provider_args, + "torch_config_args": self.torch_config_args, + "miscellaneous_args": self.miscellaneous_args, + } + + for arg_config_mixin in self._arg_group_map.values(): + self.register_args(arg_config_mixin) + + def to_dict(self) -> Dict[str, Any]: + arguments_to_dict = {} + for config_mixin in self._registered_config_mixins: + arguments_to_dict[config_mixin.__class__.__name__] = config_mixin.to_dict() + + return arguments_to_dict + + def register_args(self, config: ArgsConfigMixin) -> None: + if not hasattr(self, "_extended_add_arguments"): + self._extended_add_arguments = [] + self._extended_add_arguments.append((config.add_args, config.validate_args, config.map_args)) + self._registered_config_mixins.append(config) + + def parse_args(self): + parser = argparse.ArgumentParser() + + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + add_fn, _, _ = extended_add_arg_fns + add_fn(parser) + + args, remaining_args = parser.parse_known_args() + logger.debug(f"Remaining unparsed arguments: {remaining_args}") + + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, _, map_fn = extended_add_arg_fns + map_fn(args, self) + + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, validate_fn, _ = extended_add_arg_fns + validate_fn(self) + + return self + + def __getattribute__(self, name: str): + try: + return object.__getattribute__(self, name) + except AttributeError: + for arg_group in self._arg_group_map.values(): + if hasattr(arg_group, name): + return getattr(arg_group, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__: + object.__setattr__(self, name, value) + return + for arg_group in self._arg_group_map.values(): + if hasattr(arg_group, name): + setattr(arg_group, name, value) + return + object.__setattr__(self, name, value) + + +if __name__ == "__main__": + main() diff --git a/docs/finetrainers-src-codebase/examples/inference/wan/wan_text_to_video.sh b/docs/finetrainers-src-codebase/examples/inference/wan/wan_text_to_video.sh new file mode 100755 index 0000000000000000000000000000000000000000..52e724af8680b90a1829a8b26e36efa1b7522c4e --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/inference/wan/wan_text_to_video.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +# export WANDB_MODE="offline" +export WANDB_MODE="disabled" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Download the validation dataset +if [ ! -d "examples/inference/datasets/openvid-1k-split-validation" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/inference/datasets/openvid-1k-split-validation +else + echo "Validation dataset already exists. Skipping download." +fi + +BACKEND="ptd" + +NUM_GPUS=4 +CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Check the JSON files for the expected JSON format +DATASET_FILE="examples/inference/wan/dummy_text_to_video.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +CP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 2 --tp_degree 1" +CP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 4 --tp_degree 1" +# FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +# FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +# HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $CP_4 +) + +# Model arguments +model_cmd=( + --model_name "wan" + --pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + --enable_slicing + --enable_tiling +) + +# Inference arguments +inference_cmd=( + --inference_type text_to_video + --dataset_file "$DATASET_FILE" +) + +# Attention provider arguments +attn_provider_cmd=( + --attn_provider sage +) + +# Torch config arguments +torch_config_cmd=( + --allow_tf32 + --float32_matmul_precision high +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --seed 31337 + --tracker_name "finetrainers-inference" + --output_dir "/raid/aryan/wan-inference" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the inference script +export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + +torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + examples/inference/inference.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${inference_cmd[@]}" \ + "${attn_provider_cmd[@]}" \ + "${torch_config_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/.gitignore b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..55cb9b50839c38382a7ff49be6a27faf23d6a626 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/.gitignore @@ -0,0 +1 @@ +!validation_dataset/**/* \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/README.md b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/README.md new file mode 100644 index 0000000000000000000000000000000000000000..791459925fe0e954162662e64e09a271ee28fff3 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/README.md @@ -0,0 +1,15 @@ +# CogView4 Canny Control training + +To launch training, you can run the following from the root directory of the repository. + +```bash +chmod +x ./examples/training/sft/cogview4/canny/train.sh +./examples/training/sft/cogview4/canny/train.sh +``` + +The script should automatically download the validation dataset, but in case that doesn't happen, please make sure that a folder named `validation_dataset` exists in `examples/training/sft/cogview4/omni_edit/` and contains the validation dataset. You can also configure `validation.json` in the same directory however you like for your own validation dataset. + +```bash +cd examples/training/sft/cogview4/canny/ +huggingface-cli download --repo-type dataset finetrainers/Canny-image-validation-dataset --local-dir validation_dataset +``` diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/train.sh b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..f24935c6e889cddacc74d28ec3f2cafd029d87b1 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/train.sh @@ -0,0 +1,173 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="INFO" + +# Download the validation dataset +if [ ! -d "examples/training/control/cogview4/canny/validation_dataset" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/Canny-image-validation-dataset --local-dir examples/training/control/cogview4/canny/validation_dataset +else + echo "Validation dataset already exists. Skipping download." +fi + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 1 GPU on 4-GPU node for training +NUM_GPUS=1 +CUDA_VISIBLE_DEVICES="3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/control/cogview4/canny/training.json" +VALIDATION_DATASET_FILE="examples/training/control/cogview4/canny/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_1 +) + +# Model arguments +model_cmd=( + --model_name "cogview4" + --pretrained_model_name_or_path "THUDM/CogView4-6B" + --compile_modules transformer +) + +# Control arguments +control_cmd=( + --control_type canny + --rank 128 + --lora_alpha 128 + --target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 16 +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type control-lora + --seed 42 + --batch_size 1 + --train_steps 10000 + --gradient_accumulation_steps 1 + --gradient_checkpointing + # --checkpointing_steps 1000 + # --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 3e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 2000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-cogview4-control" + --output_dir "/raid/aryan/cogview4-control-lora-canny" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/training.json b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/training.json new file mode 100644 index 0000000000000000000000000000000000000000..cc3ebcae7b49e0e3fbdc169c556e5ac573273794 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/training.json @@ -0,0 +1,13 @@ +{ + "datasets": [ + { + "data_root": "recoilme/aesthetic_photos_xs", + "dataset_type": "image", + "image_resolution_buckets": [ + [1024, 1024] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/validation.json b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..190abe4d63561564dc9fa83e4efb758e2b11da1d --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/canny/validation.json @@ -0,0 +1,32 @@ +{ + "data": [ + { + "caption": "an orange flamingo stands in the shallow water, solo, standing, full_body, outdoors, blurry, no_humans, bird, leaf, plant, animal_focus, beak", + "image_path": "examples/training/control/cogview4/canny/validation_dataset/0.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "a woman holding a bouquet of flowers on the street, 1girl, solo, long_hair, shirt, black_hair, long_sleeves, holding, bare_shoulders, jewelry, upper_body, flower, earrings, outdoors, parted_lips, striped, off_shoulder, black_eyes, tree, looking_to_the_side, grass, pink_flower, striped_shirt, off-shoulder_shirt, holding_flower", + "image_path": "examples/training/control/cogview4/canny/validation_dataset/1.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "there is a boat on the river in the wilderness, outdoors, sky, day, water, blurry, tree, no_humans, leaf, plant, nature, scenery, reflection, mountain, lake", + "image_path": "examples/training/control/cogview4/canny/validation_dataset/2.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "a man in white lab coat wearing a pair of virtual glasses, shirt, black_hair, 1boy, holding, male_focus, multiple_boys, indoors, blurry, cup, depth_of_field, blurry_background, blue_shirt, mug, labcoat, coffee, coffee_mug, doctor", + "image_path": "examples/training/control/cogview4/canny/validation_dataset/3.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/.gitignore b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..55cb9b50839c38382a7ff49be6a27faf23d6a626 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/.gitignore @@ -0,0 +1 @@ +!validation_dataset/**/* \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/README.md b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1115ba540bece5dd96003fe5a3f8dd3c95f264f4 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/README.md @@ -0,0 +1,15 @@ +# CogView4 Edit Control training + +To launch training, you can run the following from the root directory of the repository. + +```bash +chmod +x ./examples/training/sft/cogview4/omni_edit/train.sh +./examples/training/sft/cogview4/omni_edit/train.sh +``` + +The script should automatically download the validation dataset, but in case that doesn't happen, please make sure that a folder named `validation_dataset` exists in `examples/training/sft/cogview4/omni_edit/` and contains the validation dataset. You can also configure `validation.json` in the same directory however you like for your own validation dataset. + +```bash +cd examples/training/sft/cogview4/omni_edit/ +huggingface-cli download --repo-type dataset finetrainers/OmniEdit-validation-dataset --local-dir validation_dataset +``` diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/train.sh b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..966065f129a33f4a45ae62615b01b1824913f40f --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/train.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="INFO" + +# Download the validation dataset +if [ ! -d "examples/training/control/cogview4/omni_edit/validation_dataset" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/OmniEdit-validation-dataset --local-dir examples/training/control/cogview4/omni_edit/validation_dataset +else + echo "Validation dataset already exists. Skipping download." +fi + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 8 GPUs on a single node for training +NUM_GPUS=8 +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/control/cogview4/omni_edit/training.json" +VALIDATION_DATASET_FILE="examples/training/control/cogview4/omni_edit/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_8 +) + +# Model arguments +model_cmd=( + --model_name "cogview4" + --pretrained_model_name_or_path "THUDM/CogView4-6B" +) + +# Control arguments +control_cmd=( + --control_type custom + --rank 128 + --lora_alpha 128 + --target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)" +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 16 +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type control-lora + --seed 42 + --batch_size 1 + --train_steps 10000 + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 5 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 3e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 2000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-cogview4-control" + --output_dir "/fsx/aryan/cogview4-control-lora" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/training.json b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/training.json new file mode 100644 index 0000000000000000000000000000000000000000..b758ec8fff718b28bfde91ba1fdc06430103cf2a --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/training.json @@ -0,0 +1,50 @@ +{ + "datasets": [ + { + "data_root": "sayakpaul/OmniEdit-mini", + "dataset_type": "image", + "image_resolution_buckets": [ + [384, 384], + [480, 480], + [512, 512], + [640, 640], + [704, 704], + [768, 768], + [832, 832], + [896, 896], + [960, 960], + [1024, 1024], + [1152, 1152], + [1280, 1280], + [1344, 1344], + [480, 720], + [480, 768], + [512, 768], + [640, 768], + [768, 960], + [768, 1152], + [768, 1360], + [864, 1152], + [864, 1360], + [720, 480], + [768, 480], + [768, 512], + [768, 640], + [960, 768], + [1152, 768], + [1360, 768], + [1152, 864], + [1360, 864] + ], + "caption_options": { + "column_names": "edited_prompt_list" + }, + "rename_columns": { + "src_img": "control_image", + "edited_img": "image" + }, + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/validation.json b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..f62b5cba6de85e08f643d82c7d25f5aabfe153bf --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/control/cogview4/omni_edit/validation.json @@ -0,0 +1,116 @@ +{ + "data": [ + { + "caption": "Make the image look like it's from an ancient Egyptian mural.", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/0.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "Make it look like a cubist painting.", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/1.png", + "num_inference_steps": 30, + "height": 768, + "width": 1360 + }, + { + "caption": "turn the color of mushroom to gray", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/2.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "transform the setting to a misty atmosphere", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/3.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "Replace the vampire with cape with woman", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/4.png", + "num_inference_steps": 30, + "height": 768, + "width": 1344 + }, + { + "caption": "Convert this into a cartoon.", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/5.png", + "num_inference_steps": 30, + "height": 768, + "width": 1152 + }, + { + "caption": "Make it Pop Art.", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/6.png", + "num_inference_steps": 30, + "height": 864, + "width": 1152 + }, + { + "caption": "transform the setting to a stormy space", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/7.png", + "num_inference_steps": 30, + "height": 1024, + "width": 1024 + }, + { + "caption": "Change it to look like it's in the style of an impasto painting.", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/8.png", + "num_inference_steps": 30, + "height": 768, + "width": 1152 + }, + { + "caption": "Replace the MOVIES sign with RITZ sign", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/9.png", + "num_inference_steps": 30, + "height": 1152, + "width": 864 + }, + { + "caption": "transform the setting to a sunny day", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/10.png", + "num_inference_steps": 30, + "height": 864, + "width": 1152 + }, + { + "caption": "Replace the cobblestone street with road", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/11.png", + "num_inference_steps": 30, + "height": 768, + "width": 1344 + }, + { + "caption": "Replace the rolling hills with ocean", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/12.png", + "num_inference_steps": 30, + "height": 768, + "width": 1344 + }, + { + "caption": "change the setting to spring with blooming trees", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/13.png", + "num_inference_steps": 30, + "height": 768, + "width": 1152 + }, + { + "caption": "Make this photo look like a comic book", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/14.png", + "num_inference_steps": 30, + "height": 768, + "width": 1360 + }, + { + "caption": "Replace the ninja gloves with white gloves", + "control_image_path": "examples/training/control/cogview4/omni_edit/validation_dataset/15.png", + "num_inference_steps": 30, + "height": 768, + "width": 1344 + } + ] +} diff --git a/docs/finetrainers/examples_training_wan_image_conditioning__train.sh b/docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/train.sh old mode 100644 new mode 100755 similarity index 100% rename from docs/finetrainers/examples_training_wan_image_conditioning__train.sh rename to docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/train.sh diff --git a/docs/finetrainers/examples_training_wan_image_conditioning__ttraining.json b/docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/training.json similarity index 100% rename from docs/finetrainers/examples_training_wan_image_conditioning__ttraining.json rename to docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/training.json diff --git a/docs/finetrainers/examples_training_wan_image_conditioning__tvalidation.json b/docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/validation.json similarity index 100% rename from docs/finetrainers/examples_training_wan_image_conditioning__tvalidation.json rename to docs/finetrainers-src-codebase/examples/training/control/wan/image_condition/validation.json diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..1614557ab1fa19643bb7595a9ca5e8142960749b --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/train.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="0,1" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/cogvideox/crush_smol_lora/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/cogvideox/crush_smol_lora/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_2 +) + +# Model arguments +model_cmd=( + --model_name "cogvideox" + --pretrained_model_name_or_path "THUDM/CogVideoX1.5-5B" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~50 videos. Since we're using 2 GPUs, we precompute +# embeddings of 25 dataset items per GPU. Also, we're using a very small dataset for finetuning, so +# we are okay with precomputing embeddings once and re-using them without having to worry about disk +# space. Currently, however, every new training run performs precomputation even if it's not required +# (which is something we've to improve [TODO(aryan)]) +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 10 + --enable_precomputation + --precomputation_items 25 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 3000 + --rank 32 + --lora_alpha 32 + --target_modules "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 1000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 5e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-cogvideox" + --output_dir "/raid/aryan/cogvideox" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/training.json b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/training.json new file mode 100644 index 0000000000000000000000000000000000000000..c57e3a4f6139e1c5a2e91dc3e78875174c6eaeae --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/training.json @@ -0,0 +1,14 @@ +{ + "datasets": [ + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [81, 480, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..e76c9adba60145944c963b073ffca3d177c89a82 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogvideox/crush_smol_lora/validation.json @@ -0,0 +1,44 @@ +{ + "data": [ + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..e4f7702af383d1f5d75f0257c45dcc7fdb8b4dcb --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/train.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/cogview4/raider_white_tarot/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/cogview4/raider_white_tarot/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_2 +) + +# Model arguments +model_cmd=( + --model_name "cogview4" + --pretrained_model_name_or_path "THUDM/CogView4-6B" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same +# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. Since +# we're using 2 GPUs for training, we can split the data into 120 images per GPU and precompute +# all embeddings at once, instead of doing it on-the-fly which would be slower (the ideal usecase +# of not using `--precomputation_once` is when you're training on large datasets) +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 32 + --enable_precomputation + --precomputation_items 120 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 5000 + --rank 32 + --lora_alpha 32 + --target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 3e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-cogview4" + --output_dir "/raid/aryan/cogview4" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/training.json b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/training.json new file mode 100644 index 0000000000000000000000000000000000000000..4401d4c8eb10411636d25b1c49a473544fc40b66 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/training.json @@ -0,0 +1,34 @@ +{ + "datasets": [ + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [1280, 720] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [512, 512] + ], + "reshape_mode": "center_crop", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [768, 768] + ], + "reshape_mode": "center_crop", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..719191e9a8ba96e5435d37f8baf41ec0bb5ad087 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/raider_white_tarot/validation.json @@ -0,0 +1,68 @@ +{ + "data": [ + { + "caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 512 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 512 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/README.md b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c42371e094b80773069633747ccba1fd61baca4 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/README.md @@ -0,0 +1,5 @@ +# CogView4-6B The Simpsons dataset + +This example is only an experiment to verify if webdataset loading and streaming from the HF Hub works as expected. Do not expect meaningful results. + +The dataset used for testing is available at [`bigdata-pw/TheSimpsons`](https://huggingface.co/datasets/bigdata-pw/TheSimpsons). diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..c2d9cd2675cede4dc56040d705b211d91e67b836 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/train.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="INFO" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using all 8 GPUs on a 8-GPU node for training +NUM_GPUS=8 +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/cogview4/the_simpsons/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/cogview4/the_simpsons/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" +HSDP_4_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $HSDP_4_2 +) + +# Model arguments +model_cmd=( + --model_name "cogview4" + --pretrained_model_name_or_path "THUDM/CogView4-6B" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same +# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. Since +# we're using 2 GPUs for training, we can split the data into 120 images per GPU and precompute +# all embeddings at once, instead of doing it on-the-fly which would be slower (the ideal usecase +# of not using `--precomputation_once` is when you're training on large datasets) +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 32 +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 5000 + --rank 128 + --lora_alpha 128 + --target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 1e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 2000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-cogview4" + --output_dir "/fsx/aryan/cogview4" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/training.json b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/training.json new file mode 100644 index 0000000000000000000000000000000000000000..e537ed687b2e8b956e8a6558513e985a8043eea5 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/training.json @@ -0,0 +1,24 @@ +{ + "datasets": [ + { + "data_root": "bigdata-pw/TheSimpsons", + "dataset_type": "image", + "id_token": "SMPSN", + "image_resolution_buckets": [ + [960, 528], + [720, 528], + [720, 480] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true, + "caption_options": { + "column_names": ["caption.txt", "detailed_caption.txt", "more_detailed_caption.txt"], + "weights": { + "caption.txt": 0.2, + "detailed_caption.txt": 0.6, + "more_detailed_caption.txt": 0.2 + } + } + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..1517f7bef13c42f7b4a88091834cafe372ad099c --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/cogview4/the_simpsons/validation.json @@ -0,0 +1,132 @@ +{ + "data": [ + { + "caption": "SMPSN Homer Simpson and Santa Claus walking down a street, with Homer wearing a blue jacket and a red and white Santa Claus costume. The background is a bright blue sky with white snowflakes falling from the sky.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN Homer Simpson and Santa Claus walking down a street, with Homer wearing a blue jacket and a red and white Santa Claus costume. The background is a bright blue sky with white snowflakes falling from the sky.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 720 + }, + { + "caption": "SMPSN Homer Simpson and Santa Claus walking down a street, with Homer wearing a blue jacket and a red and white Santa Claus costume. The background is a bright blue sky with white snowflakes falling from the sky.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 720 + }, + { + "caption": "SMPSN A man and a woman sitting on a bed in a bedroom.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN Marge Simpson from The Simpsons sitting on a couch in front of a wall. She is wearing a yellow dress and has a cheerful expression on her face.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN Homer Simpson becomes an astronaut and floats around in space, essentially becoming omnipresent, rule of worlds", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN Homer Simpson riding a horse, realistic", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN A brown bear standing in the woods with its mouth open, surrounded by trees and plants in a green and blue background", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 528, + "width": 960 + }, + { + "caption": "SMPSN A still from the animated TV show, The Simpsons. It shows three characters, Homer Simpson, Marge Simpson, and Lisa Simpson, standing in a grassy field with trees and a blue sky in the background. Homer is on the left side of the image, wearing a pink hat and a pink dress. Marge is in the center, with her arms around Marge's waist. She has a surprised expression on her face and is looking at Homer with a concerned expression. On the right side, there is a man in a blue suit and tie, who appears to be Bart Simpson. In the background, there are mushrooms and a red brick wall.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1024, + "width": 1024 + }, + { + "caption": "SMPSN A cat from the world of The Simpsons escapes the confines of conciousness, absorbing all knowledge in existence", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1024, + "width": 1024 + }, + { + "caption": "SMPSN two cartoon characters, Bart Simpson and Lisa Simpson, sitting on the grass and reading books. Behind them is a tree trunk and a fence, and in the background there are trees and a blue sky.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1024, + "width": 1024 + }, + { + "caption": "SMPSN A family picture of Homer Simpson's family. 1920s, baroque art style", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1024, + "width": 1024 + }, + { + "caption": "SMPSN The Simpson family as a mighty group of crimefighters", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + }, + { + "caption": "SMPSN The Simpson family, cyberpunk style, artistic depiction", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + }, + { + "caption": "SMPSN The Simpson family, re-imagined as a group of peculiar scientists", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 720, + "width": 1280 + }, + { + "caption": "SMPSN A supercar made just for the Simpsons drifing, cartoon artistic style, 3D animation", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 720, + "width": 1280 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..26acde2e52f1bab29c7f1e62b1ce4c7c03d94fb2 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/train.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 1 GPUs on a 4-GPU node for training +NUM_GPUS=1 +CUDA_VISIBLE_DEVICES="3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/flux_dev/raider_white_tarot/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/flux_dev/raider_white_tarot/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_1 +) + +# Model arguments +model_cmd=( + --model_name "flux" + --pretrained_model_name_or_path "black-forest-labs/FLUX.1-dev" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same +# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 32 + --enable_precomputation + --precomputation_items 240 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 1000 + --rank 32 + --lora_alpha 32 + --target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0|add_q_proj|add_k_proj|add_v_proj|to_add_out)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 251 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 1e-4 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 200 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 251 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-flux" + --output_dir "/raid/aryan/flux" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/training.json b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/training.json new file mode 100644 index 0000000000000000000000000000000000000000..4401d4c8eb10411636d25b1c49a473544fc40b66 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/training.json @@ -0,0 +1,34 @@ +{ + "datasets": [ + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [1280, 720] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [512, 512] + ], + "reshape_mode": "center_crop", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "multimodalart/1920-raider-waite-tarot-public-domain", + "dataset_type": "image", + "id_token": "TRTCRD", + "image_resolution_buckets": [ + [768, 768] + ], + "reshape_mode": "center_crop", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..719191e9a8ba96e5435d37f8baf41ec0bb5ad087 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/flux_dev/raider_white_tarot/validation.json @@ -0,0 +1,68 @@ +{ + "data": [ + { + "caption": "TRTCRD a trtcrd of a knight mounting a running horse wearing an armor and holding a staff, \"knight of wands\"", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a woman sitting on a throne, wearing a crown and holding a trophee, \"queen of cups\"", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 1280, + "width": 720 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 512 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 512 + }, + { + "caption": "TRTCRD a trtcrd of a knight holding the cup while mounts on a stationary horse", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + }, + { + "caption": "TRTCRD a trtcrd of a person in a red robe holding a scale and giving coins to two kneeling figures, surrounded by six pentacles", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 768, + "width": 768 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..79d469dac5f40274294f7aa2c1e7840b82d35545 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh @@ -0,0 +1,159 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=8 +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" +HSDP_4_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $HSDP_4_2 +) + +# Model arguments +model_cmd=( + --model_name "hunyuan_video" + --pretrained_model_name_or_path "hunyuanvideo-community/HunyuanVideo" +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 10 + --enable_precomputation + --precomputation_items 10 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 3000 + --rank 32 + --lora_alpha 32 + --target_modules "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|add_q_proj|add_k_proj|add_v_proj|to_add_out)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 500 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 3e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-hunyuanvideo" + --output_dir "/fsx/aryan/lora-training/hunyuanvideo" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json new file mode 100644 index 0000000000000000000000000000000000000000..3d211b0612e517ed49e92d945c06195ebebc0916 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json @@ -0,0 +1,24 @@ +{ + "datasets": [ + { + "data_root": "modal-labs/dissolve", + "dataset_type": "video", + "id_token": "MODAL_DISSOLVE", + "video_resolution_buckets": [ + [49, 480, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "modal-labs/dissolve", + "dataset_type": "video", + "id_token": "MODAL_DISSOLVE", + "video_resolution_buckets": [ + [81, 480, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..4e50723dd63a2d4f90e24aa9e407d258c4bf9b02 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json @@ -0,0 +1,76 @@ +{ + "data": [ + { + "caption": "MODAL_DISSOLVE A meticulously detailed, antique-style vase, featuring mottled beige and brown hues and two small handles, sits centrally on a dark brown circular pedestal. The vase, seemingly made of clay or porcelain, begins to dissolve from the bottom up. The disintegration process is rapid but not explosive, with a cloud of fine, light tan dust forming and rising in a swirling, almost ethereal column that expands outwards before slowly descending. The dust particles are individually visible as they float, and the overall effect is one of delicate disintegration rather than shattering. Finally, only the empty pedestal and the intricately patterned marble floor remain.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE Close-up view of a sloth resting on a thick tree branch within a dense, sun-dappled forest. The sloth's body, initially clearly defined, begins to subtly disintegrate. The process starts with a light dusting of particles from its lower back and rump. This quickly intensifies, with a visible cloud of fine, sparkling dust billowing outwards as the sloth's form gradually vanishes. The dissolution proceeds in a wave-like manner, moving from rear to front. The head and arms are the last parts to disappear, leaving only scattered motes of dust that slowly disperse amongst the leaves, blending seamlessly with the forest environment. The overall effect is dreamlike and ethereal.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete digital dissolution of an orange Porsche 911 GT3 RS within a garage environment. The car's dissolution proceeds in three discernible stages: (1) Initial shimmering along the car's edges and body panels, creating a subtle, high-frequency displacement effect. (2) Rapid disintegration of the vehicle into a dense cloud of primarily orange and black particles, varying in size and opacity; particle motion exhibits both outward and swirling movements. (3) Complete disappearance of the car, leaving behind only a remaining, smaller, seemingly fiery-textured rubber duck model. The overall effect resembles a controlled explosion or rapid combustion, creating a dynamic, visually complex transformation. The garage's lighting and shadows remain consistent throughout the dissolution, providing clear visual contrast.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete disintegration of a white origami crane. The disintegration process is initiated at the head of the crane and proceeds in a generally downward direction. The disintegration manifests as the rapid breakdown of paper fibers into a cloud of fine particulate matter. The particle size appears consistent, with a texture similar to very fine powder. The rate of disintegration increases over time, resulting in a visually dynamic and texturally complex effect. The background consists of a dark-stained wooden surface, providing a high-contrast setting that highlights the white particles' dispersal and movement. The final state shows only residual particulate matter scattered sparsely on the surface.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE A meticulously detailed, antique-style vase, featuring mottled beige and brown hues and two small handles, sits centrally on a dark brown circular pedestal. The vase, seemingly made of clay or porcelain, begins to dissolve from the bottom up. The disintegration process is rapid but not explosive, with a cloud of fine, light tan dust forming and rising in a swirling, almost ethereal column that expands outwards before slowly descending. The dust particles are individually visible as they float, and the overall effect is one of delicate disintegration rather than shattering. Finally, only the empty pedestal and the intricately patterned marble floor remain.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE Close-up view of a sloth resting on a thick tree branch within a dense, sun-dappled forest. The sloth's body, initially clearly defined, begins to subtly disintegrate. The process starts with a light dusting of particles from its lower back and rump. This quickly intensifies, with a visible cloud of fine, sparkling dust billowing outwards as the sloth's form gradually vanishes. The dissolution proceeds in a wave-like manner, moving from rear to front. The head and arms are the last parts to disappear, leaving only scattered motes of dust that slowly disperse amongst the leaves, blending seamlessly with the forest environment. The overall effect is dreamlike and ethereal.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete digital dissolution of an orange Porsche 911 GT3 RS within a garage environment. The car's dissolution proceeds in three discernible stages: (1) Initial shimmering along the car's edges and body panels, creating a subtle, high-frequency displacement effect. (2) Rapid disintegration of the vehicle into a dense cloud of primarily orange and black particles, varying in size and opacity; particle motion exhibits both outward and swirling movements. (3) Complete disappearance of the car, leaving behind only a remaining, smaller, seemingly fiery-textured rubber duck model. The overall effect resembles a controlled explosion or rapid combustion, creating a dynamic, visually complex transformation. The garage's lighting and shadows remain consistent throughout the dissolution, providing clear visual contrast.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete disintegration of a white origami crane. The disintegration process is initiated at the head of the crane and proceeds in a generally downward direction. The disintegration manifests as the rapid breakdown of paper fibers into a cloud of fine particulate matter. The particle size appears consistent, with a texture similar to very fine powder. The rate of disintegration increases over time, resulting in a visually dynamic and texturally complex effect. The background consists of a dark-stained wooden surface, providing a high-contrast setting that highlights the white particles' dispersal and movement. The final state shows only residual particulate matter scattered sparsely on the surface.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..1eb0d38b60d261f53ba6b5a4da4c4f2bb3e2b3ad --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/ltx_video/crush_smol_lora/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/ltx_video/crush_smol_lora/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_2 +) + +# Model arguments +model_cmd=( + --model_name "ltx_video" + --pretrained_model_name_or_path "a-r-r-o-w/LTX-Video-diffusers" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~50 videos. Since we're using 2 GPUs, we precompute +# embeddings of 25 dataset items per GPU. Also, we're using a very small dataset for finetuning, so +# we are okay with precomputing embeddings once and re-using them without having to worry about disk +# space. Currently, however, every new training run performs precomputation even if it's not required +# (which is something we've to improve [TODO(aryan)]) +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 10 + --enable_precomputation + --precomputation_items 25 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 5000 + --rank 32 + --lora_alpha 32 + --target_modules "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 5e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-ltxvideo" + --output_dir "/raid/aryan/ltx-video" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train_multires.sh b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train_multires.sh new file mode 100755 index 0000000000000000000000000000000000000000..c95dbe44d52dce24f9f8844aad32415aff5d0ce5 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/train_multires.sh @@ -0,0 +1,171 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/ltx_video/crush_smol_lora/training_multires.json" +VALIDATION_DATASET_FILE="examples/training/sft/ltx_video/crush_smol_lora/validation_multires.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_2 +) + +# Model arguments +model_cmd=( + --model_name "ltx_video" + --pretrained_model_name_or_path "a-r-r-o-w/LTX-Video-diffusers" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~50 videos. Since we're using 2 GPUs, we precompute +# embeddings of 25 dataset items per GPU. Also, we're using a very small dataset for finetuning, so +# we are okay with precomputing embeddings once and re-using them without having to worry about disk +# space. Currently, however, every new training run performs precomputation even if it's not required +# (which is something we've to improve [TODO(aryan)]) +# Note: This is a copy of `train.sh` file from the same directory. For multi-resolution training, we +# are utilizing the same dataset with different total frame counts. We do this with 4 copies of the +# dataset, so also multiply the precomputation items with 4, in order to make sure all required embeddings +# are precomputed at once. +# We also bump the shuffle buffer size to have ample diversity in different orders through training. +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 50 + --enable_precomputation + --precomputation_items 100 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +# Note: Since we're training on multiple resolutions, it might take the model more steps +# to converge. We bump up the training steps to 7500 just to be save experimentation time +# can choose the best performing checkpoint manually. +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 7500 + --rank 32 + --lora_alpha 32 + --target_modules "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 5e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-ltxvideo" + --output_dir "/raid/aryan/ltx-video" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training.json b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training.json new file mode 100644 index 0000000000000000000000000000000000000000..0c3f8aa31fa14f612c819695dd97130bbc368f8f --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training.json @@ -0,0 +1,14 @@ +{ + "datasets": [ + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [49, 512, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training_multires.json b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training_multires.json new file mode 100644 index 0000000000000000000000000000000000000000..fb6b6ec85de2852a296144aa783fe4fadbd2ba24 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/training_multires.json @@ -0,0 +1,44 @@ +{ + "datasets": [ + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [49, 512, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [81, 512, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [121, 512, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "finetrainers/crush-smol", + "dataset_type": "video", + "id_token": "PIKA_CRUSH", + "video_resolution_buckets": [ + [161, 512, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..e76c9adba60145944c963b073ffca3d177c89a82 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation.json @@ -0,0 +1,44 @@ +{ + "data": [ + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation_multires.json b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation_multires.json new file mode 100644 index 0000000000000000000000000000000000000000..85f469fb61621e369ef4e8b1d1557959400c6c68 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/ltx_video/crush_smol_lora/validation_multires.json @@ -0,0 +1,84 @@ +{ + "data": [ + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 49, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 81, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 121, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 161, + "frame_rate": 25 + }, + { + "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 512, + "width": 768, + "num_frames": 161, + "frame_rate": 25 + } + ] +} diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..58abdab5092e9266742641adb592bdecd89b3619 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/train.sh @@ -0,0 +1,163 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=2 +CUDA_VISIBLE_DEVICES="2,3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/wan/3dgs_dissolve/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/wan/3dgs_dissolve/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_2 +) + +# Model arguments +model_cmd=( + --model_name "wan" + --pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +) + +# Dataset arguments +# Here, we know that the dataset size if about ~100 videos. Since we're using 2 GPUs, we precompute +# embeddings of 50 dataset items per GPU. Also, we're using a very small dataset for finetuning, so +# we are okay with precomputing embeddings once and re-using them without having to worry about disk +# space. Currently, however, every new training run performs precomputation even if it's not required +# (which is something we've to improve [TODO(aryan)]) +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 10 + --enable_precomputation + --precomputation_items 100 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 5000 + --rank 32 + --lora_alpha 32 + --target_modules "blocks.*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 500 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 5e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-wan" + --output_dir "/raid/aryan/wan" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/training.json b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/training.json new file mode 100644 index 0000000000000000000000000000000000000000..42239774a4ab333eea284789b348a7390021d692 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/training.json @@ -0,0 +1,24 @@ +{ + "datasets": [ + { + "data_root": "finetrainers/3dgs-dissolve", + "dataset_type": "video", + "id_token": "3DGS_DISSOLVE", + "video_resolution_buckets": [ + [49, 480, 832] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "finetrainers/3dgs-dissolve", + "dataset_type": "video", + "id_token": "3DGS_DISSOLVE", + "video_resolution_buckets": [ + [81, 480, 832] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..7f30a38abda6f7815f1a926877636b9440a04aff --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan/3dgs_dissolve/validation.json @@ -0,0 +1,58 @@ +{ + "data": [ + { + "caption": "A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 49 + }, + { + "caption": "3DGS_DISSOLVE A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 49 + }, + { + "caption": "3DGS_DISSOLVE A spacecraft, rendered in a 3D appearance, ascends into the night sky, leaving behind a trail of fiery exhaust. As it climbs higher, the exhaust gradually transforms into a burst of red sparks, creating a dramatic and dynamic visual effect against the dark backdrop.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "3DGS_DISSOLVE A vintage-style treasure chest, rendered in a 3D appearance, stands prominently against a dark background. As the scene progresses, the chest begins to emit a glowing light, which intensifies until it evaporates into a burst of red sparks, creating a dramatic and mysterious atmosphere.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 49 + }, + { + "caption": "3DGS_DISSOLVE A glowing, fiery cube in a 3D appearance begins to spin and rotate, its edges shimmering with intense light. As it continues to spin, the cube gradually evaporates into a burst of red sparks that scatter across the screen, creating a dynamic and mesmerizing visual effect against the dark background.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 49 + }, + { + "caption": "3DGS_DISSOLVE A dynamic explosion unfolds in a 3D appearance, beginning as a concentrated burst of intense orange flames. As the fire intensifies, it rapidly expands outward, transitioning into a vibrant display of red sparks that scatter across the frame. The sparks continue to evolve, evaporating into a burst of red sparks against the dark backdrop, creating a mesmerizing visual spectacle.", + "image_path": null, + "video_path": null, + "num_inference_steps": 50, + "height": 480, + "width": 832, + "num_frames": 49 + } + ] +} diff --git a/docs/finetrainers/examples_training_wan__train.sh b/docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/train.sh old mode 100644 new mode 100755 similarity index 99% rename from docs/finetrainers/examples_training_wan__train.sh rename to docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/train.sh index 83490b75ba375575697deb57f61282817b82941f..fe51eda1d8a59492df5196b2e8f7e3052c1c18ae --- a/docs/finetrainers/examples_training_wan__train.sh +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/train.sh @@ -49,6 +49,7 @@ model_cmd=( dataset_cmd=( --dataset_config $TRAINING_DATASET_CONFIG --dataset_shuffle_buffer_size 10 + --enable_precomputation --precomputation_items 25 --precomputation_once ) diff --git a/docs/finetrainers/examples_training_wan__training.json b/docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/training.json similarity index 100% rename from docs/finetrainers/examples_training_wan__training.json rename to docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/training.json diff --git a/docs/finetrainers/examples_training_wan__validation.json b/docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/validation.json similarity index 100% rename from docs/finetrainers/examples_training_wan__validation.json rename to docs/finetrainers-src-codebase/examples/training/sft/wan/crush_smol_lora/validation.json diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/train.sh b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..7213180969774d37a4ad112c9df16bcaa7e55c0c --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/train.sh @@ -0,0 +1,176 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Download the validation dataset +if [ ! -d "examples/training/sft/wan_i2v/3dgs_dissolve/validation_dataset" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/training/sft/wan_i2v/3dgs_dissolve/validation_dataset +else + echo "Validation dataset already exists. Skipping download." +fi + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 1 GPU on a 4-GPU node for training +NUM_GPUS=1 +CUDA_VISIBLE_DEVICES="3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/wan_i2v/3dgs_dissolve/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/wan_i2v/3dgs_dissolve/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_1 +) + +# Model arguments +model_cmd=( + --model_name "wan" + --pretrained_model_name_or_path "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + --compile_modules text_encoder image_encoder transformer vae + --compile_scopes regional +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 32 + --enable_precomputation + --precomputation_items 50 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 1000 + --rank 16 + --lora_alpha 16 + --target_modules "blocks.*(to_q|to_k|to_v|to_out.0)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 501 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 1e-4 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 100 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 101 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-wan-i2v" + --output_dir "/raid/aryan/wan-i2v" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Torch config arguments +torch_config_cmd=( + --allow_tf32 + --float32_matmul_precision high +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" \ + "${torch_config_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" \ + "${torch_config_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/training.json b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/training.json new file mode 100644 index 0000000000000000000000000000000000000000..2f9fd437cec63aabb31b88ed41dcaf0cbff615a7 --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/training.json @@ -0,0 +1,14 @@ +{ + "datasets": [ + { + "data_root": "finetrainers/3dgs-dissolve", + "dataset_type": "video", + "id_token": "3DGS_DISSOLVE", + "video_resolution_buckets": [ + [49, 480, 832] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/validation.json b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/validation.json new file mode 100644 index 0000000000000000000000000000000000000000..925e5c2a09c16ba41d4d653772f00e92108624df --- /dev/null +++ b/docs/finetrainers-src-codebase/examples/training/sft/wan_i2v/3dgs_dissolve/validation.json @@ -0,0 +1,36 @@ +{ + "data": [ + { + "caption": "3DGS_DISSOLVE A vibrant green Mustang GT parked in a parking lot. The car is positioned at an angle, showcasing its sleek design and black rims. The car's hood is black, contrasting with the green body. The car gradually transforms and bursts into red sparks, creating a dramatic and dynamic visual effect against a dark backdrop.", + "video_path": "examples/training/sft/wan_i2v/3dgs_dissolve/validation_dataset/0.mp4", + "num_inference_steps": 30, + "num_frames": 49, + "height": 480, + "width": 832 + }, + { + "caption": "3DGS_DISSOLVE A cooking tutorial featuring a man in a kitchen. He is wearing a white t-shirt and a black apron. As the scene progresses, light starts to emanate from the man and he burst into a fiery flame of red sparks.", + "video_path": "examples/training/control/wan/image_condition/validation_dataset/1.mp4", + "num_inference_steps": 30, + "num_frames": 49, + "height": 480, + "width": 832 + }, + { + "caption": "3DGS_DISSOLVE A man in a suit and tie, standing against a blue background with a digital pattern. He appears to be speaking or presenting, as suggested by his open mouth and focused expression. Suddenly, the man starts to dissolve into thin air with a bright fiery flame of red sparks.", + "video_path": "examples/training/control/wan/image_condition/validation_dataset/2.mp4", + "num_inference_steps": 30, + "num_frames": 49, + "height": 480, + "width": 832 + }, + { + "caption": "3DGS_DISSOLVE A man in a workshop, dressed in a black shirt and a beige hat, with a beard and glasses. He is holding a hammer and a metal object, possibly a piece of iron or a tool. The scene erupts with a bright fiery flame of red sparks.", + "video_path": "examples/training/control/wan/image_condition/validation_dataset/3.mp4", + "num_inference_steps": 30, + "num_frames": 49, + "height": 480, + "width": 832 + } + ] +} diff --git a/docs/finetrainers-src-codebase/finetrainers/__init__.py b/docs/finetrainers-src-codebase/finetrainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26acf49de21bb13c355d8d299a4a204d691c7536 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/__init__.py @@ -0,0 +1,8 @@ +from .args import BaseArgs +from .config import ModelType, TrainingType +from .logging import get_logger +from .models import ModelSpecification +from .trainer import ControlTrainer, SFTTrainer + + +__version__ = "0.2.0.dev0" diff --git a/docs/finetrainers-src-codebase/finetrainers/_metadata.py b/docs/finetrainers-src-codebase/finetrainers/_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..7fccac48904cd2f59782b2d3c0a738dc354db709 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/_metadata.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass, field +from typing import Dict, ForwardRef, List, Optional, Type, Union + + +ParamIdentifierType = ForwardRef("ParamIdentifier") +ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata") +ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata") + +_ContextParallelInputType = Dict[ + ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]] +] +_ContextParallelOutputType = List[ContextParallelOutputMetadataType] +ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType] + + +@dataclass(frozen=True) +class ParamId: + """ + A class to identify a parameter of a method. + + Atleast one of `name` or `index` must be provided. + + Attributes: + name (`str`, *optional*): + The name of the parameter. + index (`int`, *optional*): + The index of the parameter in the method signature. Indexing starts at 0 (ignore + the `self` parameter for instance methods). + """ + + name: Optional[str] = None + index: Optional[int] = None + + def __post_init__(self): + if self.name is None and self.index is None: + raise ValueError("At least one of `name` or `index` must be provided.") + + +@dataclass(frozen=True) +class CPInput: + split_dim: int + expected_dims: Optional[int] = None + split_output: bool = False + + +@dataclass(frozen=True) +class CPOutput: + gather_dim: int + expected_dims: Optional[int] = None + + +@dataclass +class TransformerMetadata: + # Mapping of FQN to mapping of input name to ContextParallelModelPlan + cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict) + + # tp_plan # TODO(aryan) + + +class TransformerRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] diff --git a/docs/finetrainers-src-codebase/finetrainers/args.py b/docs/finetrainers-src-codebase/finetrainers/args.py new file mode 100644 index 0000000000000000000000000000000000000000..81db52ba19702a74351415ea76ad7836e3503c63 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/args.py @@ -0,0 +1,1032 @@ +import argparse +import os +import pathlib +import sys +from typing import Any, Dict, List, Literal, Optional, Union + +import torch + +from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType +from .logging import get_logger +from .parallel import ParallelBackendEnum +from .utils import ArgsConfigMixin, get_non_null_items + + +logger = get_logger() + +# fmt: off +# Must match src/finetrainers/models/attention_dispatch.py +AttentionProviderTraining = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "xformers"] +AttentionProviderInference = Literal["flash", "flash_varlen", "flex", "native", "_native_cudnn", "_native_efficient", "_native_flash", "_native_math", "sage", "sage_varlen", "_sage_qk_int8_pv_fp8_cuda", "_sage_qk_int8_pv_fp8_cuda_sm90", "_sage_qk_int8_pv_fp16_cuda", "_sage_qk_int8_pv_fp16_triton", "xformers"] + +# We do a union because every ArgsConfigMixin registered to BaseArgs can be looked up using the `__getattribute__` override +BaseArgsType = Union["BaseArgs", "AttentionProviderArgs"] +# fmt: on + + +class AttentionProviderArgs(ArgsConfigMixin): + """ + Args: + attn_provider_training (`List[str]`, defaults to `None`): + Must be a string of the form `":"`. For example, if you want to use + flash varlen attention implementation on the `transformer` module, you can set this argument to + `"transformer:flash_varlen"`. The attention provider will be used for both training and validation. + Options for `` are: + flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math, xformers + attn_provider_inference (`List[str]`, defaults to `None`): + Must be a string of the form `":"`. For example, if you want to use + flash varlen attention implementation on the `transformer` module, you can set this argument to + `"transformer:flash_varlen"`. The attention provider will be used for both training and validation. + Options for `` are: + flash, flash_varlen, flex, native, _native_cudnn, _native_efficient, _native_flash, _native_math, + _native_math, sage, sage_varlen, _sage_qk_int8_pv_fp8_cuda, _sage_qk_int8_pv_fp8_cuda_sm90, + _sage_qk_int8_pv_fp16_cuda, _sage_qk_int8_pv_fp16_triton, xformers + """ + + attn_provider_training: List[AttentionProviderTraining] = None + attn_provider_inference: List[AttentionProviderInference] = None + + def add_args(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--attn_provider_training", + type=str, + default=None, + nargs="+", + help="Attention provider for training. Must be a string of the form `:`.", + ) + parser.add_argument( + "--attn_provider_inference", + type=str, + default=None, + nargs="+", + help="Attention provider for inference. Must be a string of the form `:`.", + ) + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + attn_training = argparse_args.attn_provider_training + attn_inference = argparse_args.attn_provider_inference + if attn_training is None: + attn_training = [] + if attn_inference is None: + attn_inference = [] + mapped_args.attn_provider_training = attn_training + mapped_args.attn_provider_inference = attn_inference + + def validate_args(self, args: "BaseArgs"): + pass + + def to_dict(self) -> Dict[str, Any]: + return { + "attn_provider_training": self.attn_provider_training, + "attn_provider_inference": self.attn_provider_inference, + } + + +class BaseArgs: + """ + The arguments for the finetrainers training script. + + For helpful information about arguments, run `python train.py --help`. + + TODO(aryan): add `python train.py --recommend_configs --model_name ` to recommend + good training configs for a model after extensive testing. + TODO(aryan): add `python train.py --memory_requirements --model_name ` to show + memory requirements per model, per training type with sensible training settings. + + PARALLEL ARGUMENTS + ------------------ + parallel_backend (`str`, defaults to `accelerate`): + The parallel backend to use for training. Choose between ['accelerate', 'ptd']. + pp_degree (`int`, defaults to `1`): + The degree of pipeline parallelism. + dp_degree (`int`, defaults to `1`): + The degree of data parallelism (number of model replicas). + dp_shards (`int`, defaults to `-1`): + The number of data parallel shards (number of model partitions). + cp_degree (`int`, defaults to `1`): + The degree of context parallelism. + + MODEL ARGUMENTS + --------------- + model_name (`str`): + Name of model to train. To get a list of models, run `python train.py --list_models`. + pretrained_model_name_or_path (`str`): + Path to pretrained model or model identifier from https://huggingface.co/models. The model should be + loadable based on specified `model_name`. + revision (`str`, defaults to `None`): + If provided, the model will be loaded from a specific branch of the model repository. + variant (`str`, defaults to `None`): + Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk + storage requirements. + cache_dir (`str`, defaults to `None`): + The directory where the downloaded models and datasets will be stored, or loaded from. + tokenizer_id (`str`, defaults to `None`): + Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_2_id (`str`, defaults to `None`): + Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_3_id (`str`, defaults to `None`): + Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + text_encoder_id (`str`, defaults to `None`): + Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_2_id (`str`, defaults to `None`): + Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_3_id (`str`, defaults to `None`): + Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + transformer_id (`str`, defaults to `None`): + Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`. + vae_id (`str`, defaults to `None`): + Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`. + text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder when generating text embeddings. + text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 2 when generating text embeddings. + text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the text encoder 3 when generating text embeddings. + transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the transformer model. + vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type for the VAE model. + layerwise_upcasting_modules (`List[str]`, defaults to `[]`): + Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer']. + layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`): + Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2']. + layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`): + Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision + naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers + by default, and recommend adding more layers to the default list based on the model architecture. + + DATASET ARGUMENTS + ----------------- + dataset_config (`str`): + File to a dataset file containing information about training data. This file can contain information about one or + more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each + dictionary must contain the following keys: + - "data_root": (`str`) + The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided. + - "dataset_file": (`str`) + Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter + must be provided if `data_root` is not provided. + - "dataset_type": (`str`) + Type of dataset. Choose between ['image', 'video']. + - "id_token": (`str`) + Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training + for single subject/concept/style training, but is not necessary. + - "image_resolution_buckets": (`List[Tuple[int, int]]`) + Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple + represents the resolution (height, width). All images will be resized to the nearest bucket resolution. + This parameter must be provided if `dataset_type` is 'image'. + - "video_resolution_buckets": (`List[Tuple[int, int, int]]`) + Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple + represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket + resolution. This parameter must be provided if `dataset_type` is 'video'. + - "reshape_mode": (`str`) + All input images/videos are reshaped using this mode. Choose between the following: + ["center_crop", "random_crop", "bicubic"]. + - "remove_common_llm_caption_prefixes": (`boolean`) + Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes. + dataset_shuffle_buffer_size (`int`, defaults to `1`): + The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default + value of `1` means that the dataset will not be shuffled. + enable_precomputation (`bool`, defaults to `False`): + Whether or not to precompute the embeddings for the dataset. This is useful for faster training. If set to `True`, + the embeddings will be precomputed and saved to disk and loaded as required. + precomputation_items (`int`, defaults to `512`): + Number of data samples to precompute at once for memory-efficient training. The higher this value, + the more disk memory will be used to save the precomputed samples (conditions and latents). + precomputation_dir (`str`, defaults to `None`): + The directory where the precomputed samples will be stored. If not provided, the precomputed samples + will be stored in a temporary directory of the output directory. + precomputation_once (`bool`, defaults to `False`): + Precompute embeddings from all datasets at once before training. This is useful to save time during training + with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during + training when required (that is, computing embeddings of more data samples once `precomputation_items` of them + have been exhausted across all distributed ranks). Make sure to set `precomputation_items` to a reasonable value + in line with the size of your dataset(s). + precomputation_reuse (`bool`, defaults to `False`): + Reuse precomputed embeddings from previous training runs. This is useful to save time during training + with medium/large datasets. By default, old precomputed embeddings that exist in the specified precomputation + directory, or default precomputation dir `{output_dir}/precomputed` will be deleted if this is not set to `True`. + This flag is ignored if `enable_precomputation` is `False`. The topology of the distributed training run must be + the same as the one used to precompute the embeddings for this to work correctly (this limitation will be + addressed in the future). + + DATALOADER_ARGUMENTS + -------------------- + See https://pytorch.org/docs/stable/data.html for more information. + + dataloader_num_workers (`int`, defaults to `0`): + Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner + on the main process. + pin_memory (`bool`, defaults to `False`): + Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading. + + DIFFUSION ARGUMENTS + ------------------- + flow_resolution_shifting (`bool`, defaults to `False`): + Resolution-dependent shifting of timestep schedules. + [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206). + TODO(aryan): We don't support this yet. + flow_base_seq_len (`int`, defaults to `256`): + Base number of tokens for images/video when applying resolution-dependent shifting. + flow_max_seq_len (`int`, defaults to `4096`): + Maximum number of tokens for images/video when applying resolution-dependent shifting. + flow_base_shift (`float`, defaults to `0.5`): + Base shift for timestep schedules when applying resolution-dependent shifting. + flow_max_shift (`float`, defaults to `1.15`): + Maximum shift for timestep schedules when applying resolution-dependent shifting. + flow_shift (`float`, defaults to `1.0`): + Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma). + Setting it higher is helpful when trying to train models for high-resolution generation or to produce better + samples in lower number of inference steps. + flow_weighting_scheme (`str`, defaults to `none`): + We default to the "none" weighting scheme for uniform sampling and uniform loss. + Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none']. + flow_logit_mean (`float`, defaults to `0.0`): + Mean to use when using the `'logit_normal'` weighting scheme. + flow_logit_std (`float`, defaults to `1.0`): + Standard deviation to use when using the `'logit_normal'` weighting scheme. + flow_mode_scale (`float`, defaults to `1.29`): + Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. + + TRAINING ARGUMENTS + ------------------ + training_type (`str`, defaults to `None`): + Type of training to perform. Choose between ['lora']. + seed (`int`, defaults to `42`): + A seed for reproducible training. + batch_size (`int`, defaults to `1`): + Per-device batch size. + train_steps (`int`, defaults to `1000`): + Total number of training steps to perform. + max_data_samples (`int`, defaults to `2**64`): + Maximum number of data samples observed during training training. If lesser than that required by `train_steps`, + the training will stop early. + gradient_accumulation_steps (`int`, defaults to `1`): + Number of gradients steps to accumulate before performing an optimizer step. + gradient_checkpointing (`bool`, defaults to `False`): + Whether or not to use gradient/activation checkpointing to save memory at the expense of slower + backward pass. + checkpointing_steps (`int`, defaults to `500`): + Save a checkpoint of the training state every X training steps. These checkpoints can be used both + as final checkpoints in case they are better than the last checkpoint, and are also suitable for + resuming training using `resume_from_checkpoint`. + checkpointing_limit (`int`, defaults to `None`): + Max number of checkpoints to store. + resume_from_checkpoint (`str`, defaults to `None`): + Can be an integer or the string `"latest"`. If an integer is provided, training will resume from that step if a + checkpoint corresponding to it exists. If `"latest"` is provided, training will resume from the latest checkpoint + in the `--output_dir`. + + OPTIMIZER ARGUMENTS + ------------------- + optimizer (`str`, defaults to `adamw`): + The optimizer type to use. Choose between the following: + - Torch optimizers: ["adam", "adamw"] + - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"] + lr (`float`, defaults to `1e-4`): + Initial learning rate (after the potential warmup period) to use. + lr_scheduler (`str`, defaults to `cosine_with_restarts`): + The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', + 'constant', 'constant_with_warmup']. + lr_warmup_steps (`int`, defaults to `500`): + Number of steps for the warmup in the lr scheduler. + lr_num_cycles (`int`, defaults to `1`): + Number of hard resets of the lr in cosine_with_restarts scheduler. + lr_power (`float`, defaults to `1.0`): + Power factor of the polynomial scheduler. + beta1 (`float`, defaults to `0.9`): + beta2 (`float`, defaults to `0.95`): + beta3 (`float`, defaults to `0.999`): + weight_decay (`float`, defaults to `0.0001`): + Penalty for large weights in the model. + epsilon (`float`, defaults to `1e-8`): + Small value to avoid division by zero in the optimizer. + max_grad_norm (`float`, defaults to `1.0`): + Maximum gradient norm to clip the gradients. + + VALIDATION ARGUMENTS + -------------------- + validation_dataset_file (`str`, defaults to `None`): + Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the + "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path" + will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path" + will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary. + The validation dataset file may contain other attributes specific to inference/validation such as: + - "height" and "width" and "num_frames": Resolution + - "num_inference_steps": Number of inference steps + - "guidance_scale": Classifier-free Guidance Scale + - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be + invoked with the sample dictionary to validate the sample.) + validation_steps (`int`, defaults to `500`): + Number of training steps after which a validation step is performed. + enable_model_cpu_offload (`bool`, defaults to `False`): + Whether or not to offload different modeling components to CPU during validation. + + MISCELLANEOUS ARGUMENTS + ----------------------- + tracker_name (`str`, defaults to `finetrainers`): + Name of the tracker/project to use for logging training metrics. + push_to_hub (`bool`, defaults to `False`): + Whether or not to push the model to the Hugging Face Hub. + hub_token (`str`, defaults to `None`): + The API token to use for pushing the model to the Hugging Face Hub. + hub_model_id (`str`, defaults to `None`): + The model identifier to use for pushing the model to the Hugging Face Hub. + output_dir (`str`, defaults to `None`): + The directory where the model checkpoints and logs will be stored. + logging_dir (`str`, defaults to `logs`): + The directory where the logs will be stored. + logging_steps (`int`, defaults to `1`): + Training logs will be tracked every `logging_steps` steps. + nccl_timeout (`int`, defaults to `1800`): + Timeout for the NCCL communication. + report_to (`str`, defaults to `wandb`): + The name of the logger to use for logging training metrics. Choose between ['wandb']. + verbose (`int`, defaults to `1`): + Whether or not to print verbose logs. + - 0: Diffusers/Transformers warning logging on local main process only + - 1: Diffusers/Transformers info logging on local main process only + - 2: Diffusers/Transformers debug logging on local main process only + - 3: Diffusers/Transformers debug logging on all processes + + TORCH CONFIG ARGUMENTS + ---------------------- + compile_modules (`List[str]`, defaults to `[]`): + Modules that should be regionally compiled with `torch.compile`. + compile_scopes (`str`, defaults to `None`): + The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as + `--compile_modules`. If `None`, will default to `regional` for all modules. + allow_tf32 (`bool`, defaults to `False`): + Whether or not to allow the use of TF32 matmul on compatible hardware. + float32_matmul_precision (`str`, defaults to `highest`): + The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium']. + """ + + # Parallel arguments + parallel_backend = ParallelBackendEnum.ACCELERATE + pp_degree: int = 1 + dp_degree: int = 1 + dp_shards: int = 1 + cp_degree: int = 1 + tp_degree: int = 1 + + # Model arguments + model_name: str = None + pretrained_model_name_or_path: str = None + revision: Optional[str] = None + variant: Optional[str] = None + cache_dir: Optional[str] = None + tokenizer_id: Optional[str] = None + tokenizer_2_id: Optional[str] = None + tokenizer_3_id: Optional[str] = None + text_encoder_id: Optional[str] = None + text_encoder_2_id: Optional[str] = None + text_encoder_3_id: Optional[str] = None + transformer_id: Optional[str] = None + vae_id: Optional[str] = None + text_encoder_dtype: torch.dtype = torch.bfloat16 + text_encoder_2_dtype: torch.dtype = torch.bfloat16 + text_encoder_3_dtype: torch.dtype = torch.bfloat16 + transformer_dtype: torch.dtype = torch.bfloat16 + vae_dtype: torch.dtype = torch.bfloat16 + layerwise_upcasting_modules: List[str] = [] + layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn + # fmt: off + layerwise_upcasting_skip_modules_pattern: List[str] = ["patch_embed", "pos_embed", "x_embedder", "context_embedder", "time_embed", "^proj_in$", "^proj_out$", "norm"] + # fmt: on + + # Dataset arguments + dataset_config: str = None + dataset_shuffle_buffer_size: int = 1 + enable_precomputation: bool = False + precomputation_items: int = 512 + precomputation_dir: Optional[str] = None + precomputation_once: bool = False + precomputation_reuse: bool = False + + # Dataloader arguments + dataloader_num_workers: int = 0 + pin_memory: bool = False + + # Diffusion arguments + flow_resolution_shifting: bool = False + flow_base_seq_len: int = 256 + flow_max_seq_len: int = 4096 + flow_base_shift: float = 0.5 + flow_max_shift: float = 1.15 + flow_shift: float = 1.0 + flow_weighting_scheme: str = "none" + flow_logit_mean: float = 0.0 + flow_logit_std: float = 1.0 + flow_mode_scale: float = 1.29 + + # Training arguments + training_type: str = None + seed: int = 42 + batch_size: int = 1 + train_steps: int = 1000 + max_data_samples: int = 2**64 + gradient_accumulation_steps: int = 1 + gradient_checkpointing: bool = False + checkpointing_steps: int = 500 + checkpointing_limit: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + enable_slicing: bool = False + enable_tiling: bool = False + + # Optimizer arguments + optimizer: str = "adamw" + lr: float = 1e-4 + lr_scheduler: str = "cosine_with_restarts" + lr_warmup_steps: int = 0 + lr_num_cycles: int = 1 + lr_power: float = 1.0 + beta1: float = 0.9 + beta2: float = 0.95 + beta3: float = 0.999 + weight_decay: float = 0.0001 + epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + + # Validation arguments + validation_dataset_file: Optional[str] = None + validation_steps: int = 500 + enable_model_cpu_offload: bool = False + + # Miscellaneous arguments + tracker_name: str = "finetrainers" + push_to_hub: bool = False + hub_token: Optional[str] = None + hub_model_id: Optional[str] = None + output_dir: str = None + logging_dir: Optional[str] = "logs" + logging_steps: int = 1 + init_timeout: int = 300 # 5 minutes + nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed + report_to: str = "wandb" + verbose: int = 1 + + # Torch config arguments + compile_modules: List[str] = [] + compile_scopes: List[str] = None + allow_tf32: bool = False + float32_matmul_precision: str = "highest" + + # Attention provider arguments + attention_provider_args: AttentionProviderArgs = AttentionProviderArgs() + + _registered_config_mixins: List[ArgsConfigMixin] = [] + _arg_group_map: Dict[str, ArgsConfigMixin] = {} + + def __init__(self): + self._arg_group_map: Dict[str, ArgsConfigMixin] = { + "attention_provider_args": self.attention_provider_args, + } + + for arg_config_mixin in self._arg_group_map.values(): + self.register_args(arg_config_mixin) + + def to_dict(self) -> Dict[str, Any]: + parallel_arguments = { + "pp_degree": self.pp_degree, + "dp_degree": self.dp_degree, + "dp_shards": self.dp_shards, + "cp_degree": self.cp_degree, + "tp_degree": self.tp_degree, + } + + model_arguments = { + "model_name": self.model_name, + "pretrained_model_name_or_path": self.pretrained_model_name_or_path, + "revision": self.revision, + "variant": self.variant, + "cache_dir": self.cache_dir, + "tokenizer_id": self.tokenizer_id, + "tokenizer_2_id": self.tokenizer_2_id, + "tokenizer_3_id": self.tokenizer_3_id, + "text_encoder_id": self.text_encoder_id, + "text_encoder_2_id": self.text_encoder_2_id, + "text_encoder_3_id": self.text_encoder_3_id, + "transformer_id": self.transformer_id, + "vae_id": self.vae_id, + "text_encoder_dtype": self.text_encoder_dtype, + "text_encoder_2_dtype": self.text_encoder_2_dtype, + "text_encoder_3_dtype": self.text_encoder_3_dtype, + "transformer_dtype": self.transformer_dtype, + "vae_dtype": self.vae_dtype, + "layerwise_upcasting_modules": self.layerwise_upcasting_modules, + "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, + "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, + } + model_arguments = get_non_null_items(model_arguments) + + dataset_arguments = { + "dataset_config": self.dataset_config, + "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size, + "enable_precomputation": self.enable_precomputation, + "precomputation_items": self.precomputation_items, + "precomputation_dir": self.precomputation_dir, + "precomputation_once": self.precomputation_once, + "precomputation_reuse": self.precomputation_reuse, + } + dataset_arguments = get_non_null_items(dataset_arguments) + + dataloader_arguments = { + "dataloader_num_workers": self.dataloader_num_workers, + "pin_memory": self.pin_memory, + } + + diffusion_arguments = { + "flow_resolution_shifting": self.flow_resolution_shifting, + "flow_base_seq_len": self.flow_base_seq_len, + "flow_max_seq_len": self.flow_max_seq_len, + "flow_base_shift": self.flow_base_shift, + "flow_max_shift": self.flow_max_shift, + "flow_shift": self.flow_shift, + "flow_weighting_scheme": self.flow_weighting_scheme, + "flow_logit_mean": self.flow_logit_mean, + "flow_logit_std": self.flow_logit_std, + "flow_mode_scale": self.flow_mode_scale, + } + + training_arguments = { + "training_type": self.training_type, + "seed": self.seed, + "batch_size": self.batch_size, + "train_steps": self.train_steps, + "max_data_samples": self.max_data_samples, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "gradient_checkpointing": self.gradient_checkpointing, + "checkpointing_steps": self.checkpointing_steps, + "checkpointing_limit": self.checkpointing_limit, + "resume_from_checkpoint": self.resume_from_checkpoint, + "enable_slicing": self.enable_slicing, + "enable_tiling": self.enable_tiling, + } + training_arguments = get_non_null_items(training_arguments) + + optimizer_arguments = { + "optimizer": self.optimizer, + "lr": self.lr, + "lr_scheduler": self.lr_scheduler, + "lr_warmup_steps": self.lr_warmup_steps, + "lr_num_cycles": self.lr_num_cycles, + "lr_power": self.lr_power, + "beta1": self.beta1, + "beta2": self.beta2, + "beta3": self.beta3, + "weight_decay": self.weight_decay, + "epsilon": self.epsilon, + "max_grad_norm": self.max_grad_norm, + } + optimizer_arguments = get_non_null_items(optimizer_arguments) + + validation_arguments = { + "validation_dataset_file": self.validation_dataset_file, + "validation_steps": self.validation_steps, + "enable_model_cpu_offload": self.enable_model_cpu_offload, + } + validation_arguments = get_non_null_items(validation_arguments) + + miscellaneous_arguments = { + "tracker_name": self.tracker_name, + "push_to_hub": self.push_to_hub, + "hub_token": self.hub_token, + "hub_model_id": self.hub_model_id, + "output_dir": self.output_dir, + "logging_dir": self.logging_dir, + "logging_steps": self.logging_steps, + "init_timeout": self.init_timeout, + "nccl_timeout": self.nccl_timeout, + "report_to": self.report_to, + "verbose": self.verbose, + } + miscellaneous_arguments = get_non_null_items(miscellaneous_arguments) + + torch_config_arguments = { + "compile_modules": self.compile_modules, + "compile_scopes": self.compile_scopes, + "allow_tf32": self.allow_tf32, + "float32_matmul_precision": self.float32_matmul_precision, + } + + additional_arguments = {} + for config_mixin in self._registered_config_mixins: + additional_arguments[config_mixin.__class__.__name__] = config_mixin.to_dict() + + return { + "parallel_arguments": parallel_arguments, + "model_arguments": model_arguments, + "dataset_arguments": dataset_arguments, + "dataloader_arguments": dataloader_arguments, + "diffusion_arguments": diffusion_arguments, + "training_arguments": training_arguments, + "optimizer_arguments": optimizer_arguments, + "validation_arguments": validation_arguments, + "miscellaneous_arguments": miscellaneous_arguments, + "additional_arguments": additional_arguments, + "torch_config_arguments": torch_config_arguments, + } + + def register_args(self, config: ArgsConfigMixin) -> None: + if not hasattr(self, "_extended_add_arguments"): + self._extended_add_arguments = [] + self._extended_add_arguments.append((config.add_args, config.validate_args, config.map_args)) + self._registered_config_mixins.append(config) + + def parse_args(self): + _LIST_MODELS = "--list_models" + + parser = argparse.ArgumentParser() + + special_args = [_LIST_MODELS] + if any(arg in sys.argv for arg in special_args): + _add_helper_arguments(parser) + args = parser.parse_args() + _display_helper_messages(args) + sys.exit(0) + else: + _add_args(parser) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + add_fn, _, _ = extended_add_arg_fns + add_fn(parser) + + args, remaining_args = parser.parse_known_args() + logger.debug(f"Remaining unparsed arguments: {remaining_args}") + + mapped_args = _map_to_args_type(args) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, _, map_fn = extended_add_arg_fns + map_fn(args, mapped_args) + + _validate_args(mapped_args) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, validate_fn, _ = extended_add_arg_fns + validate_fn(mapped_args) + + return mapped_args + + def __getattribute__(self, name: str): + try: + return object.__getattribute__(self, name) + except AttributeError: + for arg_group in self._arg_group_map.values(): + if hasattr(arg_group, name): + return getattr(arg_group, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__: + object.__setattr__(self, name, value) + return + for arg_group in self._arg_group_map.values(): + if hasattr(arg_group, name): + setattr(arg_group, name, value) + return + object.__setattr__(self, name, value) + + +def _add_args(parser: argparse.ArgumentParser) -> None: + _add_parallel_arguments(parser) + _add_model_arguments(parser) + _add_dataset_arguments(parser) + _add_dataloader_arguments(parser) + _add_diffusion_arguments(parser) + _add_training_arguments(parser) + _add_optimizer_arguments(parser) + _add_validation_arguments(parser) + _add_miscellaneous_arguments(parser) + _add_torch_config_arguments(parser) + + +def _validate_args(args: BaseArgs): + _validate_model_args(args) + _validate_dataset_args(args) + _validate_validation_args(args) + + +def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--parallel_backend", + type=str, + default=ParallelBackendEnum.ACCELERATE, + choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD], + ) + parser.add_argument("--pp_degree", type=int, default=1) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--dp_shards", type=int, default=1) + parser.add_argument("--cp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + + +def _add_model_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()] + ) + parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--variant", type=str, default=None) + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--tokenizer_id", type=str, default=None) + parser.add_argument("--tokenizer_2_id", type=str, default=None) + parser.add_argument("--tokenizer_3_id", type=str, default=None) + parser.add_argument("--text_encoder_id", type=str, default=None) + parser.add_argument("--text_encoder_2_id", type=str, default=None) + parser.add_argument("--text_encoder_3_id", type=str, default=None) + parser.add_argument("--transformer_id", type=str, default=None) + parser.add_argument("--vae_id", type=str, default=None) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16") + parser.add_argument("--transformer_dtype", type=str, default="bf16") + parser.add_argument("--vae_dtype", type=str, default="bf16") + parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"]) + parser.add_argument( + "--layerwise_upcasting_storage_dtype", + type=str, + default="float8_e4m3fn", + choices=["float8_e4m3fn", "float8_e5m2"], + ) + parser.add_argument( + "--layerwise_upcasting_skip_modules_pattern", + type=str, + default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], + nargs="+", + ) + + +def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--dataset_config", type=str, required=True) + parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1) + parser.add_argument("--enable_precomputation", action="store_true") + parser.add_argument("--precomputation_items", type=int, default=512) + parser.add_argument("--precomputation_dir", type=str, default=None) + parser.add_argument("--precomputation_once", action="store_true") + parser.add_argument("--precomputation_reuse", action="store_true") + + +def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--dataloader_num_workers", type=int, default=0) + parser.add_argument("--pin_memory", action="store_true") + + +def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--flow_resolution_shifting", action="store_true") + parser.add_argument("--flow_base_seq_len", type=int, default=256) + parser.add_argument("--flow_max_seq_len", type=int, default=4096) + parser.add_argument("--flow_base_shift", type=float, default=0.5) + parser.add_argument("--flow_max_shift", type=float, default=1.15) + parser.add_argument("--flow_shift", type=float, default=1.0) + parser.add_argument( + "--flow_weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument("--flow_logit_mean", type=float, default=0.0) + parser.add_argument("--flow_logit_std", type=float, default=1.0) + parser.add_argument("--flow_mode_scale", type=float, default=1.29) + + +def _add_training_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True + ) + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--train_steps", type=int, default=1000) + parser.add_argument("--max_data_samples", type=int, default=2**64) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_checkpointing", action="store_true") + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--checkpointing_limit", type=int, default=None) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--enable_slicing", action="store_true") + parser.add_argument("--enable_tiling", action="store_true") + + +def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="constant") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--lr_num_cycles", type=int, default=1) + parser.add_argument("--lr_power", type=float, default=1.0) + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"], + ) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument("--beta3", type=float, default=None) + parser.add_argument("--weight_decay", type=float, default=1e-04) + parser.add_argument("--epsilon", type=float, default=1e-8) + parser.add_argument("--max_grad_norm", default=1.0, type=float) + + +def _add_validation_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--validation_dataset_file", type=str, default=None) + parser.add_argument("--validation_steps", type=int, default=500) + parser.add_argument("--enable_model_cpu_offload", action="store_true") + + +def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default="finetrainers") + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--output_dir", type=str, default="finetrainers-training") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--init_timeout", type=int, default=300) + parser.add_argument("--nccl_timeout", type=int, default=600) + parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"]) + parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3]) + + +def _add_torch_config_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--compile_modules", type=str, default=[], nargs="+") + parser.add_argument("--compile_scopes", type=str, default=None, nargs="+") + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument( + "--float32_matmul_precision", + type=str, + default="highest", + choices=["highest", "high", "medium"], + help="The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].", + ) + + +def _add_helper_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--list_models", action="store_true") + + +_DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, +} + + +def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs: + result_args = BaseArgs() + + # Parallel arguments + result_args.parallel_backend = args.parallel_backend + result_args.pp_degree = args.pp_degree + result_args.dp_degree = args.dp_degree + result_args.dp_shards = args.dp_shards + result_args.cp_degree = args.cp_degree + result_args.tp_degree = args.tp_degree + + # Model arguments + compile_scopes = args.compile_scopes + if len(args.compile_modules) > 0: + if compile_scopes is None: + compile_scopes = "regional" + if isinstance(compile_scopes, list) and len(compile_scopes) == 1: + compile_scopes = compile_scopes[0] + if isinstance(compile_scopes, str): + compile_scopes = [compile_scopes] * len(args.compile_modules) + else: + compile_scopes = [] + + result_args.model_name = args.model_name + result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path + result_args.revision = args.revision + result_args.variant = args.variant + result_args.cache_dir = args.cache_dir + result_args.tokenizer_id = args.tokenizer_id + result_args.tokenizer_2_id = args.tokenizer_2_id + result_args.tokenizer_3_id = args.tokenizer_3_id + result_args.text_encoder_id = args.text_encoder_id + result_args.text_encoder_2_id = args.text_encoder_2_id + result_args.text_encoder_3_id = args.text_encoder_3_id + result_args.transformer_id = args.transformer_id + result_args.vae_id = args.vae_id + result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype] + result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype] + result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] + result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype] + result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype] + result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules + result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype] + result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern + + # Dataset arguments + result_args.dataset_config = args.dataset_config + result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size + result_args.enable_precomputation = args.enable_precomputation + result_args.precomputation_items = args.precomputation_items + result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed") + result_args.precomputation_once = args.precomputation_once + result_args.precomputation_reuse = args.precomputation_reuse + + # Dataloader arguments + result_args.dataloader_num_workers = args.dataloader_num_workers + result_args.pin_memory = args.pin_memory + + # Diffusion arguments + result_args.flow_resolution_shifting = args.flow_resolution_shifting + result_args.flow_base_seq_len = args.flow_base_seq_len + result_args.flow_max_seq_len = args.flow_max_seq_len + result_args.flow_base_shift = args.flow_base_shift + result_args.flow_max_shift = args.flow_max_shift + result_args.flow_shift = args.flow_shift + result_args.flow_weighting_scheme = args.flow_weighting_scheme + result_args.flow_logit_mean = args.flow_logit_mean + result_args.flow_logit_std = args.flow_logit_std + result_args.flow_mode_scale = args.flow_mode_scale + + # Training arguments + result_args.training_type = args.training_type + result_args.seed = args.seed + result_args.batch_size = args.batch_size + result_args.train_steps = args.train_steps + result_args.max_data_samples = args.max_data_samples + result_args.gradient_accumulation_steps = args.gradient_accumulation_steps + result_args.gradient_checkpointing = args.gradient_checkpointing + result_args.checkpointing_steps = args.checkpointing_steps + result_args.checkpointing_limit = args.checkpointing_limit + result_args.resume_from_checkpoint = args.resume_from_checkpoint + result_args.enable_slicing = args.enable_slicing + result_args.enable_tiling = args.enable_tiling + + # Optimizer arguments + result_args.optimizer = args.optimizer or "adamw" + result_args.lr = args.lr or 1e-4 + result_args.lr_scheduler = args.lr_scheduler + result_args.lr_warmup_steps = args.lr_warmup_steps + result_args.lr_num_cycles = args.lr_num_cycles + result_args.lr_power = args.lr_power + result_args.beta1 = args.beta1 + result_args.beta2 = args.beta2 + result_args.beta3 = args.beta3 + result_args.weight_decay = args.weight_decay + result_args.epsilon = args.epsilon + result_args.max_grad_norm = args.max_grad_norm + + # Validation arguments + result_args.validation_dataset_file = args.validation_dataset_file + result_args.validation_steps = args.validation_steps + result_args.enable_model_cpu_offload = args.enable_model_cpu_offload + + # Miscellaneous arguments + result_args.tracker_name = args.tracker_name + result_args.push_to_hub = args.push_to_hub + result_args.hub_token = args.hub_token + result_args.hub_model_id = args.hub_model_id + result_args.output_dir = args.output_dir + result_args.logging_dir = args.logging_dir + result_args.logging_steps = args.logging_steps + result_args.init_timeout = args.init_timeout + result_args.nccl_timeout = args.nccl_timeout + result_args.report_to = args.report_to + result_args.verbose = args.verbose + + # Torch config arguments + result_args.compile_modules = args.compile_modules + result_args.compile_scopes = compile_scopes + result_args.allow_tf32 = args.allow_tf32 + result_args.float32_matmul_precision = args.float32_matmul_precision + + return result_args + + +def _validate_model_args(args: BaseArgs): + if args.training_type == "full-finetune": + assert "transformer" not in args.layerwise_upcasting_modules, ( + "Layerwise upcasting is not supported for full-finetune training" + ) + if len(args.compile_modules) > 0: + assert len(args.compile_modules) == len(args.compile_scopes) and all( + x in ["regional", "full"] for x in args.compile_scopes + ), ( + "Compile modules and compile scopes must be of the same length and compile scopes must be either 'regional' or 'full'" + ) + + +def _validate_dataset_args(args: BaseArgs): + dataset_config = pathlib.Path(args.dataset_config) + if not dataset_config.exists(): + raise ValueError(f"Dataset config file {args.dataset_config} does not exist.") + if args.dataset_shuffle_buffer_size < 1: + raise ValueError("Dataset shuffle buffer size must be greater than 0.") + if args.precomputation_items < 1: + raise ValueError("Precomputation items must be greater than 0.") + + +def _validate_validation_args(args: BaseArgs): + if args.enable_model_cpu_offload: + if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]): + raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.") + + +def _display_helper_messages(args: argparse.Namespace): + if args.list_models: + print("Supported models:") + for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()): + print(f" {index + 1}. {model_name}") diff --git a/docs/finetrainers-src-codebase/finetrainers/config.py b/docs/finetrainers-src-codebase/finetrainers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..46e713e9b6ea9314a63d994875900d2a5facf3bd --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/config.py @@ -0,0 +1,75 @@ +from enum import Enum +from typing import Type + +from .models import ModelSpecification +from .models.cogvideox import CogVideoXModelSpecification +from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification +from .models.flux import FluxModelSpecification +from .models.hunyuan_video import HunyuanVideoModelSpecification +from .models.ltx_video import LTXVideoModelSpecification +from .models.wan import WanControlModelSpecification, WanModelSpecification + + +class ModelType(str, Enum): + COGVIDEOX = "cogvideox" + COGVIEW4 = "cogview4" + FLUX = "flux" + HUNYUAN_VIDEO = "hunyuan_video" + LTX_VIDEO = "ltx_video" + WAN = "wan" + + +class TrainingType(str, Enum): + # SFT + LORA = "lora" + FULL_FINETUNE = "full-finetune" + + # Control + CONTROL_LORA = "control-lora" + CONTROL_FULL_FINETUNE = "control-full-finetune" + + +SUPPORTED_MODEL_CONFIGS = { + # TODO(aryan): autogenerate this + # SFT + ModelType.COGVIDEOX: { + TrainingType.LORA: CogVideoXModelSpecification, + TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, + }, + ModelType.COGVIEW4: { + TrainingType.LORA: CogView4ModelSpecification, + TrainingType.FULL_FINETUNE: CogView4ModelSpecification, + TrainingType.CONTROL_LORA: CogView4ControlModelSpecification, + TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification, + }, + ModelType.FLUX: { + TrainingType.LORA: FluxModelSpecification, + TrainingType.FULL_FINETUNE: FluxModelSpecification, + }, + ModelType.HUNYUAN_VIDEO: { + TrainingType.LORA: HunyuanVideoModelSpecification, + TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, + }, + ModelType.LTX_VIDEO: { + TrainingType.LORA: LTXVideoModelSpecification, + TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, + }, + ModelType.WAN: { + TrainingType.LORA: WanModelSpecification, + TrainingType.FULL_FINETUNE: WanModelSpecification, + TrainingType.CONTROL_LORA: WanControlModelSpecification, + TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification, + }, +} + + +def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: + if model_name not in SUPPORTED_MODEL_CONFIGS: + raise ValueError( + f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" + ) + if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: + raise ValueError( + f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" + ) + return SUPPORTED_MODEL_CONFIGS[model_name][training_type] diff --git a/docs/finetrainers-src-codebase/finetrainers/constants.py b/docs/finetrainers-src-codebase/finetrainers/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..bd45d2925af9541608170b8f73244f27b13471d2 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/constants.py @@ -0,0 +1,87 @@ +import os + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + +FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO") +FINETRAINERS_ATTN_PROVIDER = os.environ.get("FINETRAINERS_ATTN_PROVIDER", "native") +FINETRAINERS_ATTN_CHECKS = os.getenv("FINETRAINERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES +FINETRAINERS_ENABLE_TIMING = os.getenv("FINETRAINERS_ENABLE_TIMING", "1") in ENV_VARS_TRUE_VALUES + +DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +DEFAULT_FRAME_BUCKETS = [49] + +DEFAULT_IMAGE_RESOLUTION_BUCKETS = [] +for height in DEFAULT_HEIGHT_BUCKETS: + for width in DEFAULT_WIDTH_BUCKETS: + DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width)) + +DEFAULT_VIDEO_RESOLUTION_BUCKETS = [] +for frames in DEFAULT_FRAME_BUCKETS: + for height in DEFAULT_HEIGHT_BUCKETS: + for width in DEFAULT_WIDTH_BUCKETS: + DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width)) + +PRECOMPUTED_DIR_NAME = "precomputed" +PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions" +PRECOMPUTED_LATENTS_DIR_NAME = "latents" + +MODEL_DESCRIPTION = r""" +\# {model_id} {training_type} finetune + + + +\#\# Model Description + +This model is a {training_type} of the `{model_id}` model. + +This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers). + +\#\# Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +\#\# Usage + +Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed. + +```python +{model_example} +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +\#\# License + +Please adhere to the license of the base model. +""".strip() + +_COMMON_BEGINNING_PHRASES = ( + "This video", + "The video", + "This clip", + "The clip", + "The animation", + "This image", + "The image", + "This picture", + "The picture", +) +_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents") + +COMMON_LLM_START_PHRASES = ( + "In the video,", + "In this video,", + "In this video clip,", + "In the clip,", + "Caption:", + *( + f"{beginning} {continuation}" + for beginning in _COMMON_BEGINNING_PHRASES + for continuation in _COMMON_CONTINUATION_WORDS + ), +) + +SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png") +SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov") diff --git a/docs/finetrainers-src-codebase/finetrainers/data/__init__.py b/docs/finetrainers-src-codebase/finetrainers/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2025f19f9fde243cbbc998cbb58d330d70f9544 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/__init__.py @@ -0,0 +1,26 @@ +from ._artifact import ImageArtifact, VideoArtifact +from .dataloader import DPDataLoader +from .dataset import ( + ImageCaptionFilePairDataset, + ImageFileCaptionFileListDataset, + ImageFolderDataset, + ImageWebDataset, + ValidationDataset, + VideoCaptionFilePairDataset, + VideoFileCaptionFileListDataset, + VideoFolderDataset, + VideoWebDataset, + combine_datasets, + initialize_dataset, + wrap_iterable_dataset_for_preprocessing, +) +from .precomputation import ( + InMemoryDataIterable, + InMemoryDistributedDataPreprocessor, + InMemoryOnceDataIterable, + PrecomputedDataIterable, + PrecomputedDistributedDataPreprocessor, + PrecomputedOnceDataIterable, + initialize_preprocessor, +) +from .sampler import ResolutionSampler diff --git a/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py b/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..400f25d143f5062d77ed6391ca9862654d295de7 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py @@ -0,0 +1,29 @@ +# ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT ===== + +from dataclasses import dataclass +from typing import Any, List + +from PIL.Image import Image + + +@dataclass +class Artifact: + type: str + value: Any + file_extension: str + + +@dataclass +class ImageArtifact(Artifact): + value: Image + + def __init__(self, value: Image): + super().__init__(type="image", value=value, file_extension="png") + + +@dataclass +class VideoArtifact(Artifact): + value: List[Image] + + def __init__(self, value: List[Image]): + super().__init__(type="video", value=value, file_extension="mp4") diff --git a/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py b/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..752229489de69a684b395c79e5a2799c3f747596 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py @@ -0,0 +1,40 @@ +import pickle +from typing import Any, Dict + +import torch.distributed.checkpoint.stateful +import torchdata.stateful_dataloader + +from finetrainers.logging import get_logger + + +logger = get_logger() + + +class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful): + def __init__( + self, + rank: int, + dataset: torch.utils.data.IterableDataset, + batch_size: int = 1, + num_workers: int = 0, + collate_fn=None, + ) -> None: + super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) + + self._dp_rank = rank + self._rank_id = f"dp_rank_{rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}") + return + + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) diff --git a/docs/finetrainers-src-codebase/finetrainers/data/dataset.py b/docs/finetrainers-src-codebase/finetrainers/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5416e6151a52404d174a9939279420d46dbe232e --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/dataset.py @@ -0,0 +1,1040 @@ +import pathlib +import random +from typing import Any, Dict, List, Optional, Tuple, Union + +import datasets +import datasets.data_files +import datasets.distributed +import datasets.exceptions +import huggingface_hub +import huggingface_hub.errors +import numpy as np +import PIL.Image +import PIL.JpegImagePlugin +import torch +import torch.distributed.checkpoint.stateful +import torchvision +from diffusers.utils import load_image, load_video +from huggingface_hub import list_repo_files, repo_exists, snapshot_download +from tqdm.auto import tqdm + +from finetrainers import constants +from finetrainers import functional as FF +from finetrainers.logging import get_logger +from finetrainers.utils import find_files +from finetrainers.utils.import_utils import is_datasets_version + + +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger() + + +# fmt: off +MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024 +COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] +COMMON_VIDEO_FILES = ["video.txt", "videos.txt"] +COMMON_IMAGE_FILES = ["image.txt", "images.txt"] +COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"] +# fmt: on + + +class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + caption_files = sorted(find_files(self.root.as_posix(), "*.txt", depth=0)) + for caption_file in caption_files: + data_file = self._find_data_file(caption_file) + if data_file: + data.append( + { + "caption": (self.root / caption_file).as_posix(), + "image": (self.root / data_file).as_posix(), + } + ) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("image", datasets.Image(mode="RGB")) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["caption"] = _read_caption_from_file(sample["caption"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + def _find_data_file(self, caption_file: str) -> str: + caption_file = pathlib.Path(caption_file) + data_file = None + found_data = 0 + + for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: + image_filename = caption_file.with_suffix(f".{extension}") + if image_filename.exists(): + found_data += 1 + data_file = image_filename + + if found_data == 0: + return False + elif found_data > 1: + raise ValueError( + f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " + f"file per caption file. The following extensions are supported:\n" + f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n" + ) + + return data_file.as_posix() + + +class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + caption_files = sorted(find_files(self.root.as_posix(), "*.txt", depth=0)) + for caption_file in caption_files: + data_file = self._find_data_file(caption_file) + if data_file: + data.append( + { + "caption": (self.root / caption_file).as_posix(), + "video": (self.root / data_file).as_posix(), + } + ) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("video", datasets.Video()) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["caption"] = _read_caption_from_file(sample["caption"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + def _find_data_file(self, caption_file: str) -> str: + caption_file = pathlib.Path(caption_file) + data_file = None + found_data = 0 + + for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: + video_filename = caption_file.with_suffix(f".{extension}") + if video_filename.exists(): + found_data += 1 + data_file = video_filename + + if found_data == 0: + return False + elif found_data > 1: + raise ValueError( + f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " + f"file per caption file. The following extensions are supported:\n" + f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n" + ) + + return data_file.as_posix() + + +class ImageFileCaptionFileListDataset( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] + VALID_IMAGE_FILES = ["image.txt", "images.txt"] + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] + existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()] + + if len(existing_caption_files) == 0: + raise FileNotFoundError( + f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_image_files) == 0: + raise FileNotFoundError( + f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" + ) + if len(existing_caption_files) > 1: + raise ValueError( + f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_image_files) > 1: + raise ValueError( + f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" + ) + + caption_file = existing_caption_files[0] + image_file = existing_image_files[0] + + with open((self.root / caption_file).as_posix(), "r") as f: + captions = f.read().splitlines() + with open((self.root / image_file).as_posix(), "r") as f: + images = f.read().splitlines() + images = [(self.root / image).as_posix() for image in images] + + if len(captions) != len(images): + raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})") + + for caption, image in zip(captions, images): + data.append({"caption": caption, "image": image}) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("image", datasets.Image(mode="RGB")) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoFileCaptionFileListDataset( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] + VALID_VIDEO_FILES = ["video.txt", "videos.txt"] + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] + existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()] + + if len(existing_caption_files) == 0: + raise FileNotFoundError( + f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_video_files) == 0: + raise FileNotFoundError( + f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" + ) + if len(existing_caption_files) > 1: + raise ValueError( + f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_video_files) > 1: + raise ValueError( + f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" + ) + + caption_file = existing_caption_files[0] + video_file = existing_video_files[0] + + with open((self.root / caption_file).as_posix(), "r") as f: + captions = f.read().splitlines() + with open((self.root / video_file).as_posix(), "r") as f: + videos = f.read().splitlines() + videos = [(self.root / video).as_posix() for video in videos] + + if len(captions) != len(videos): + raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})") + + for caption, video in zip(captions, videos): + data.append({"caption": caption, "video": video}) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("video", datasets.Video()) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train") + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train") + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__( + self, + dataset_name: str, + infinite: bool = False, + column_names: Union[str, List[str]] = "__auto__", + weights: Dict[str, float] = -1, + **kwargs, + ) -> None: + super().__init__() + + assert weights == -1 or isinstance(weights, dict), ( + "`weights` must be a dictionary of probabilities for each caption column" + ) + + self.dataset_name = dataset_name + self.infinite = infinite + + data = datasets.load_dataset(dataset_name, split="train", streaming=True) + + if column_names == "__auto__": + if weights == -1: + caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] + if len(caption_columns) == 0: + raise ValueError( + f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}. " + f"Available columns are: {data.column_names}" + ) + weights = [1] * len(caption_columns) + else: + caption_columns = list(weights.keys()) + weights = list(weights.values()) + if not all(column in data.column_names for column in caption_columns): + raise ValueError( + f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" + ) + else: + if isinstance(column_names, str): + if column_names not in data.column_names: + raise ValueError( + f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" + ) + caption_columns = [column_names] + weights = [1] if weights == -1 else [weights.get(column_names)] + elif isinstance(column_names, list): + if not all(column in data.column_names for column in column_names): + raise ValueError( + f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" + ) + caption_columns = column_names + weights = [1] if weights == -1 else [weights.get(column) for column in column_names] + else: + raise ValueError(f"Unsupported type for column_name: {type(column_names)}") + + for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: + if column_names in data.column_names: + data = data.cast_column(column_names, datasets.Image(mode="RGB")) + data = data.rename_column(column_names, "image") + break + + self._data = data + self._sample_index = 0 + self._precomputable_once = False + self._caption_columns = caption_columns + self._weights = weights + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] + sample["caption"] = sample[caption_column] + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_index = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__( + self, + dataset_name: str, + infinite: bool = False, + column_names: Union[str, List[str]] = "__auto__", + weights: Dict[str, float] = -1, + **kwargs, + ) -> None: + super().__init__() + + assert weights == -1 or isinstance(weights, dict), ( + "`weights` must be a dictionary of probabilities for each caption column" + ) + + self.dataset_name = dataset_name + self.infinite = infinite + + data = datasets.load_dataset(dataset_name, split="train", streaming=True) + + if column_names == "__auto__": + if weights == -1: + caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] + if len(caption_columns) == 0: + raise ValueError( + f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}" + ) + weights = [1] * len(caption_columns) + else: + caption_columns = list(weights.keys()) + weights = list(weights.values()) + if not all(column in data.column_names for column in caption_columns): + raise ValueError( + f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" + ) + else: + if isinstance(column_names, str): + if column_names not in data.column_names: + raise ValueError( + f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" + ) + caption_columns = [column_names] + weights = [1] if weights == -1 else [weights.get(column_names)] + elif isinstance(column_names, list): + if not all(column in data.column_names for column in column_names): + raise ValueError( + f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" + ) + caption_columns = column_names + weights = [1] if weights == -1 else [weights.get(column) for column in column_names] + else: + raise ValueError(f"Unsupported type for column_name: {type(column_names)}") + + for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: + if column_names in data.column_names: + data = data.cast_column(column_names, datasets.Video()) + data = data.rename_column(column_names, "video") + break + + self._data = data + self._sample_index = 0 + self._precomputable_once = False + self._caption_columns = caption_columns + self._weights = weights + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] + sample["caption"] = sample[caption_column] + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_index = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ValidationDataset(torch.utils.data.IterableDataset): + def __init__(self, filename: str): + super().__init__() + + self.filename = pathlib.Path(filename) + + if not self.filename.exists(): + raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist") + + if self.filename.suffix == ".csv": + data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train") + elif self.filename.suffix == ".json": + data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data") + elif self.filename.suffix == ".parquet": + data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train") + elif self.filename.suffix == ".arrow": + data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train") + else: + _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"] + raise ValueError( + f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}" + ) + + self._data = data.to_iterable_dataset() + + def __iter__(self): + for sample in self._data: + # For consistency reasons, we mandate that "caption" is always present in the validation dataset. + # However, since the model specifications use "prompt", we create an alias here. + sample["prompt"] = sample["caption"] + + # Load image or video if the path is provided + # TODO(aryan): need to handle custom columns here for control conditions + sample["image"] = None + sample["video"] = None + + if sample.get("image_path", None) is not None: + image_path = sample["image_path"] + if not pathlib.Path(image_path).is_file() and not image_path.startswith("http"): + logger.warning(f"Image file {image_path.as_posix()} does not exist.") + else: + sample["image"] = load_image(sample["image_path"]) + + if sample.get("video_path", None) is not None: + video_path = sample["video_path"] + if not pathlib.Path(video_path).is_file() and not video_path.startswith("http"): + logger.warning(f"Video file {video_path.as_posix()} does not exist.") + else: + sample["video"] = load_video(sample["video_path"]) + + if sample.get("control_image_path", None) is not None: + control_image_path = sample["control_image_path"] + if not pathlib.Path(control_image_path).is_file() and not control_image_path.startswith("http"): + logger.warning(f"Control Image file {control_image_path.as_posix()} does not exist.") + else: + sample["control_image"] = load_image(sample["control_image_path"]) + + if sample.get("control_video_path", None) is not None: + control_video_path = sample["control_video_path"] + if not pathlib.Path(control_video_path).is_file() and not control_video_path.startswith("http"): + logger.warning(f"Control Video file {control_video_path.as_posix()} does not exist.") + else: + sample["control_video"] = load_video(sample["control_video_path"]) + + sample = {k: v for k, v in sample.items() if v is not None} + yield sample + + +class IterableDatasetPreprocessingWrapper( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__( + self, + dataset: torch.utils.data.IterableDataset, + dataset_type: str, + id_token: Optional[str] = None, + image_resolution_buckets: List[Tuple[int, int]] = None, + video_resolution_buckets: List[Tuple[int, int, int]] = None, + rename_columns: Optional[Dict[str, str]] = None, + drop_columns: Optional[List[str]] = None, + reshape_mode: str = "bicubic", + remove_common_llm_caption_prefixes: bool = False, + **kwargs, + ): + super().__init__() + + self.dataset = dataset + self.dataset_type = dataset_type + self.id_token = id_token + self.image_resolution_buckets = image_resolution_buckets + self.video_resolution_buckets = video_resolution_buckets + self.rename_columns = rename_columns or {} + self.drop_columns = drop_columns or [] + self.reshape_mode = reshape_mode + self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes + + logger.info( + f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n" + f" - Dataset Type: {dataset_type}\n" + f" - ID Token: {id_token}\n" + f" - Image Resolution Buckets: {image_resolution_buckets}\n" + f" - Video Resolution Buckets: {video_resolution_buckets}\n" + f" - Rename Columns: {rename_columns}\n" + f" - Reshape Mode: {reshape_mode}\n" + f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n" + ) + + def __iter__(self): + logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset") + for sample in iter(self.dataset): + for column in self.drop_columns: + sample.pop(column, None) + + sample = {self.rename_columns.get(k, k): v for k, v in sample.items()} + + for key in sample.keys(): + if isinstance(sample[key], PIL.Image.Image): + sample[key] = _preprocess_image(sample[key]) + elif isinstance(sample[key], (decord.VideoReader, torchvision.io.video_reader.VideoReader)): + sample[key] = _preprocess_video(sample[key]) + + if self.dataset_type == "image": + if self.image_resolution_buckets: + sample["_original_num_frames"] = 1 + sample["_original_height"] = sample["image"].size(1) + sample["_original_width"] = sample["image"].size(2) + sample["image"] = FF.resize_to_nearest_bucket_image( + sample["image"], self.image_resolution_buckets, self.reshape_mode + ) + elif self.dataset_type == "video": + if self.video_resolution_buckets: + sample["_original_num_frames"] = sample["video"].size(0) + sample["_original_height"] = sample["video"].size(2) + sample["_original_width"] = sample["video"].size(3) + sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video( + sample["video"], self.video_resolution_buckets, self.reshape_mode + ) + if _first_frame_only: + msg = ( + "The number of frames in the video is less than the minimum bucket size " + "specified. The first frame is being used as a single frame video. This " + "message is logged at the first occurence and for every 128th occurence " + "after that." + ) + logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128) + sample["video"] = sample["video"][:1] + + caption = sample["caption"] + if isinstance(caption, list): + caption = caption[0] + if caption.startswith("b'") and caption.endswith("'"): + caption = FF.convert_byte_str_to_str(caption) + if self.remove_common_llm_caption_prefixes: + caption = FF.remove_prefix(caption, constants.COMMON_LLM_START_PHRASES) + if self.id_token is not None: + caption = f"{self.id_token} {caption}" + sample["caption"] = caption + + yield sample + + def load_state_dict(self, state_dict): + self.dataset.load_state_dict(state_dict["dataset"]) + + def state_dict(self): + return {"dataset": self.dataset.state_dict()} + + +class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False): + super().__init__() + + self.datasets = datasets + self.buffer_size = buffer_size + self.shuffle = shuffle + + logger.info( + f"Initializing IterableCombinedDataset with the following configuration:\n" + f" - Number of Datasets: {len(datasets)}\n" + f" - Buffer Size: {buffer_size}\n" + f" - Shuffle: {shuffle}\n" + ) + + def __iter__(self): + logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets") + iterators = [iter(dataset) for dataset in self.datasets] + buffer = [] + per_iter = max(1, self.buffer_size // len(iterators)) + + for index, it in enumerate(iterators): + for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"): + try: + buffer.append((it, next(it))) + except StopIteration: + continue + + while len(buffer) > 0: + idx = 0 + if self.shuffle: + idx = random.randint(0, len(buffer) - 1) + current_it, sample = buffer.pop(idx) + yield sample + try: + buffer.append((current_it, next(current_it))) + except StopIteration: + pass + + def load_state_dict(self, state_dict): + for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]): + dataset.load_state_dict(dataset_state_dict) + + def state_dict(self): + return {"datasets": [dataset.state_dict() for dataset in self.datasets]} + + +# TODO(aryan): maybe write a test for this +def initialize_dataset( + dataset_name_or_root: str, + dataset_type: str = "video", + streaming: bool = True, + infinite: bool = False, + *, + _caption_options: Optional[Dict[str, Any]] = None, +) -> torch.utils.data.IterableDataset: + assert dataset_type in ["image", "video"] + + try: + does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset") + except huggingface_hub.errors.HFValidationError: + does_repo_exist_on_hub = False + + if does_repo_exist_on_hub: + return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options) + else: + return _initialize_local_dataset( + dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options + ) + + +def combine_datasets( + datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False +) -> torch.utils.data.IterableDataset: + return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle) + + +def wrap_iterable_dataset_for_preprocessing( + dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any] +) -> torch.utils.data.IterableDataset: + return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config) + + +def _initialize_local_dataset( + dataset_name_or_root: str, + dataset_type: str, + infinite: bool = False, + *, + _caption_options: Optional[Dict[str, Any]] = None, +): + root = pathlib.Path(dataset_name_or_root) + supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"] + metadata_files = [root / metadata_file for metadata_file in supported_metadata_files] + metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()] + + if len(metadata_files) > 1: + raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.") + + if len(metadata_files) == 1: + if dataset_type == "image": + dataset = ImageFolderDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoFolderDataset(root.as_posix(), infinite=infinite) + return dataset + + file_list = find_files(root.as_posix(), "*", depth=100) + has_tar_or_parquet_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in file_list) + if has_tar_or_parquet_files: + return _initialize_webdataset(root.as_posix(), dataset_type, infinite, _caption_options=_caption_options) + + if _has_data_caption_file_pairs(root, remote=False): + if dataset_type == "image": + dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite) + elif _has_data_file_caption_file_lists(root, remote=False): + if dataset_type == "image": + dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + raise ValueError( + f"Could not find any supported dataset structure in the directory {root}. Please open an issue at " + f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will " + f"help you set it up." + ) + + return dataset + + +def _initialize_hub_dataset( + dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None +): + repo_file_list = list_repo_files(dataset_name, repo_type="dataset") + if _has_data_caption_file_pairs(repo_file_list, remote=True): + return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) + elif _has_data_file_caption_file_lists(repo_file_list, remote=True): + return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) + + has_tar_or_parquet_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list) + if has_tar_or_parquet_files: + return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options) + + # TODO(aryan): This should be improved + caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")] + if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT: + try: + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + dataset = ImageFolderDataset(dataset_root, infinite=infinite) + else: + dataset = VideoFolderDataset(dataset_root, infinite=infinite) + return dataset + except Exception: + pass + + raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub") + + +def _initialize_data_caption_file_dataset_from_hub( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Downloading dataset {dataset_name} from the HF Hub") + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + return ImageCaptionFilePairDataset(dataset_root, infinite=infinite) + else: + return VideoCaptionFilePairDataset(dataset_root, infinite=infinite) + + +def _initialize_data_file_caption_file_dataset_from_hub( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Downloading dataset {dataset_name} from the HF Hub") + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite) + else: + return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite) + + +def _initialize_webdataset( + dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None +) -> torch.utils.data.IterableDataset: + logger.info(f"Streaming webdataset {dataset_name} from the HF Hub") + _caption_options = _caption_options or {} + if dataset_type == "image": + return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options) + else: + return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options) + + +def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + caption_files = find_files(root.as_posix(), "*.txt", depth=0) + for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}") + if data_filename.exists(): + return True + return False + else: + caption_files = [file for file in root if file.endswith(".txt")] + for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}").name + if data_filename in root: + return True + return False + + +def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + file_list = {x.name for x in root.iterdir()} + has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES) + has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES) + has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + else: + has_caption_files = any(file in root for file in COMMON_CAPTION_FILES) + has_video_files = any(file in root for file in COMMON_VIDEO_FILES) + has_image_files = any(file in root for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + + +def _read_caption_from_file(filename: str) -> str: + with open(filename, "r") as f: + return f.read().strip() + + +def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: + image = image.convert("RGB") + image = np.array(image).astype(np.float32) + image = torch.from_numpy(image) + image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 + return image + + +if is_datasets_version("<", "3.4.0"): + + def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: + video = video.get_batch(list(range(len(video)))) + video = video.permute(0, 3, 1, 2).contiguous() + video = video.float() / 127.5 - 1.0 + return video + +else: + # Hardcode max frames for now. Ideally, we should allow user to set this and handle it in IterableDatasetPreprocessingWrapper + MAX_FRAMES = 4096 + + def _preprocess_video(video: torchvision.io.video_reader.VideoReader) -> torch.Tensor: + frames = [] + # Error driven data loading! torchvision does not expose length of video + try: + for _ in range(MAX_FRAMES): + frames.append(next(video)["data"]) + except StopIteration: + pass + video = torch.stack(frames) + video = video.float() / 127.5 - 1.0 + return video diff --git a/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py b/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py new file mode 100644 index 0000000000000000000000000000000000000000..3a33a80603558befd85ca43abf0a1ce39c6d94cd --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py @@ -0,0 +1,420 @@ +import pathlib +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +import torch +from tqdm.auto import tqdm + +from finetrainers.logging import get_logger +from finetrainers.utils import delete_files + + +logger = get_logger() + +PRECOMPUTED_DATA_DIR = "finetrainers-precomputed-data" + + +def initialize_preprocessor( + rank: int, + world_size: int, + num_items: int, + processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], + save_dir: Optional[str] = None, + enable_precomputation: bool = False, + enable_reuse: bool = False, +) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]: + if enable_precomputation: + return PrecomputedDistributedDataPreprocessor( + rank, world_size, num_items, processor_fn, save_dir, enable_reuse + ) + return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn) + + +class DistributedDataProcessorMixin: + def consume(self, *args, **kwargs): + raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.") + + def consume_once(self, *args, **kwargs): + raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.") + + @property + def requires_data(self): + raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.") + + +class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin): + def __init__( + self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] + ) -> None: + super().__init__() + + self._rank = rank + self._num_items = num_items + self._processor_fn = processor_fn + + self._cached_samples = [] + self._buffer = InMemoryDataBuffer(num_items) + self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None + + def consume( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + for i in range(self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + self._buffer.add(data_type, item) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + + self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer) + return iter(self._preprocessed_iterator) + + def consume_once( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + for i in range(self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + self._buffer.add(data_type, item) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + + self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer) + return iter(self._preprocessed_iterator) + + @property + def requires_data(self): + if self._preprocessed_iterator is None: + return True + return self._preprocessed_iterator.requires_data + + +class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin): + def __init__( + self, + rank: int, + world_size: int, + num_items: int, + processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], + save_dir: str, + enable_reuse: bool = False, + ) -> None: + super().__init__() + + self._rank = rank + self._world_size = world_size + self._num_items = num_items + self._processor_fn = processor_fn + self._save_dir = pathlib.Path(save_dir) / PRECOMPUTED_DATA_DIR + self._enable_reuse = enable_reuse + + self._cached_samples = [] + self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None + + if enable_reuse: + if not self._save_dir.exists() or not self._save_dir.is_dir(): + raise RuntimeError( + f"The directory '{self._save_dir}' does not exist or is not a directory, but is required when enabling reuse of precomputed data." + ) + logger.info(f"Reusing precomputed data from {self._save_dir}.") + else: + subdirectories = [] if not self._save_dir.exists() else [f for f in self._save_dir.iterdir() if f.is_dir()] + if len(subdirectories) > 0: + raise RuntimeError( + "The current directory contains subdirectories other than the saved precomputed files. Please remove them or enable precomputation reuse." + ) + # NOTE: this should be safe since we are adding `PRECOMPUTED_DATA_DIR` to the path, but be careful since + # delete_files can seriously mess up your filesystem if used incorrectly. + delete_files([self._save_dir]) + self._save_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Cleaned up any existing precomputed data in {self._save_dir} and created a fresh directory.") + + def consume( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + if not self._enable_reuse: + for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + index = self._rank * self._num_items + i + _save_item(item, index, self._save_dir, data_type) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + + if self._enable_reuse: + data_iterator = PrecomputedOnceDataIterable(self._rank, self._world_size, self._save_dir, data_type) + else: + data_iterator = PrecomputedDataIterable(self._rank, self._world_size, self._save_dir, data_type) + self._preprocessed_iterator = data_iterator + return iter(data_iterator) + + def consume_once( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + if not self._enable_reuse: + for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + index = self._rank * self._num_items + i + _save_item(item, index, self._save_dir, data_type) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + + self._preprocessed_iterator = PrecomputedOnceDataIterable( + self._rank, self._world_size, self._save_dir, data_type + ) + return iter(self._preprocessed_iterator) + + @property + def requires_data(self): + if self._preprocessed_iterator is None: + return True + return self._preprocessed_iterator.requires_data + + +class InMemoryDataIterable: + """ + An iterator that loads data items from an in-memory buffer. Once all the data is consumed, + `requires_data` is set to True, indicating that the more data is required and the preprocessor's + consume method should be called again. + """ + + def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: + self._rank = rank + self._data_type = data_type + self._buffer = buffer + + self._requires_data = False + + def __iter__(self) -> Iterable[Dict[str, Any]]: + while (length := self._buffer.get_length(self._data_type)) > 0: + if length <= 1: + self._requires_data = True + yield self._buffer.get(self._data_type) + + def __len__(self) -> int: + return self._buffer.get_length(self._data_type) + + @property + def requires_data(self): + return self._requires_data + + +class InMemoryOnceDataIterable: + """ + An iterator that loads data items from an in-memory buffer. This iterator will never set + `requires_data` to True, as it is assumed that all the data was configured to be preprocessed + by the user. The data will indefinitely be cycled from the buffer. + """ + + def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: + self._rank = rank + self._data_type = data_type + self._buffer = buffer + + self._requires_data = False + + def __iter__(self) -> Iterable[Dict[str, Any]]: + assert len(self) > 0, "No data available in the buffer." + while True: + item = self._buffer.get(self._data_type) + yield item + self._buffer.add(self._data_type, item) + + def __len__(self) -> int: + return self._buffer.get_length(self._data_type) + + @property + def requires_data(self): + return self._requires_data + + +class PrecomputedDataIterable: + """ + An iterator that loads preconfigured number of data items from disk. Once all the data is + loaded, `requires_data` is set to True, indicating that the more data is required and + the preprocessor's consume method should be called again. + """ + + def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None: + self._rank = rank + self._world_size = world_size + self._save_dir = pathlib.Path(save_dir) + self._data_type = data_type + self._requires_data = False + + self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt"))) + + def __iter__(self) -> Iterable[Dict[str, Any]]: + map_location = torch.device(self._rank) + for i in range(self._num_items): + if i == self._num_items - 1: + self._requires_data = True + index = self._rank * self._num_items + i + yield _load_item(index, self._save_dir, self._data_type, map_location) + + def __len__(self) -> int: + return self._num_items + + @property + def requires_data(self): + return self._requires_data + + +class PrecomputedOnceDataIterable: + """ + An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator + will never set `requires_data` to True, as it is assumed that all the data was configured to + be preprocessed by the user. + """ + + def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None: + self._rank = rank + self._world_size = world_size + self._save_dir = pathlib.Path(save_dir) + self._data_type = data_type + self._requires_data = False + + self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt"))) + if self._num_items <= self._rank: + raise ValueError( + f"Precomputed data directory is empty or does not contain enough items (required {self._rank + 1}, found {self._num_items})." + ) + self._num_items_per_rank = max(1, self._num_items // world_size) + + def __iter__(self) -> Iterable[Dict[str, Any]]: + map_location = torch.device(self._rank) + i = 0 + while True: + index = self._rank * self._num_items_per_rank + i + yield _load_item(index, self._save_dir, self._data_type, map_location) + i = (i + 1) % self._num_items_per_rank + + def __len__(self) -> int: + return self._num_items_per_rank + + @property + def requires_data(self): + return self._requires_data + + +class InMemoryDataBuffer: + def __init__(self, max_limit: int = -1) -> None: + self.max_limit = max_limit + self.buffer: Dict[str, List[str]] = {} + + def add(self, data_type: str, item: Dict[str, Any]) -> None: + if data_type not in self.buffer: + self.buffer[data_type] = [] + if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit: + logger.log_freq( + "WARN", + "IN_MEMORY_DATA_BUFFER_FULL", + "Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.", + 64, + ) + self.buffer[data_type].pop(0) + self.buffer[data_type].append(item) + + def get(self, data_type: str) -> Dict[str, Any]: + return self.buffer[data_type].pop(0) + + def get_length(self, data_type: str) -> int: + return len(self.buffer[data_type]) + + +def _save_item(item: Dict[str, Any], index: int, directory: pathlib.Path, data_type: str) -> None: + filename = directory / f"{data_type}-{index}.pt" + torch.save(item, filename.as_posix()) + + +def _load_item(index: int, directory: pathlib.Path, data_type: str, map_location=None) -> Dict[str, Any]: + filename = directory / f"{data_type}-{index}.pt" + return torch.load(filename.as_posix(), map_location=map_location, weights_only=True) diff --git a/docs/finetrainers-src-codebase/finetrainers/data/sampler.py b/docs/finetrainers-src-codebase/finetrainers/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d650e1d610e8ce91b4168a9960479cfcfe8f7 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/data/sampler.py @@ -0,0 +1,58 @@ +from typing import Any, Dict, List, Tuple + +import torch + + +class ResolutionSampler: + def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None: + self.batch_size = batch_size + self.dim_keys = dim_keys + assert dim_keys is not None, "dim_keys must be provided" + + self._chosen_leader_key = None + self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {} + self._satisfied_buckets: List[Dict[Any, Any]] = [] + + def consume(self, *dict_items: Dict[Any, Any]) -> None: + if self._chosen_leader_key is None: + self._determine_leader_item(*dict_items) + self._update_buckets(*dict_items) + + def get_batch(self) -> List[Dict[str, Any]]: + return list(zip(*self._satisfied_buckets.pop(-1))) + + @property + def is_ready(self) -> bool: + return len(self._satisfied_buckets) > 0 + + def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None: + num_observed = 0 + for dict_item in dict_items: + for key in self.dim_keys.keys(): + if key in dict_item.keys(): + self._chosen_leader_key = key + if not torch.is_tensor(dict_item[key]): + raise ValueError(f"Leader key {key} must be a tensor") + num_observed += 1 + if num_observed > 1: + raise ValueError( + f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys" + ) + if self._chosen_leader_key is None: + raise ValueError("No leader key found in provided list of data dictionaries") + + def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None: + chosen_value = [ + dict_item[self._chosen_leader_key] + for dict_item in dict_items + if self._chosen_leader_key in dict_item.keys() + ] + if len(chosen_value) == 0: + raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries") + chosen_value = chosen_value[0] + dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key]) + if dims not in self._unsatisfied_buckets: + self._unsatisfied_buckets[dims] = [] + self._unsatisfied_buckets[dims].append(dict_items) + if len(self._unsatisfied_buckets[dims]) == self.batch_size: + self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims)) diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py b/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca7b9b22ae0578612e1ebc54b550d86b04eba99c --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py @@ -0,0 +1,17 @@ +from .diffusion import flow_match_target, flow_match_xt +from .image import ( + bicubic_resize_image, + center_crop_image, + find_nearest_resolution_image, + resize_crop_image, + resize_to_nearest_bucket_image, +) +from .normalization import normalize +from .text import convert_byte_str_to_str, dropout_caption, dropout_embeddings_to_zero, remove_prefix +from .video import ( + bicubic_resize_video, + center_crop_video, + find_nearest_video_resolution, + resize_crop_video, + resize_to_nearest_bucket_video, +) diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py b/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d553895c2fb251abf80f01f284049acf84f87d --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py @@ -0,0 +1,11 @@ +import torch + + +def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + r"""Forward process of flow matching.""" + return (1.0 - t) * x0 + t * n + + +def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: + r"""Loss target for flow matching.""" + return n - x0 diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/image.py b/docs/finetrainers-src-codebase/finetrainers/functional/image.py new file mode 100644 index 0000000000000000000000000000000000000000..be2b024be001045171ec897064ab51433f875e0e --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/image.py @@ -0,0 +1,56 @@ +from typing import List, Literal, Tuple + +import torch +import torch.nn.functional as F + + +def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_channels, height, width = image.shape + crop_h, crop_w = size + if height < crop_h or width < crop_w: + raise ValueError(f"Image size {(height, width)} is smaller than the target size {size}.") + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + return image[:, top : top + crop_h, left : left + crop_w] + + +def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_channels, height, width = image.shape + target_h, target_w = size + scale = max(target_h / height, target_w / width) + new_h, new_w = int(height * scale), int(width * scale) + image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False) + return center_crop_image(image, size) + + +def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0] + + +def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]: + num_channels, height, width = image.shape + aspect_ratio = width / height + + def aspect_ratio_diff(bucket): + return abs((bucket[1] / bucket[0]) - aspect_ratio), (-bucket[0], -bucket[1]) + + return min(resolution_buckets, key=aspect_ratio_diff) + + +def resize_to_nearest_bucket_image( + image: torch.Tensor, + resolution_buckets: List[Tuple[int, int]], + resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", +) -> torch.Tensor: + target_size = find_nearest_resolution_image(image, resolution_buckets) + + if resize_mode == "center_crop": + return center_crop_image(image, target_size) + elif resize_mode == "resize_crop": + return resize_crop_image(image, target_size) + elif resize_mode == "bicubic": + return bicubic_resize_image(image, target_size) + else: + raise ValueError( + f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py b/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..b3433b7636dff5c3d89d17fb94e487d936658741 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch + + +def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor: + """ + Normalize a tensor to the range [min_val, max_val]. + + Args: + x (`torch.Tensor`): + The input tensor to normalize. + min (`float`, defaults to `-1.0`): + The minimum value of the normalized range. + max (`float`, defaults to `1.0`): + The maximum value of the normalized range. + dim (`int`, *optional*): + The dimension along which to normalize. If `None`, the entire tensor is normalized. + + Returns: + The normalized tensor of the same shape as `x`. + """ + if dim is None: + x_min = x.min() + x_max = x.max() + if torch.isclose(x_min, x_max).any(): + x = torch.full_like(x, min) + else: + x = min + (max - min) * (x - x_min) / (x_max - x_min) + else: + x_min = x.amin(dim=dim, keepdim=True) + x_max = x.amax(dim=dim, keepdim=True) + if torch.isclose(x_min, x_max).any(): + x = torch.full_like(x, min) + else: + x = min + (max - min) * (x - x_min) / (x_max - x_min) + return x diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/text.py b/docs/finetrainers-src-codebase/finetrainers/functional/text.py new file mode 100644 index 0000000000000000000000000000000000000000..dd319aba5437730be6b5b4d20c0de4de2ae9173c --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/text.py @@ -0,0 +1,40 @@ +import random +from typing import List, Union + +import torch + + +def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str: + """ + Extracts the actual string from a stringified bytes array (common in some webdatasets). + + Example: "b'hello world'" -> "hello world" + """ + try: + s = s[2:-1] + s = s.encode("utf-8").decode(encoding) + except (UnicodeDecodeError, UnicodeEncodeError, IndexError): + pass + return s + + +def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]: + if random.random() >= dropout_p: + return caption + if isinstance(caption, str): + return "" + return [""] * len(caption) + + +def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor: + if random.random() >= dropout_p: + return embed + embed = torch.zeros_like(embed) + return embed + + +def remove_prefix(text: str, prefixes: List[str]) -> str: + for prefix in prefixes: + if text.startswith(prefix): + return text.removeprefix(prefix).strip() + return text diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/video.py b/docs/finetrainers-src-codebase/finetrainers/functional/video.py new file mode 100644 index 0000000000000000000000000000000000000000..2fadf66373554b749bdb1d68e455932f158ad9b9 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/functional/video.py @@ -0,0 +1,96 @@ +from typing import List, Literal, Tuple + +import torch +import torch.nn.functional as F + + +def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + crop_h, crop_w = size + if height < crop_h or width < crop_w: + raise ValueError(f"Video size {(height, width)} is smaller than the target size {size}.") + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + return video[:, :, top : top + crop_h, left : left + crop_w] + + +def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + target_h, target_w = size + scale = max(target_h / height, target_w / width) + new_h, new_w = int(height * scale), int(width * scale) + video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False) + return center_crop_video(video, size) + + +def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + video = F.interpolate(video, size=size, mode="bicubic", align_corners=False) + return video + + +def find_nearest_video_resolution( + video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]] +) -> Tuple[int, int, int]: + num_frames, num_channels, height, width = video.shape + aspect_ratio = width / height + possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames] + + if not possible_buckets: + best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames)) + else: + best_frame_match = max(possible_buckets, key=lambda b: b[0]) + + frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]] + + def aspect_ratio_diff(bucket): + return abs((bucket[2] / bucket[1]) - aspect_ratio), (-bucket[1], -bucket[2]) + + return min(frame_filtered_buckets, key=aspect_ratio_diff) + + +def resize_to_nearest_bucket_video( + video: torch.Tensor, + resolution_buckets: List[Tuple[int, int, int]], + resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", +) -> torch.Tensor: + """ + Resizes a video tensor to the nearest resolution bucket using the specified mode. + - It first finds a frame match with <= T frames. + - Then, it selects the closest height/width bucket. + + Args: + video (`torch.Tensor`): + Input video tensor of shape `(B, T, C, H, W)`. + resolution_buckets (`List[Tuple[int, int, int]]`): + Available (num_frames, height, width) resolution buckets. + resize_mode (`str`): + One of ["center_crop", "resize_crop", "bicubic"]. + + Returns: + `torch.Tensor`: + Resized video tensor of the nearest bucket resolution. + """ + target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets) + + # Adjust frame count: only interpolate frames if no lesser/equal frame count exists + num_frames, num_channels, height, width = video.shape + _first_frame_only = False + if num_frames > target_frames: + # Downsample: Select frames evenly + indices = torch.linspace(0, num_frames - 1, target_frames).long() + video = video[indices, :, :, :] + elif num_frames < target_frames: + _first_frame_only = False + + # Resize spatial resolution + if resize_mode == "center_crop": + return center_crop_video(video, (target_h, target_w)), _first_frame_only + elif resize_mode == "resize_crop": + return resize_crop_video(video, (target_h, target_w)), _first_frame_only + elif resize_mode == "bicubic": + return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only + else: + raise ValueError( + f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/logging.py b/docs/finetrainers-src-codebase/finetrainers/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..66cf41b4cb943486067f74c550e1c53811e4290d --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/logging.py @@ -0,0 +1,139 @@ +import logging +import os +from typing import TYPE_CHECKING, Union + +import diffusers +import transformers + +from .constants import FINETRAINERS_LOG_LEVEL + + +if TYPE_CHECKING: + from .parallel import ParallelBackendType + + +class FinetrainersLoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: + super().__init__(logger, {}) + self.parallel_backend = parallel_backend + self._log_freq = {} + self._log_freq_counter = {} + + def log( + self, + level, + msg, + *args, + main_process_only: bool = False, + local_main_process_only: bool = True, + in_order: bool = False, + **kwargs, + ): + # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice + kwargs.setdefault("stacklevel", 2) + + if not self.isEnabledFor(level): + return + + if self.parallel_backend is None: + if int(os.environ.get("RANK", 0)) == 0: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + if (main_process_only or local_main_process_only) and in_order: + raise ValueError( + "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." + ) + + if (main_process_only and self.parallel_backend.is_main_process) or ( + local_main_process_only and self.parallel_backend.is_local_main_process + ): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + if in_order: + for i in range(self.parallel_backend.world_size): + if self.rank == i: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + self.parallel_backend.wait_for_everyone() + return + + if not main_process_only and not local_main_process_only: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + def log_freq( + self, + level: str, + name: str, + msg: str, + frequency: int, + *, + main_process_only: bool = False, + local_main_process_only: bool = True, + in_order: bool = False, + **kwargs, + ) -> None: + if frequency <= 0: + return + if name not in self._log_freq_counter: + self._log_freq[name] = frequency + self._log_freq_counter[name] = 0 + if self._log_freq_counter[name] % self._log_freq[name] == 0: + self.log( + level, + msg, + main_process_only=main_process_only, + local_main_process_only=local_main_process_only, + in_order=in_order, + **kwargs, + ) + self._log_freq_counter[name] += 1 + + +def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: + global _logger + return _logger + + +def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: + _logger.parallel_backend = parallel_backend + + +_logger = logging.getLogger("finetrainers") +_logger.setLevel(FINETRAINERS_LOG_LEVEL) +_console_handler = logging.StreamHandler() +_console_handler.setLevel(FINETRAINERS_LOG_LEVEL) +_formatter = logging.Formatter("%(asctime)s - [%(levelname)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +_console_handler.setFormatter(_formatter) +_logger.addHandler(_console_handler) +_logger.propagate = False +_logger = FinetrainersLoggerAdapter(_logger) + + +def set_dependency_log_level(verbose: int = 0, is_local_main_process: bool = False) -> None: + transformers_log_level = transformers.utils.logging.set_verbosity_error + diffusers_log_level = diffusers.utils.logging.set_verbosity_error + + if verbose == 0: + if is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_warning + diffusers_log_level = diffusers.utils.logging.set_verbosity_warning + elif verbose == 1: + if is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_info + diffusers_log_level = diffusers.utils.logging.set_verbosity_info + elif verbose == 2: + if is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_debug + diffusers_log_level = diffusers.utils.logging.set_verbosity_debug + else: + transformers_log_level = transformers.utils.logging.set_verbosity_debug + diffusers_log_level = diffusers.utils.logging.set_verbosity_debug + + transformers_log_level() + diffusers_log_level() diff --git a/docs/finetrainers-src-codebase/finetrainers/models/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec474ff2a5a786085cc5df72f295a590deee08fc --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/__init__.py @@ -0,0 +1,8 @@ +from .attention_dispatch import AttentionProvider, attention_dispatch, attention_provider +from .modeling_utils import ControlModelSpecification, ModelSpecification + + +from ._metadata.transformer import register_transformer_metadata # isort: skip + + +register_transformer_metadata() diff --git a/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5baeb109014d64f9bff2f102c28ee0a3da40f8 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py @@ -0,0 +1 @@ +from .transformer import register_transformer_metadata diff --git a/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..33de3148ca0b025d41be8602dc0c17c9b4eed4aa --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py @@ -0,0 +1,86 @@ +from diffusers import ( + CogVideoXTransformer3DModel, + CogView4Transformer2DModel, + FluxTransformer2DModel, + WanTransformer3DModel, +) + +from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry +from finetrainers.logging import get_logger + + +logger = get_logger() + + +def register_transformer_metadata(): + # CogVideoX + TransformerRegistry.register( + model_class=CogVideoXTransformer3DModel, + metadata=TransformerMetadata( + cp_plan={ + "": { + ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)], + }, + "transformer_blocks.0": { + ParamId("hidden_states", 0): CPInput(1, 3), + ParamId("encoder_hidden_states", 1): CPInput(1, 3), + }, + "proj_out": [CPOutput(1, 3)], + } + ), + ) + + # CogView4 + TransformerRegistry.register( + model_class=CogView4Transformer2DModel, + metadata=TransformerMetadata( + cp_plan={ + "patch_embed": { + ParamId(index=0): CPInput(1, 3, split_output=True), + ParamId(index=1): CPInput(1, 3, split_output=True), + }, + "rope": { + ParamId(index=0): CPInput(0, 2, split_output=True), + ParamId(index=1): CPInput(0, 2, split_output=True), + }, + "proj_out": [CPOutput(1, 3)], + } + ), + ) + + # Flux + TransformerRegistry.register( + model_class=FluxTransformer2DModel, + metadata=TransformerMetadata( + cp_plan={ + "": { + ParamId("hidden_states", 0): CPInput(1, 3), + ParamId("encoder_hidden_states", 1): CPInput(1, 3), + ParamId("img_ids", 4): CPInput(0, 2), + ParamId("txt_ids", 5): CPInput(0, 2), + }, + "proj_out": [CPOutput(1, 3)], + } + ), + ) + + # Wan2.1 + TransformerRegistry.register( + model_class=WanTransformer3DModel, + metadata=TransformerMetadata( + cp_plan={ + "rope": { + ParamId(index=0): CPInput(2, 4, split_output=True), + }, + "blocks.*": { + ParamId("encoder_hidden_states", 1): CPInput(1, 3), + }, + "blocks.0": { + ParamId("hidden_states", 0): CPInput(1, 3), + }, + "proj_out": [CPOutput(1, 3)], + } + ), + ) + + logger.debug("Metadata for transformer registered") diff --git a/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py b/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..e1438674db1f7bf242df1af0fd385a09a081eda6 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py @@ -0,0 +1,1812 @@ +import contextlib +import functools +import inspect +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +# Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take +# control for dispatching to different attention providers, we need to import the original function +# to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`. +import torch.autograd +from diffusers.utils.import_utils import OptionalDependencyNotAvailable +from torch.nn.functional import scaled_dot_product_attention as native_sdpa + +from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER +from finetrainers.logging import get_logger +from finetrainers.utils.import_utils import ( + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_version, + is_xformers_available, + is_xformers_version, +) + + +if is_flash_attn_available(): + if is_flash_attn_version("<", "2.6.3"): + raise OptionalDependencyNotAvailable( + "The `flash-attn` library version is too old. Please update it to at least 2.6.3." + ) + + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward +else: + flash_attn_func = None + flash_attn_varlen_func = None + _flash_attn_forward = None + _flash_attn_backward = None + + +if is_sageattention_available(): + if is_sageattention_version("<", "2.1.1"): + raise OptionalDependencyNotAvailable( + "The `sageattention` library version is too old. Please update it to at least 2.1.1." + ) + + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + import torch.nn.attention.flex_attention as flex_attention + + +if is_torch_version(">=", "2.6.0"): + from torch.distributed.tensor.experimental._attention import ( + _AttentionOp, + _cp_options, + _templated_ring_attention, + _templated_ring_attention_backward, + set_rotate_method, + ) +else: + _cp_options = None + _templated_ring_attention = None + set_rotate_method = None + + class _AttentionOp: + def __init__(self, *args, **kwargs): + raise OptionalDependencyNotAvailable( + "The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." + ) + + +if is_xformers_available(): + if is_xformers_version("<", "0.0.29"): + raise OptionalDependencyNotAvailable( + "The `xformers` library version is too old. Please update it to at least 0.0.29." + ) + + import xformers.ops as xops +else: + xops = None + + +logger = get_logger() + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +# ===== Custom operator implementations/wrappers ===== + + +def _finetrainers_scaled_dot_product_efficient_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = False, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 + # See: https://github.com/pytorch/pytorch/issues/152942 + seqlen_q = query.shape[-2] + out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + # LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch + # ring attention does not error out with shape mismatch + if compute_log_sumexp: + assert lse.ndim == 3 + lse = lse[:, :, :seqlen_q] # .contiguous() + + return out, lse, philox_seed, philox_offset + + +# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) +def _finetrainers_scaled_dot_product_efficient_attention_backward( + grad_out_: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + dropout_p: float, + grad_input_mask: List[bool], + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(grad_input_mask) == 4 + # https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113 + kAlignLSE = 32 + + logsumexp = torch.nn.functional.pad( + logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") + ) + + grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( + grad_out_=grad_out_, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=dropout_p, + grad_input_mask=grad_input_mask, + is_causal=is_causal, + scale=scale, + ) + + return grad_query, grad_key, grad_value, grad_attn_bias + + +# This function wraps the actual _flash_attn_forward call to return LSE at index 1 to be compatible with pytorch's native ring attention +def _finetrainers_flash_attn_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, +): + query, key, value = ( + x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) + ) # [B, N, S, D] -> [B, S, N, D] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax + ) + out = out.permute(0, 2, 1, 3).contiguous() # [B, S, N, D] -> [B, N, S, D] + return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state + + +# This function wraps the actual _flash_attn_backward call as the counterpart of the _finetrainers_flash_attn_forward function +def _finetrainers_flash_attn_backward( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, # Needs a different names than the one used in flash-attn because _templated_ring_attention_backward assumes name is logsumexp + dropout_p: float, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + rng_state: Optional[torch.Tensor] = None, + _permute_outputs: bool = True, +): + dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D] + + dq, dk, dv, softmax_d = _flash_attn_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + dq, + dk, + dv, + dropout_p, + scale, + is_causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state, + ) + + # Head dimension may have been padded + dq = dq[..., : grad_out.shape[-1]] + dk = dk[..., : grad_out.shape[-1]] + dv = dv[..., : grad_out.shape[-1]] + + if _permute_outputs: + dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) # [B, S, N, D] -> [B, N, S, D] + return dq, dk, dv + + +# ===== Attention provider ===== + + +class AttentionProvider(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionProviderRegistry: + _providers = {} + _constraints = {} + _supports_cp = {} + _supported_arg_names = {} + + _active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) + _checks_enabled = FINETRAINERS_ATTN_CHECKS + + # Context parallel attributes + _mesh: torch.distributed.device_mesh.DeviceMesh = None + _convert_to_fp32: bool = None + _rotate_method: Literal["allgather", "alltoall"] = None + + @classmethod + def register( + cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False + ): + logger.debug(f"Registering attention provider: {provider}") + + def decorator(func): + cls._providers[provider] = func + cls._constraints[provider] = constraints or [] + cls._supports_cp[provider] = supports_cp + cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_provider(cls): + return cls._active_provider, cls._providers[cls._active_provider] + + @classmethod + def list_providers(cls): + return list(cls._providers.keys()) + + @classmethod + def supports_context_parallel(cls, provider: AttentionProvider): + if provider not in cls._providers: + raise ValueError(f"Provider {provider} is not registered.") + return cls._supports_cp.get(provider, False) + + @classmethod + def context_parallel_enabled(cls): + return cls._mesh is not None + + @classmethod + def _set_context_parallel( + cls, + mesh: torch.distributed.device_mesh.DeviceMesh = None, + convert_to_fp32: bool = None, + rotate_method: str = None, + *, + reset: bool = False, + ): + if reset: + mesh = convert_to_fp32 = rotate_method = None + cls._mesh = mesh + cls._convert_to_fp32 = convert_to_fp32 + cls._rotate_method = rotate_method + + @classmethod + def _raise_cp_error_if_mesh_not_set(cls): + if cls._mesh is None: + raise ValueError( + "`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." + ) + + +@contextlib.contextmanager +def attention_provider( + provider: AttentionProvider = AttentionProvider.NATIVE, + *, + mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + convert_to_fp32: bool = True, + rotate_method: str = "allgather", +): + """Context manager to set the active attention provider and possibly enable context parallelism.""" + + if provider not in _AttentionProviderRegistry._providers: + raise ValueError(f"Provider {provider} is not registered.") + if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): + raise ValueError(f"Provider {provider} does not support context parallelism.") + + old_provider = _AttentionProviderRegistry._active_provider + _AttentionProviderRegistry._active_provider = provider + + _AttentionProviderRegistry._mesh = mesh + _AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 + _AttentionProviderRegistry._rotate_method = rotate_method + if mesh is not None: + _convert_to_f32 = _cp_options.convert_to_f32 + _enable_load_balance = _cp_options.enable_load_balance + _rotate_method = _cp_options.rotate_method + + try: + yield + finally: + _AttentionProviderRegistry._active_provider = old_provider + + _AttentionProviderRegistry._mesh = None + _AttentionProviderRegistry._convert_to_fp32 = None + _AttentionProviderRegistry._rotate_method = None + if mesh is not None: + _cp_options.convert_to_f32 = _convert_to_f32 + _cp_options.enable_load_balance = _enable_load_balance + _cp_options.rotate_method = _rotate_method + + +def attention_dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionProviderRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) + if removed_kwargs: + log_freq = 512 + msg = ( + f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " + f"message will be logged every {log_freq} calls." + ) + logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) + for check in _AttentionProviderRegistry._constraints.get(provider_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} + + if _AttentionProviderRegistry.context_parallel_enabled(): + _set_context_parallel_options(**kwargs) + + return provider_fn(**kwargs) + + +# ===== Helper functions ===== + + +# @torch.compiler.assume_constant_result +def _set_context_parallel_options(is_causal: bool, **kwargs): + _cp_options.enable_load_balance = is_causal + _cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 + set_rotate_method(_AttentionProviderRegistry._rotate_method) + + +def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: + if attn_mask is not None: + raise ValueError("Attention mask must be None for this provider.") + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, +) -> None: + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + if attn_mask is None: + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) + else: + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== Attention provider implementations ===== + + +# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 +class _flash_attn_flash_attention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_softmax: bool = False, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( + query=q, + key=k, + value=v, + dropout_p=dropout_p, + scale=softmax_scale, + is_causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + + ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) + + return (out, lse) if return_softmax else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + q, k, v, out, lse, rng_state = ctx.saved_tensors + + grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( + grad_out=grad_out, + query=q, + key=k, + value=v, + out=out, + logsumexp=lse, + dropout_p=ctx.dropout_p, + scale=ctx.softmax_scale, + is_causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + rng_state=rng_state, + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 +class _native_ring_flash_attn_flash_attention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_softmax: bool = False, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + # For ring flash attention using the flash-attn repo, we want the LSE but flash-attn only supports it if dropout_p > 0 + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=_finetrainers_flash_attn_forward, + query=q, + key=k, + value=v, + dropout_p=dropout_p, + scale=softmax_scale, + is_causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=True, + ) + + ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) + + return (out, lse) if return_softmax else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + q, k, v, out, lse, rng_state = ctx.saved_tensors + lse = lse.permute(0, 2, 1).contiguous() # [B, N, S] -> [B, S, N] + + grad_query, grad_key, grad_value = _templated_ring_attention_backward( + mesh=_AttentionProviderRegistry._mesh, + # This needs to be 1 because q, k, v, out_padded returned from forward are BSND instead of BNSD + # The grad_out permutation is handled in _finetrainers_flash_attn_backward, and the outputs from that are expected to have + # shape BSND instead of BNSD (requirement of _templated_ring_attention_backward), so we need to set seq_dim=1 and permute the + # returned outputs + seq_dim=1, + op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), + grad_out=grad_out, + grad_out_name="grad_out", + query=q, + key=k, + value=v, + out=out, + logsumexp=lse, + dropout_p=ctx.dropout_p, + scale=ctx.softmax_scale, + is_causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + rng_state=rng_state, + ) + grad_query, grad_key, grad_value = ( + x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_cp=True, +) +def flash_attn_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + dispatch_fn = ( + _native_ring_flash_attn_flash_attention + if _AttentionProviderRegistry.context_parallel_enabled() + else _flash_attn_flash_attention + ) + return dispatch_fn.apply( + query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_cp=False, +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + if _AttentionProviderRegistry.context_parallel_enabled(): + return_attn_probs = True + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + + rest = None + if return_attn_probs: + out, *rest = out + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() + if return_attn_probs: + return out, *rest[:1] + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], + supports_cp=False, +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, num_heads, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = flex_attention.create_block_mask( + _flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + return flex_attention.flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=None, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.NATIVE, + constraints=[_check_device, _check_shape], + supports_cp=False, +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + return native_sdpa( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +class _native_cudnn_attention(torch.autograd.Function): + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 + # forward declaration: + # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + # backward declaration: + # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + torch.ops.aten._scaled_dot_product_cudnn_attention( + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( + grad_out=grad_out, + query=query, + key=key, + value=value, + out=out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=ctx.attn_mask, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None + + +class _native_ring_native_cudnn_attention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + _templated_ring_attention( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=torch.ops.aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_query, grad_key, grad_value = _templated_ring_attention_backward( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=ctx.attn_mask, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_cp=True, +) +def native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + dispatch_fn = ( + _native_ring_native_cudnn_attention + if _AttentionProviderRegistry.context_parallel_enabled() + else _native_cudnn_attention + ) + return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) + + +class _native_efficient_attention(torch.autograd.Function): + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 + # forward declaration: + # aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + # backward declaration: + # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + + # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. + out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors + + # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. + grad_query, grad_key, grad_value, grad_attn_bias = ( + _finetrainers_scaled_dot_product_efficient_attention_backward( + grad_out_=grad_out, + query=query, + key=key, + value=value, + attn_bias=ctx.attn_mask, + out=out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=ctx.dropout_p, + grad_input_mask=[True, True, True, False], + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None + + +class _native_ring_native_efficient_attention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.attn_mask = attn_mask + + # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. + out, lse, philox_seed, philox_offset = _templated_ring_attention( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=_finetrainers_scaled_dot_product_efficient_attention_forward, + query=query, + key=key, + value=value, + attn_bias=attn_mask, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors + + # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. + grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=_finetrainers_scaled_dot_product_efficient_attention_backward, + grad_out=grad_out, + grad_out_name="grad_out_", + query=query, + key=key, + value=value, + attn_bias=ctx.attn_mask, + out=out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=ctx.dropout_p, + grad_input_mask=[True, True, True, False], + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], + supports_cp=True, +) +def native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> torch.Tensor: + dispatch_fn = ( + _native_ring_native_efficient_attention + if _AttentionProviderRegistry.context_parallel_enabled() + else _native_efficient_attention + ) + return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) + + +class _native_flash_attention(torch.autograd.Function): + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910 + # forward declaration: + # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + # backward declaration: + # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + torch.ops.aten._scaled_dot_product_flash_attention( + query=query, + key=key, + value=value, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( + grad_out=grad_out, + query=query, + key=key, + value=value, + out=out, + logsumexp=lse, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + philox_seed=philox_seed, + philox_offset=philox_offset, + scale=ctx.scale, + ) + + return grad_query, grad_key, grad_value, None, None, None, None + + +class _native_ring_native_flash_attention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + _templated_ring_attention( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=torch.ops.aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + ) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args: torch.Tensor, + ): + _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( + mesh=_AttentionProviderRegistry._mesh, + seq_dim=2, + op=torch.ops.aten._scaled_dot_product_flash_attention_backward, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=lse, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + philox_seed=philox_seed, + philox_offset=philox_offset, + ) + + return grad_query, grad_key, grad_value, None, None, None, None + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_cp=True, +) +def native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + dispatch_fn = ( + _native_ring_native_flash_attention + if _AttentionProviderRegistry.context_parallel_enabled() + else _native_flash_attention + ) + return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) + + +# class _native_math_attention(torch.autograd.Function): +# # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901 +# # forward declaration: +# # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) +# # backward declaration: +# # does not exist +# @staticmethod +# def forward( +# ctx: torch.autograd.function.FunctionCtx, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# attn_mask: Optional[torch.Tensor] = None, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# dropout_mask: Optional[torch.Tensor] = None, +# scale: Optional[float] = None, +# enable_gqa: bool = False, +# return_scores: bool = False, +# ): +# ctx.dropout_p = dropout_p +# ctx.is_causal = is_causal +# ctx.scale = scale +# ctx.enable_gqa = enable_gqa + +# print(f"query.shape: {query.shape}") +# with torch.enable_grad(): +# out, scores = torch.ops.aten._scaled_dot_product_attention_math( +# query=query, +# key=key, +# value=value, +# attn_mask=attn_mask, +# dropout_p=dropout_p, +# is_causal=is_causal, +# dropout_mask=dropout_mask, +# scale=scale, +# enable_gqa=enable_gqa, +# ) + +# ctx.save_for_backward(query, key, value, out) + +# return (out, scores) if return_scores else out + +# @staticmethod +# def backward( +# ctx: torch.autograd.function.FunctionCtx, +# grad_out: torch.Tensor, +# ): +# raise NotImplementedError("Backward pass for native math attention is not implemented.") + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_MATH, + constraints=[_check_device, _check_shape], + supports_cp=False, +) +def native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + return native_sdpa( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_cp=False, +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + if _AttentionProviderRegistry.context_parallel_enabled(): + return_lse = True + + kwargs = { + "q": query, + "k": key, + "v": value, + "tensor_layout": "HND", + "is_causal": is_causal, + "sm_scale": scale, + "return_lse": return_lse, + } + out = sageattn(**kwargs) + + rest = None + if return_lse: + out, *rest = out + if return_lse: + return out, *rest[:1] + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if enable_gqa: + # TODO + pass + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() + + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], + supports_cp=False, +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], + supports_cp=False, +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], + supports_cp=False, +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], + supports_cp=False, +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="HND", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, num_heads_q, seq_len_q, _ = query.shape + _, num_heads_kv, seq_len_kv, _ = key.shape + + # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers + # query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + if enable_gqa: + out = out.flatten(2, 3) + + out = out.permute(0, 2, 1, 3) # .contiguous() + return out diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f9a84073541b0e764877bac0335637f03d32ca --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py @@ -0,0 +1 @@ +from .base_specification import CogVideoXModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0e6210c47f448cfda570f3f07d324f7f980a71 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py @@ -0,0 +1,410 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from PIL.Image import Image +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer + +from finetrainers.data import VideoArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.models.utils import DiagonalGaussianDistribution +from finetrainers.processors import ProcessorMixin, T5Processor +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + +from .utils import prepare_rotary_positional_embeddings + + +logger = get_logger() + + +class CogVideoXLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the CogVideoX VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLCogVideoX, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W] + return {self.output_names[0]: latents} + + +class CogVideoXModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])] + if latent_model_processors is None: + latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (1, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = T5Tokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLCogVideoX.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLCogVideoX.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = CogVideoXTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = CogVideoXTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = CogVideoXDDIMScheduler.from_pretrained( + self.pretrained_model_name_or_path, subfolder="scheduler", **common_kwargs + ) + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[CogVideoXTransformer3DModel] = None, + vae: Optional[AutoencoderKLCogVideoX] = None, + scheduler: Optional[CogVideoXDDIMScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> CogVideoXPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = CogVideoXPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + caption: str, + max_sequence_length: int = 226, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + conditions.pop("prompt_attention_mask", None) + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLCogVideoX, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. + VAE_SPATIAL_SCALE_FACTOR = 8 + rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR + rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR + patch_size = self.transformer_config.patch_size + patch_size_t = getattr(self.transformer_config, "patch_size_t", None) + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2) + latents = posterior.sample(generator=generator) + del posterior + + if not getattr(self.vae_config, "invert_scale_latents", False): + latents = latents * self.vae_config.scaling_factor + + if patch_size_t is not None: + latents = self._pad_frames(latents, patch_size_t) + + timesteps = (sigmas.flatten() * 1000.0).long() + + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + + batch_size, num_frames, num_channels, height, width = latents.shape + ofs_emb = ( + None + if getattr(self.transformer_config, "ofs_embed_dim", None) is None + else latents.new_full((batch_size,), fill_value=2.0) + ) + + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SPATIAL_SCALE_FACTOR, + width=width * VAE_SPATIAL_SCALE_FACTOR, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, + patch_size=patch_size, + patch_size_t=patch_size_t, + attention_head_dim=self.transformer_config.attention_head_dim, + device=transformer.device, + base_height=rope_base_height, + base_width=rope_base_width, + ) + if self.transformer_config.use_rotary_positional_embeddings + else None + ) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["image_rotary_emb"] = image_rotary_emb + latent_model_conditions["ofs"] = ofs_emb + + velocity = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same + # code paths as scheduler.get_velocity(), which can be confusing to understand. + pred = scheduler.get_velocity(velocity, noisy_latents, timesteps) + target = latents + + return pred, target, sigmas + + def validation( + self, + pipeline: CogVideoXPipeline, + prompt: str, + image: Optional[Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + # TODO(aryan): add support for more parameters + if image is not None: + pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline) + + generation_kwargs = { + "prompt": prompt, + "image": image, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + CogVideoXPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: CogVideoXTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @staticmethod + def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor: + num_frames = latents.size(1) + additional_frames = patch_size_t - (num_frames % patch_size_t) + if additional_frames > 0: + last_frame = latents[:, -1:] + padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1) + latents = torch.cat([latents, padding_frames], dim=1) + return latents diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd98c1f3653dbe23a6f53fa54dfe3e7073ea9b99 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py @@ -0,0 +1,51 @@ +from typing import Optional, Tuple + +import torch +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + patch_size_t: int = None, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + if patch_size_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6afde3a5f16247cc6f47fc16561186e31a22ad --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py @@ -0,0 +1,2 @@ +from .base_specification import CogView4ModelSpecification +from .control_specification import CogView4ControlModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..f89eb21d878f475f361d6def12591c5037b248cc --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py @@ -0,0 +1,385 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKL, + CogView4Pipeline, + CogView4Transformer2DModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, GlmModel + +import finetrainers.functional as FF +from finetrainers.data import ImageArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + + +logger = get_logger() + + +class CogView4LatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the LTX VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - original_size: The original size of the input image/video. + - target_size: The target size of the input image/video. + - crop_coords: The top-left crop coordinates of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + + self.output_names = output_names + assert len(self.output_names) == 4 + + def forward( + self, + vae: AutoencoderKL, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + _original_height: Optional[int] = None, + _original_width: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if video is not None: + # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly + image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W] + + assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor" + image = image.to(device=device, dtype=vae.dtype) + + if compute_posterior: + latents = vae.encode(image).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and image.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(image) + latents = moments.to(dtype=dtype) + + batch_size = latents.size(0) + target_height = image.size(2) + target_width = image.size(3) + original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat( + batch_size, 1 + ) + target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1) + crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1) + + return { + self.output_names[0]: latents, + self.output_names[1]: original_size, + self.output_names[2]: target_size, + self.output_names[3]: crop_coords, + } + + +class CogView4ModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "THUDM/CogView4-6B", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])] + if latent_model_processors is None: + latent_model_processors = [ + CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"]) + ] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = GlmModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = GlmModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKL.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = CogView4Transformer2DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = CogView4Transformer2DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_encoder: Optional[GlmModel] = None, + transformer: Optional[CogView4Transformer2DModel] = None, + vae: Optional[AutoencoderKL] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> CogView4Pipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + # Load the scheduler based on CogView4's config instead of using the default initialization being used for training + # "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = CogView4Pipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + caption: str, + max_sequence_length: int = 1024, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKL, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + _original_height: Optional[int] = None, + _original_width: Optional[int] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + "_original_height": _original_height, + "_original_width": _original_width, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: CogView4Transformer2DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + base_image_sequence_length = 256 + base_shift = 0.25 + max_shift = 0.75 + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + if getattr(self.vae_config, "shift_factor", None) is not None: + latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor + else: + latents = latents * self.vae_config.scaling_factor + + noise = torch.zeros_like(latents).normal_(generator=generator) + timesteps = (sigmas.flatten() * 1000.0).long() + + image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2 + mu = (image_sequence_length / base_image_sequence_length) ** 0.5 + mu = mu * max_shift + base_shift + shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0) + noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation + # but let's keep it this way for now. Longer training runs should reveal more insights. + # return pred, target, sigmas + return pred, target, shifted_sigmas + + def validation( + self, + pipeline: CogView4Pipeline, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + image = pipeline(**generation_kwargs).images[0] + return [ImageArtifact(value=image)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + CogView4Pipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: CogView4Transformer2DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = CogView4Transformer2DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..16f359fa4a59f6b218aebfa217c5a22bfd6afdb2 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py @@ -0,0 +1,375 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from transformers import AutoTokenizer, GlmModel + +import finetrainers.functional as FF +from finetrainers.data import ImageArtifact +from finetrainers.models.modeling_utils import ControlModelSpecification +from finetrainers.models.utils import DiagonalGaussianDistribution, _expand_linear_with_zeroed_weights +from finetrainers.patches.dependencies.diffusers.control import control_channel_concat +from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + +from .base_specification import CogView4LatentEncodeProcessor + + +class CogView4ControlModelSpecification(ControlModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "THUDM/CogView4-6B", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + control_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])] + if latent_model_processors is None: + latent_model_processors = [ + CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"]) + ] + if control_model_processors is None: + control_model_processors = [ + CogView4LatentEncodeProcessor(["control_latents", "original_size", "target_size", "crop_coords"]) + ] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + self.control_model_processors = control_model_processors + + @property + def control_injection_layer_name(self): + return "patch_embed.proj" + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = GlmModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = GlmModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKL.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = CogView4Transformer2DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = CogView4Transformer2DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + actual_new_in_features = new_in_features * transformer.config.patch_size**2 + transformer.patch_embed.proj = _expand_linear_with_zeroed_weights( + transformer.patch_embed.proj, new_in_features=actual_new_in_features + ) + transformer.register_to_config(in_channels=new_in_features) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_encoder: Optional[GlmModel] = None, + transformer: Optional[CogView4Transformer2DModel] = None, + vae: Optional[AutoencoderKL] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> CogView4Pipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + # Load the scheduler based on CogView4's config instead of using the default initialization being used for training + # "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = CogView4Pipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + caption: str, + max_sequence_length: int = 1024, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKL, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + control_image: Optional[torch.Tensor] = None, + control_video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + _original_height: Optional[int] = None, + _original_width: Optional[int] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + common_kwargs = { + "vae": vae, + "generator": generator, + "compute_posterior": compute_posterior, + "_original_height": _original_height, + "_original_width": _original_width, + **kwargs, + } + conditions = {"image": image, "video": video, **common_kwargs} + input_keys = set(conditions.keys()) + conditions = ControlModelSpecification.prepare_latents(self, self.latent_model_processors, **conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + + control_conditions = {"image": control_image, "video": control_video, **common_kwargs} + input_keys = set(control_conditions.keys()) + control_conditions = ControlModelSpecification.prepare_latents( + self, self.control_model_processors, **control_conditions + ) + control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys} + + return {**control_conditions, **conditions} + + def forward( + self, + transformer: CogView4Transformer2DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + base_image_sequence_length = 256 + base_shift = 0.25 + max_shift = 0.75 + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + control_latents = latent_model_conditions.pop("control_latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + control_posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("control_latents")) + control_latents = control_posterior.sample(generator=generator) + del control_posterior + + if getattr(self.vae_config, "shift_factor") is not None: + latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor + control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor + + noise = torch.zeros_like(latents).normal_(generator=generator) + timesteps = (sigmas.flatten() * 1000.0).long() + + image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2 + mu = (image_sequence_length / base_image_sequence_length) ** 0.5 + mu = mu * max_shift + base_shift + shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0) + noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas) + noisy_latents = torch.cat([noisy_latents, control_latents], dim=1) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation + # but let's keep it this way for now. Longer training runs should reveal more insights. + # return pred, target, sigmas + return pred, target, shifted_sigmas + + def validation( + self, + pipeline: CogView4Pipeline, + prompt: str, + control_image: torch.Tensor, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + with torch.no_grad(): + dtype = pipeline.vae.dtype + device = pipeline._execution_device + in_channels = self.transformer_config.in_channels # We need to use the original in_channels + latents = pipeline.prepare_latents(1, in_channels, height, width, dtype, device, generator) + control_image = pipeline.image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(device=device, dtype=dtype) + control_latents = pipeline.vae.encode(control_image).latent_dist.sample(generator=generator) + if getattr(self.vae_config, "shift_factor") is not None: + control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor + + generation_kwargs = { + "latents": latents, + "prompt": prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + + with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]): + image = pipeline(**generation_kwargs).images[0] + + return [ImageArtifact(value=image)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + CogView4Pipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if norm_state_dict is not None: + safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: CogView4Transformer2DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = CogView4Transformer2DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @property + def _original_control_layer_in_features(self): + return self.transformer_config.in_channels + + @property + def _original_control_layer_out_features(self): + return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim + + @property + def _qk_norm_identifiers(self): + return ["attn1.norm_q", "attn1.norm_k"] diff --git a/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1d172114ae4f79a2f89e5196ecbdc8be279e3a --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py @@ -0,0 +1 @@ +from .base_specification import FluxModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3ea1e167cfd72e5ee7c158a9a715ee6b7ad09f --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py @@ -0,0 +1,411 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +import finetrainers.functional as FF +from finetrainers.data import ImageArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.processors import CLIPPooledProcessor, ProcessorMixin, T5Processor +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + + +logger = get_logger() + + +class FluxLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the Flux VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKL, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if video is not None: + # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly + image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W] + + assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor" + image = image.to(device=device, dtype=vae.dtype) + + if compute_posterior: + latents = vae.encode(image).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and image.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(image) + latents = moments.to(dtype=dtype) + + return {self.output_names[0]: latents} + + +class FluxModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "black-forest-labs/FLUX.1-dev", + tokenizer_id: Optional[str] = None, + tokenizer_2_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + tokenizer_2_id=tokenizer_2_id, + text_encoder_id=text_encoder_id, + text_encoder_2_id=text_encoder_2_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [ + CLIPPooledProcessor(["pooled_projections"]), + T5Processor( + ["encoder_hidden_states", "prompt_attention_mask"], + input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, + ), + ] + if latent_model_processors is None: + latent_model_processors = [FluxLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.tokenizer_2_id is not None: + tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs) + else: + tokenizer_2 = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = CLIPTextModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + if self.text_encoder_2_id is not None: + text_encoder_2 = T5EncoderModel.from_pretrained( + self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs + ) + else: + text_encoder_2 = T5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=self.text_encoder_2_dtype, + **common_kwargs, + ) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKL.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = FluxTransformer2DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = FluxTransformer2DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder: Optional[CLIPTextModel] = None, + text_encoder_2: Optional[T5EncoderModel] = None, + transformer: Optional[FluxTransformer2DModel] = None, + vae: Optional[AutoencoderKL] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> FluxPipeline: + components = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + # Load the scheduler based on Flux's config instead of using the default initialization being used for training + # "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = FluxPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.text_encoder_2.to(self.text_encoder_2_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + tokenizer_2: CLIPTokenizer, + text_encoder: CLIPTextModel, + text_encoder_2: T5EncoderModel, + caption: str, + max_sequence_length: int = 512, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKL, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: FluxTransformer2DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + if getattr(self.vae_config, "shift_factor", None) is not None: + latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor + else: + latents = latents * self.vae_config.scaling_factor + + noise = torch.zeros_like(latents).normal_(generator=generator) + timesteps = (sigmas.flatten() * 1000.0).long() + img_ids = FluxPipeline._prepare_latent_image_ids( + latents.size(0), latents.size(2) // 2, latents.size(3) // 2, latents.device, latents.dtype + ) + text_ids = latents.new_zeros(condition_model_conditions["encoder_hidden_states"].shape[1], 3) + + if self.transformer_config.guidance_embeds: + guidance_scale = 1.0 + guidance = latents.new_full((1,), guidance_scale).expand(latents.shape[0]) + else: + guidance = None + + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + noisy_latents = FluxPipeline._pack_latents(noisy_latents, *latents.shape) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + condition_model_conditions.pop("prompt_attention_mask", None) + + spatial_compression_ratio = 2 ** len(self.vae_config.block_out_channels) + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps / 1000.0, + guidance=guidance, + img_ids=img_ids, + txt_ids=text_ids, + return_dict=False, + )[0] + pred = FluxPipeline._unpack_latents( + pred, + latents.size(2) * spatial_compression_ratio, + latents.size(3) * spatial_compression_ratio, + spatial_compression_ratio, + ) + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: FluxPipeline, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + image = pipeline(**generation_kwargs).images[0] + return [ImageArtifact(value=image)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + FluxPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: FluxTransformer2DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = FluxTransformer2DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..518a42865f0cee30a534da458ec63b08c1a8d7e4 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py @@ -0,0 +1 @@ +from .base_specification import HunyuanVideoModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..80d02c931dc54fe8e5578a383b239cb0091d26f2 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py @@ -0,0 +1,391 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel + +import finetrainers.functional as FF +from finetrainers.data import VideoArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + + +logger = get_logger() + + +class HunyuanLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the HunyuanVideo VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + return {self.output_names[0]: latents} + + +class HunyuanVideoModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo", + tokenizer_id: Optional[str] = None, + tokenizer_2_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + tokenizer_2_id=tokenizer_2_id, + text_encoder_id=text_encoder_id, + text_encoder_2_id=text_encoder_2_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [ + LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]), + CLIPPooledProcessor( + ["pooled_projections"], + input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, + ), + ] + if latent_model_processors is None: + latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.tokenizer_2_id is not None: + tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs) + else: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = LlamaModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = LlamaModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + if self.text_encoder_2_id is not None: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs + ) + else: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=self.text_encoder_2_dtype, + **common_kwargs, + ) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder: Optional[LlamaModel] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + transformer: Optional[HunyuanVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLHunyuanVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> HunyuanVideoPipeline: + components = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = HunyuanVideoPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.text_encoder_2.to(self.text_encoder_2_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + tokenizer_2: CLIPTokenizer, + text_encoder: LlamaModel, + text_encoder_2: CLIPTextModel, + caption: str, + max_sequence_length: int = 256, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: HunyuanVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + guidance: float = 1.0, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + latents = latents * self.vae_config.scaling_factor + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + timesteps = (sigmas.flatten() * 1000.0).long() + guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0 + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["guidance"] = guidance + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: HunyuanVideoPipeline, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + HunyuanVideoPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: HunyuanVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4e3550d54bb33fac80dd2d075ad2846eeeed46 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py @@ -0,0 +1 @@ +from .base_specification import LTXVideoModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..c8eaa5e420a604e2f13d9f3adee07e7ef1d02dee --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py @@ -0,0 +1,504 @@ +import functools +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXImageToVideoPipeline, + LTXPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from PIL.Image import Image +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer + +import finetrainers.functional as FF +from finetrainers.data import VideoArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.parallel import ParallelBackendEnum +from finetrainers.processors import ProcessorMixin, T5Processor +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function + + +logger = get_logger() + + +class LTXLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the LTX VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - num_frames: The number of frames in the input video. + - height: The height of the input image/video. + - width: The width of the input image/video. + - latents_mean: The latent channel means from the VAE state dict. + - latents_std: The latent channel standard deviations from the VAE state dict. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 6 + + def forward( + self, + vae: AutoencoderKLLTXVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + _, _, num_frames, height, width = latents.shape + + return { + self.output_names[0]: latents, + self.output_names[1]: num_frames, + self.output_names[2]: height, + self.output_names[3]: width, + self.output_names[4]: vae.latents_mean, + self.output_names[5]: vae.latents_std, + } + + +class LTXVideoModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "Lightricks/LTX-Video", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])] + if latent_model_processors is None: + latent_model_processors = [ + LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"]) + ] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = T5Tokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLLTXVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLLTXVideo.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = LTXVideoTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = LTXVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[LTXVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLLTXVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> LTXPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = LTXPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + caption: str, + max_sequence_length: int = 128, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLLTXVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: LTXVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + # TODO(aryan): make this configurable? Should it be? + first_frame_conditioning_p = 0.1 + min_first_frame_sigma = 0.25 + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + latents_mean = latent_model_conditions.pop("latents_mean") + latents_std = latent_model_conditions.pop("latents_std") + + latents = self._normalize_latents(latents, latents_mean, latents_std) + noise = torch.zeros_like(latents).normal_(generator=generator) + + if random.random() < first_frame_conditioning_p: + # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value. + # Making as estimated guess, we limit the sigmas to be at least 0.2. + # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas + # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas). + first_frame_sigma = torch.rand_like(sigmas) * sigmas + first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma)) + + latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:] + noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma) + noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas) + noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2) + else: + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + patch_size = self.transformer_config.patch_size + patch_size_t = self.transformer_config.patch_size_t + + latents = self._pack_latents(latents, patch_size, patch_size_t) + noise = self._pack_latents(noise, patch_size, patch_size_t) + noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t) + sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1) + timesteps = (sigmas * 1000.0).long() + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + + # TODO(aryan): make this configurable + frame_rate = 25 + temporal_compression_ratio = 8 + vae_spatial_compression_ratio = 32 + latent_frame_rate = frame_rate / temporal_compression_ratio + + rope_interpolation_scale = [ + 1 / latent_frame_rate, + vae_spatial_compression_ratio, + vae_spatial_compression_ratio, + ] + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: LTXPipeline, + prompt: str, + image: Optional[Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + frame_rate: int = 25, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + if image is not None: + pipeline = LTXImageToVideoPipeline.from_pipe(pipeline) + + generation_kwargs = { + "prompt": prompt, + "image": image, + "height": height, + "width": width, + "num_frames": num_frames, + "frame_rate": frame_rate, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + LTXPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: LTXVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def apply_tensor_parallel( + self, + backend: ParallelBackendEnum, + device_mesh: torch.distributed.DeviceMesh, + transformer: LTXVideoTransformer3DModel, + **kwargs, + ) -> None: + if backend == ParallelBackendEnum.PTD: + _apply_tensor_parallel_ptd(device_mesh, transformer) + else: + raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification") + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + batch_size = latents.shape[0] + latents_mean = latents_mean.view(batch_size, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(batch_size, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents) + return latents + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +def _apply_tensor_parallel_ptd( + device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel +) -> None: + from torch.distributed.tensor.parallel import parallelize_module + from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel + + transformer_plan = { + # ===== Condition embeddings ===== + # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), + # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), + # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), + # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), + # "caption_projection.linear_1": ColwiseParallel(), + # "caption_projection.linear_2": RowwiseParallel(), + # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), + # ===== ===== + } + + for block in transformer.transformer_blocks: + block_plan = {} + + # ===== Attention ===== + # 8 all-to-all, 3 all-reduce + # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.norm_q"] = SequenceParallel() + # block_plan["attn1.norm_k"] = SequenceParallel() + # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.norm_q"] = SequenceParallel() + # block_plan["attn2.norm_k"] = SequenceParallel() + # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # ===== ===== + + block_plan["ff.net.0.proj"] = ColwiseParallel() + block_plan["ff.net.2"] = RowwiseParallel() + + parallelize_module(block, device_mesh, block_plan) + + parallelize_module(transformer, device_mesh, transformer_plan) diff --git a/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py b/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b96599801feb3abe16acbada78dd0b4dfb9b9c7 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py @@ -0,0 +1,388 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from PIL.Image import Image + +from finetrainers.logging import get_logger +from finetrainers.parallel import ParallelBackendEnum +from finetrainers.processors import ProcessorMixin +from finetrainers.typing import ArtifactType, SchedulerType, TokenizerType +from finetrainers.utils import resolve_component_cls + + +if TYPE_CHECKING: + from finetrainers.trainer.control_trainer.config import FrameConditioningType + +logger = get_logger() + +# TODO(aryan): we most likely don't need this. take a look after refactoring more +# fmt: off +IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"} +# fmt: on + + +class ModelSpecification: + r""" + The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides + loose structure about how to organize the code for training. The trainer implementations will + make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc. + """ + + def __init__( + self, + pretrained_model_name_or_path: Optional[str] = None, + tokenizer_id: Optional[str] = None, + tokenizer_2_id: Optional[str] = None, + tokenizer_3_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + text_encoder_3_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + text_encoder_2_dtype: torch.dtype = torch.bfloat16, + text_encoder_3_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: str = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + ) -> None: + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.tokenizer_id = tokenizer_id + self.tokenizer_2_id = tokenizer_2_id + self.tokenizer_3_id = tokenizer_3_id + self.text_encoder_id = text_encoder_id + self.text_encoder_2_id = text_encoder_2_id + self.text_encoder_3_id = text_encoder_3_id + self.transformer_id = transformer_id + self.vae_id = vae_id + self.text_encoder_dtype = text_encoder_dtype + self.text_encoder_2_dtype = text_encoder_2_dtype + self.text_encoder_3_dtype = text_encoder_3_dtype + self.transformer_dtype = transformer_dtype + self.vae_dtype = vae_dtype + self.revision = revision + self.cache_dir = cache_dir + self.condition_model_processors = condition_model_processors or [] + self.latent_model_processors = latent_model_processors or [] + + self.transformer_config: Dict[str, Any] = None + self.vae_config: Dict[str, Any] = None + + self._load_configs() + + def _trainer_init(self, *args, **kwargs): + pass + + # TODO(aryan): revisit how to do this better without user having to worry about it + @property + def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]: + raise NotImplementedError( + f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}" + ) + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + raise NotImplementedError( + f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}" + ) + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + raise NotImplementedError( + f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}" + ) + + def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]: + raise NotImplementedError( + f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}" + ) + + def load_pipeline( + self, + tokenizer: Optional[TokenizerType] = None, + tokenizer_2: Optional[TokenizerType] = None, + tokenizer_3: Optional[TokenizerType] = None, + text_encoder: Optional[torch.nn.Module] = None, + text_encoder_2: Optional[torch.nn.Module] = None, + text_encoder_3: Optional[torch.nn.Module] = None, + transformer: Optional[torch.nn.Module] = None, + vae: Optional[torch.nn.Module] = None, + scheduler: Optional[SchedulerType] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> DiffusionPipeline: + raise NotImplementedError( + f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}" + ) + + def prepare_conditions(self, processors: Optional[ProcessorMixin] = None, **kwargs) -> Dict[str, Any]: + if processors is None: + processors = self.condition_model_processors + for processor in processors: + result = processor(**kwargs) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(kwargs.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + kwargs.update(result) + return kwargs + + def prepare_latents(self, processors: Optional[ProcessorMixin] = None, **kwargs) -> Dict[str, Any]: + if processors is None: + processors = self.latent_model_processors + for processor in processors: + result = processor(**kwargs) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(kwargs.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + kwargs.update(result) + return kwargs + + def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + keys = list(data[0].keys()) + collated_data = {} + for key in keys: + if key in IGNORE_KEYS_FOR_COLLATION: + collated_data[key] = data[0][key] + continue + collated_d = [d[key] for d in data] + if isinstance(collated_d[0], torch.Tensor): + collated_d = torch.cat(collated_d) + collated_data[key] = collated_d + return collated_data + + def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + keys = list(data[0].keys()) + collated_data = {} + for key in keys: + if key in IGNORE_KEYS_FOR_COLLATION: + collated_data[key] = data[0][key] + continue + collated_d = [d[key] for d in data] + # TODO(aryan): Support multi-resolution collation + if isinstance(collated_d[0], torch.Tensor): + collated_d = torch.cat(collated_d) + collated_data[key] = collated_d + return collated_data + + def forward( + self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs + ) -> Dict[str, torch.Tensor]: + raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}") + + def validation( + self, + pipeline: DiffusionPipeline, + prompt: Optional[str] = None, + image: Optional[Image] = None, + video: Optional[List[Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + frame_rate: Optional[int] = None, + generator: Optional[torch.Generator] = None, + ) -> List[ArtifactType]: + raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}") + + def _save_lora_weights( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + r""" + Save the lora state dicts of the model to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError( + f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}" + ) + + def _save_model( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + r""" + Save the state dicts to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}") + + def apply_tensor_parallel( + self, + backend: ParallelBackendEnum, + device_mesh: torch.distributed.DeviceMesh, + text_encoder: torch.nn.Module, + text_encoder_2: torch.nn.Module, + text_encoder_3: torch.nn.Module, + transformer: torch.nn.Module, + vae: torch.nn.Module, + ) -> None: + raise NotImplementedError( + f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}" + ) + + def _load_configs(self) -> None: + self._load_transformer_config() + self._load_vae_config() + + def _load_transformer_config(self) -> None: + if self.transformer_id is not None: + transformer_cls = resolve_component_cls( + self.transformer_id, + component_name="_class_name", + filename="config.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = transformer_cls.load_config( + self.transformer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + transformer_cls = resolve_component_cls( + self.pretrained_model_name_or_path, + component_name="transformer", + filename="model_index.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = transformer_cls.load_config( + self.pretrained_model_name_or_path, + subfolder="transformer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = FrozenDict(**self.transformer_config) + + def _load_vae_config(self) -> None: + if self.vae_id is not None: + vae_cls = resolve_component_cls( + self.vae_id, + component_name="_class_name", + filename="config.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir) + else: + vae_cls = resolve_component_cls( + self.pretrained_model_name_or_path, + component_name="vae", + filename="model_index.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.vae_config = vae_cls.load_config( + self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir + ) + self.vae_config = FrozenDict(**self.vae_config) + + +class ControlModelSpecification(ModelSpecification): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.frame_conditioning_type: "FrameConditioningType" = None + self.frame_conditioning_index: int = None + self.frame_conditioning_concatenate_mask: bool = False + + def _trainer_init( + self, frame_conditioning_type: "FrameConditioningType", frame_conditioning_index: int, concatenate_mask: bool + ) -> None: + self.frame_conditioning_type = frame_conditioning_type + self.frame_conditioning_index = frame_conditioning_index + self.frame_conditioning_concatenate_mask = concatenate_mask + + @property + def control_injection_layer_name(self): + r"""Must return the FQN (fully-qualified name) of the control injection layer.""" + raise NotImplementedError( + f"ControlModelSpecification::control_injection_layer_name is not implemented for {self.__class__.__name__}" + ) + + def load_diffusion_models(self, new_in_features: int) -> Dict[str, Union[torch.nn.Module]]: + raise NotImplementedError( + f"ControlModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}" + ) + + def _save_lora_weights( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + r""" + Save the lora state dicts of the model to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError( + f"ControlModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}" + ) + + def _save_model( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + r""" + Save the state dicts to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError( + f"ControlModelSpecification::save_model is not implemented for {self.__class__.__name__}" + ) + + @property + def _original_control_layer_in_features(self): + """ + Original in_features of the input projection layer where control is injected. + """ + raise NotImplementedError( + f"ControlModelSpecification::_original_control_layer_in_features is not implemented for {self.__class__.__name__}" + ) + + @property + def _original_control_layer_out_features(self): + """ + Original out_features of the input projection layer where control is injected. + + This will be used as the rank for control injection layer when performing low-rank training and unused otherwise. + """ + raise NotImplementedError( + f"ControlModelSpecification::_original_control_layer_out_features is not implemented for {self.__class__.__name__}" + ) + + @property + def _qk_norm_identifiers(self): + raise NotImplementedError( + f"ControlModelSpecification::_qk_norm_identifiers is not implemented for {self.__class__.__name__}" + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/models/utils.py b/docs/finetrainers-src-codebase/finetrainers/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea9370146c26b5fa3a9de95e4ed4bc9805ce318 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/utils.py @@ -0,0 +1,109 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +from diffusers.utils.torch_utils import randn_tensor + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1): + # Note: _dim is the new argument added here after copying from diffusers + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +@torch.no_grad() +def _expand_linear_with_zeroed_weights( + module: torch.nn.Linear, new_in_features: Optional[int] = None, new_out_features: Optional[int] = None +) -> torch.nn.Linear: + if new_in_features is None: + new_in_features = module.in_features + if new_out_features is None: + new_out_features = module.out_features + bias = getattr(module, "bias", None) + new_module = torch.nn.Linear(new_in_features, new_out_features, bias=bias is not None) + new_module.to(device=module.weight.device, dtype=module.weight.dtype) + new_module.weight.zero_() + new_module.weight.data[: module.weight.data.shape[0], : module.weight.data.shape[1]].copy_(module.weight.data) + if bias is not None: + new_module.bias.zero_() + new_module.bias.data[: bias.data.shape[0]].copy_(bias.data) + return new_module + + +@torch.no_grad() +def _expand_conv3d_with_zeroed_weights( + module: torch.nn.Linear, new_in_channels: Optional[int] = None, new_out_channels: Optional[int] = None +) -> torch.nn.Conv3d: + if new_in_channels is None: + new_in_channels = module.in_channels + if new_out_channels is None: + new_out_channels = module.out_channels + bias = getattr(module, "bias", None) + new_module = torch.nn.Conv3d( + new_in_channels, + new_out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=bias is not None, + ) + new_module.to(device=module.weight.device, dtype=module.weight.dtype) + new_module.weight.zero_() + new_module.weight.data[: module.weight.data.shape[0], : module.weight.data.shape[1]].copy_(module.weight.data) + if bias is not None: + new_module.bias.zero_() + new_module.bias.data[: bias.data.shape[0]].copy_(bias.data) + return new_module diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb19b5b4f0440a4c620e35efbcda0a3d18b1cbab --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py @@ -0,0 +1,2 @@ +from .base_specification import WanModelSpecification +from .control_specification import WanControlModelSpecification diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..633d532f8ab0f94357ecd33990e91530702544a6 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py @@ -0,0 +1,577 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanImageToVideoPipeline, + WanPipeline, + WanTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +import finetrainers.functional as FF +from finetrainers.data import VideoArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ModelSpecification +from finetrainers.processors import ProcessorMixin, T5Processor +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import get_non_null_items, safetensors_torch_save_function + + +logger = get_logger() + + +class WanLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the Wan VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - latents_mean: The channel-wise mean of the latent space. + - latents_std: The channel-wise standard deviation of the latent space. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 3 + + def forward( + self, + vae: AutoencoderKLWan, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + # TODO(aryan): refactor in diffusers to have use_slicing attribute + # if vae.use_slicing and video.shape[0] > 1: + # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + # moments = torch.cat(encoded_slices) + # else: + # moments = vae._encode(video) + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + latents_mean = torch.tensor(vae.config.latents_mean) + latents_std = 1.0 / torch.tensor(vae.config.latents_std) + + return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std} + + +class WanImageConditioningLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the Wan VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - latents_mean: The channel-wise mean of the latent space. + - latents_std: The channel-wise standard deviation of the latent space. + - mask: The conditioning frame mask for the input image/video. + """ + + def __init__(self, output_names: List[str], *, use_last_frame: bool = False): + super().__init__() + self.output_names = output_names + self.use_last_frame = use_last_frame + assert len(self.output_names) == 4 + + def forward( + self, + vae: AutoencoderKLWan, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + num_frames = video.size(2) + if not self.use_last_frame: + first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:] + video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2) + else: + first_frame, remaining_frames, last_frame = video[:, :, :1], video[:, :, 1:-1], video[:, :, -1:] + video = torch.cat([first_frame, torch.zeros_like(remaining_frames), last_frame], dim=2) + + # Image conditioning uses argmax sampling, so we use "mode" here + if compute_posterior: + latents = vae.encode(video).latent_dist.mode() + latents = latents.to(dtype=dtype) + else: + # TODO(aryan): refactor in diffusers to have use_slicing attribute + # if vae.use_slicing and video.shape[0] > 1: + # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + # moments = torch.cat(encoded_slices) + # else: + # moments = vae._encode(video) + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + latents_mean = torch.tensor(vae.config.latents_mean) + latents_std = 1.0 / torch.tensor(vae.config.latents_std) + + temporal_downsample = 2 ** sum(vae.temperal_downsample) if getattr(self, "vae", None) else 4 + mask = latents.new_ones(latents.shape[0], 1, num_frames, latents.shape[3], latents.shape[4]) + if not self.use_last_frame: + mask[:, :, 1:] = 0 + else: + mask[:, :, 1:-1] = 0 + first_frame_mask = mask[:, :, :1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_downsample) + mask = torch.cat([first_frame_mask, mask[:, :, 1:]], dim=2) + mask = mask.view(latents.shape[0], -1, temporal_downsample, latents.shape[3], latents.shape[4]) + mask = mask.transpose(1, 2) + + return { + self.output_names[0]: latents, + self.output_names[1]: latents_mean, + self.output_names[2]: latents_std, + self.output_names[3]: mask, + } + + +class WanImageEncodeProcessor(ProcessorMixin): + r""" + Processor to encoding image conditioning for Wan I2V training. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - image_embeds: The CLIP vision model image embeddings of the input image. + """ + + def __init__(self, output_names: List[str], *, use_last_frame: bool = False): + super().__init__() + self.output_names = output_names + self.use_last_frame = use_last_frame + assert len(self.output_names) == 1 + + def forward( + self, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + device = image_encoder.device + dtype = image_encoder.dtype + last_image = None + + # We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but + # the processor expects it to be in the range [0, 1]. + image = image if video is None else video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame) + image = FF.normalize(image, min=0.0, max=1.0, dim=1) + assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor" + + if self.use_last_frame: + last_image = image if video is None else video[:, -1] + last_image = FF.normalize(last_image, min=0.0, max=1.0, dim=1) + image = torch.stack([image, last_image], dim=0) + + image = image_processor(images=image.float(), do_rescale=False, do_convert_rgb=False, return_tensors="pt") + image = image.to(device=device, dtype=dtype) + image_embeds = image_encoder(**image, output_hidden_states=True) + image_embeds = image_embeds.hidden_states[-2] + return {self.output_names[0]: image_embeds} + + +class WanModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + use_last_frame = self.transformer_config.get("pos_embed_seq_len", None) is not None + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])] + if latent_model_processors is None: + latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] + + if self.transformer_config.get("image_dim", None) is not None: + latent_model_processors.append( + WanImageConditioningLatentEncodeProcessor( + ["latent_condition", "__drop__", "__drop__", "latent_condition_mask"], + use_last_frame=use_last_frame, + ) + ) + latent_model_processors.append( + WanImageEncodeProcessor(["encoder_hidden_states_image"], use_last_frame=use_last_frame) + ) + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = UMT5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLWan.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + models = {"vae": vae} + if self.transformer_config.get("image_dim", None) is not None: + # TODO(aryan): refactor the trainer to be able to support these extra models from CLI args more easily + image_encoder = CLIPVisionModel.from_pretrained( + self.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.bfloat16 + ) + image_processor = CLIPImageProcessor.from_pretrained( + self.pretrained_model_name_or_path, subfolder="image_processor" + ) + models["image_encoder"] = image_encoder + models["image_processor"] = image_processor + + return models + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = WanTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = WanTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_encoder: Optional[UMT5EncoderModel] = None, + transformer: Optional[WanTransformer3DModel] = None, + vae: Optional[AutoencoderKLWan] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + image_encoder: Optional[CLIPVisionModel] = None, + image_processor: Optional[CLIPImageProcessor] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> Union[WanPipeline, WanImageToVideoPipeline]: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + components = get_non_null_items(components) + + if self.transformer_config.get("image_dim", None) is not None: + pipe = WanPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + else: + pipe = WanImageToVideoPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + # TODO(aryan): add support in diffusers + # if enable_slicing: + # pipe.vae.enable_slicing() + # if enable_tiling: + # pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + caption: str, + max_sequence_length: int = 512, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLWan, + image_encoder: Optional[CLIPVisionModel] = None, + image_processor: Optional[CLIPImageProcessor] = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image_encoder": image_encoder, + "image_processor": image_processor, + "image": image, + "video": video, + "generator": generator, + # We must force this to False because the latent normalization should be done before + # the posterior is computed. The VAE does not handle this any more: + # https://github.com/huggingface/diffusers/pull/10998 + "compute_posterior": False, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: WanTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + compute_posterior = False # See explanation in prepare_latents + latent_condition = latent_condition_mask = None + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + latent_condition = latent_model_conditions.pop("latent_condition", None) + latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None) + else: + latents = latent_model_conditions.pop("latents") + latents_mean = latent_model_conditions.pop("latents_mean") + latents_std = latent_model_conditions.pop("latents_std") + latent_condition = latent_model_conditions.pop("latent_condition", None) + latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None) + + mu, logvar = torch.chunk(latents, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + latents = torch.cat([mu, logvar], dim=1) + + posterior = DiagonalGaussianDistribution(latents) + latents = posterior.sample(generator=generator) + + if latent_condition is not None: + mu, logvar = torch.chunk(latent_condition, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + latent_condition = torch.cat([mu, logvar], dim=1) + + posterior = DiagonalGaussianDistribution(latent_condition) + latent_condition = posterior.mode() + + del posterior + + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + timesteps = (sigmas.flatten() * 1000.0).long() + + if self.transformer_config.get("image_dim", None) is not None: + noisy_latents = torch.cat([noisy_latents, latent_condition_mask, latent_condition], dim=1) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: Union[WanPipeline, WanImageToVideoPipeline], + prompt: str, + image: Optional[PIL.Image.Image] = None, + last_image: Optional[PIL.Image.Image] = None, + video: Optional[List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + if self.transformer_config.get("image_dim", None) is not None: + if image is None and video is None: + raise ValueError("Either image or video must be provided for Wan I2V validation.") + image = image if image is not None else video[0] + generation_kwargs["image"] = image + if self.transformer_config.get("pos_embed_seq_len", None) is not None: + last_image = last_image if last_image is not None else image if video is None else video[-1] + generation_kwargs["last_image"] = last_image + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + pipeline_cls = ( + WanImageToVideoPipeline if self.transformer_config.get("image_dim", None) is not None else WanPipeline + ) + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + pipeline_cls.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: WanTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = WanTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..78a4de9b1cd92e9414227da992731f1facc1c137 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py @@ -0,0 +1,437 @@ +import functools +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import safetensors +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanPipeline, + WanTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel + +import finetrainers.functional as FF +from finetrainers.data import VideoArtifact +from finetrainers.logging import get_logger +from finetrainers.models.modeling_utils import ControlModelSpecification +from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights +from finetrainers.patches.dependencies.diffusers.control import control_channel_concat +from finetrainers.processors import ProcessorMixin, T5Processor +from finetrainers.typing import ArtifactType, SchedulerType +from finetrainers.utils import get_non_null_items, safetensors_torch_save_function + +from .base_specification import WanLatentEncodeProcessor + + +if TYPE_CHECKING: + from finetrainers.trainer.control_trainer.config import FrameConditioningType + +logger = get_logger() + + +class WanControlModelSpecification(ControlModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + control_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])] + if latent_model_processors is None: + latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] + if control_model_processors is None: + control_model_processors = [WanLatentEncodeProcessor(["control_latents", "__drop__", "__drop__"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + self.control_model_processors = control_model_processors + + @property + def control_injection_layer_name(self) -> str: + return "patch_embedding" + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = UMT5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLWan.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = WanTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = WanTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + transformer.patch_embedding = _expand_conv3d_with_zeroed_weights( + transformer.patch_embedding, new_in_channels=new_in_features + ) + transformer.register_to_config(in_channels=new_in_features) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_encoder: Optional[UMT5EncoderModel] = None, + transformer: Optional[WanTransformer3DModel] = None, + vae: Optional[AutoencoderKLWan] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> WanPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = WanPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + # TODO(aryan): add support in diffusers + # if enable_slicing: + # pipe.vae.enable_slicing() + # if enable_tiling: + # pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + caption: str, + max_sequence_length: int = 512, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLWan, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + control_image: Optional[torch.Tensor] = None, + control_video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + common_kwargs = { + "vae": vae, + "generator": generator, + # We must force this to False because the latent normalization should be done before + # the posterior is computed. The VAE does not handle this any more: + # https://github.com/huggingface/diffusers/pull/10998 + "compute_posterior": False, + **kwargs, + } + conditions = {"image": image, "video": video, **common_kwargs} + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + + control_conditions = {"image": control_image, "video": control_video, **common_kwargs} + input_keys = set(control_conditions.keys()) + control_conditions = ControlModelSpecification.prepare_latents( + self, self.control_model_processors, **control_conditions + ) + control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys} + + return {**control_conditions, **conditions} + + def forward( + self, + transformer: WanTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents + + compute_posterior = False # See explanation in prepare_latents + if compute_posterior: + latents = latent_model_conditions.pop("latents") + control_latents = latent_model_conditions.pop("control_latents") + else: + latents = latent_model_conditions.pop("latents") + control_latents = latent_model_conditions.pop("control_latents") + latents_mean = latent_model_conditions.pop("latents_mean") + latents_std = latent_model_conditions.pop("latents_std") + + mu, logvar = torch.chunk(latents, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + latents = torch.cat([mu, logvar], dim=1) + + mu, logvar = torch.chunk(control_latents, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + control_latents = torch.cat([mu, logvar], dim=1) + + posterior = DiagonalGaussianDistribution(latents) + latents = posterior.mode() + del posterior + + control_posterior = DiagonalGaussianDistribution(control_latents) + control_latents = control_posterior.mode() + del control_posterior + + noise = torch.zeros_like(latents).normal_(generator=generator) + timesteps = (sigmas.flatten() * 1000.0).long() + + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + control_latents = apply_frame_conditioning_on_latents( + control_latents, + noisy_latents.shape[2], + channel_dim=1, + frame_dim=2, + frame_conditioning_type=self.frame_conditioning_type, + frame_conditioning_index=self.frame_conditioning_index, + concatenate_mask=self.frame_conditioning_concatenate_mask, + ) + noisy_latents = torch.cat([noisy_latents, control_latents], dim=1) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: WanPipeline, + prompt: str, + control_image: Optional[torch.Tensor] = None, + control_video: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + frame_conditioning_type: "FrameConditioningType" = "full", + frame_conditioning_index: int = 0, + **kwargs, + ) -> List[ArtifactType]: + from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents + + with torch.no_grad(): + dtype = pipeline.vae.dtype + device = pipeline._execution_device + in_channels = self.transformer_config.in_channels # We need to use the original in_channels + latents = pipeline.prepare_latents(1, in_channels, height, width, num_frames, dtype, device, generator) + latents_mean = ( + torch.tensor(self.vae_config.latents_mean) + .view(1, self.vae_config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_config.latents_std).view(1, self.vae_config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if control_image is not None: + control_video = pipeline.video_processor.preprocess( + control_image, height=height, width=width + ).unsqueeze(2) + else: + control_video = pipeline.video_processor.preprocess_video(control_video, height=height, width=width) + + control_video = control_video.to(device=device, dtype=dtype) + control_latents = pipeline.vae.encode(control_video).latent_dist.mode() + control_latents = self._normalize_latents(control_latents, latents_mean, latents_std) + control_latents = apply_frame_conditioning_on_latents( + control_latents, + latents.shape[2], + channel_dim=1, + frame_dim=2, + frame_conditioning_type=frame_conditioning_type, + frame_conditioning_index=frame_conditioning_index, + concatenate_mask=self.frame_conditioning_concatenate_mask, + ) + + generation_kwargs = { + "latents": latents, + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + + with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]): + video = pipeline(**generation_kwargs).frames[0] + + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + WanPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if norm_state_dict is not None: + safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: WanTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = WanTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents + + @property + def _original_control_layer_in_features(self): + return self.transformer_config.in_channels + + @property + def _original_control_layer_out_features(self): + return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim + + @property + def _qk_norm_identifiers(self): + return ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] diff --git a/docs/finetrainers-src-codebase/finetrainers/optimizer.py b/docs/finetrainers-src-codebase/finetrainers/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..57da28e9377f2bf82b5307fae83338ad0b9ec385 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/optimizer.py @@ -0,0 +1,449 @@ +import functools +import math +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import torch +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_optimizer_state_dict, + set_optimizer_state_dict, +) +from torch.distributed.checkpoint.stateful import Stateful + +from .parallel import ParallelBackendEnum +from .utils.import_utils import is_bitsandbytes_available + + +class OptimizerWrapper(Stateful): + r""" + Optimizer wrapper that: + - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages + - saves/loading optimizer state_dict at checkpoint + """ + + def __init__( + self, + model_parts: List[torch.nn.Module], + optimizer_cls: Type[torch.optim.Optimizer], + optimizer_kwargs: Dict[str, Any], + ) -> None: + self.optimizer_cls = optimizer_cls + self.optimizer_kwargs = optimizer_kwargs + + self.optimizers = [] + self.model_parts = model_parts + + for model in self.model_parts: + optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) + self.optimizers.append(optimizer) + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def state_dict(self) -> Dict[str, Any]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self.model_parts, self.optimizers)) + + +class SchedulerWrapper: + def __init__( + self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int + ) -> None: + self.schedulers = [] + for optimizer in optimizers: + self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) + + def step(self) -> None: + for scheduler in self.schedulers: + scheduler.step() + + def get_last_lr(self) -> List[float]: + # TODO(aryan): look into this later. Currently calling it leads to NCCL hang????? + return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} + + def get_lr_scheduler_state(self) -> Dict[str, Any]: + state_dict = {} + if len(self.schedulers) == 1: + state_dict["lr_scheduler"] = self.schedulers[0] + else: + # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. + # It should only support saving and loading a distributed checkpoint with the same number of pp ranks + for idx, lr_scheduler in enumerate(self.schedulers): + state_dict[f"lr_scheduler_{idx}"] = lr_scheduler + return state_dict + + +def get_optimizer( + parallel_backend: ParallelBackendEnum, + name: str, + model_parts: List[torch.nn.Module], + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.999, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + fused: bool = False, +) -> Union[torch.optim.Optimizer, OptimizerWrapper]: + name = name.lower() + + _raise_errors_if_packages_not_available(name) + + if name == "adam": + optimizer_cls = torch.optim.Adam + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + "fused": fused, + } + elif name == "adamw": + optimizer_cls = torch.optim.AdamW + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + "fused": fused, + } + elif name == "adam-bnb": + from bitsandbytes.optim import Adam + + optimizer_cls = Adam + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adamw-bnb": + from bitsandbytes.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adam-bnb-8bit": + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adamw-bnb-8bit": + from bitsandbytes.optim import AdamW8bit + + optimizer_cls = AdamW8bit + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + # TODO(aryan): handle bitsandbytes and torchao + else: + raise ValueError(f"Unsupported optimizer: {name}") + + if parallel_backend == ParallelBackendEnum.ACCELERATE: + return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) + elif parallel_backend == ParallelBackendEnum.PTD: + return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) + + +def get_optimizer_accelerate( + model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] +) -> torch.optim.Optimizer: + params = [param for model in model_parts for param in model.parameters() if param.requires_grad] + optimizer = optimizer_cls(params, **optimizer_kwargs) + return optimizer + + +def get_optimizer_ptd( + model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] +) -> OptimizerWrapper: + return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) + + +def get_lr_scheduler( + parallel_backend: ParallelBackendEnum, + name: str, + optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], + step_rules: Optional[str] = None, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, + lr_init: float = 1e-3, + lr_end: float = 1e-7, + last_epoch: int = -1, +) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: + name = name.lower() + if name == "constant": + scheduler_lambda_fn = get_constant_schedule() + elif name == "constant_with_warmup": + scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) + elif name == "piecewise_constant": + scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) + elif name == "linear": + scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) + elif name == "cosine": + scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) + elif name == "cosine_with_restarts": + scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( + num_warmup_steps, num_training_steps, num_cycles + ) + elif name == "polynomial": + scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( + num_warmup_steps, num_training_steps, lr_init, lr_end, power + ) + else: + raise ValueError(f"Unsupported scheduler: {name}") + + if parallel_backend == ParallelBackendEnum.ACCELERATE: + return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) + elif parallel_backend == ParallelBackendEnum.PTD: + return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) + + +def get_lr_scheduler_accelerate( + optimizer: torch.optim.Optimizer, + scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], + last_epoch: int = -1, +) -> torch.optim.lr_scheduler.LambdaLR: + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) + return scheduler + + +def get_lr_scheduler_ptd( + optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 +) -> SchedulerWrapper: + return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) + + +# ============================== +# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py +# ============================== + + +def get_constant_schedule() -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + """ + + def lr_lambda(current_step: int): + return 1.0 + + return lr_lambda + + +def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return lr_lambda + + +def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + return rules_func + + +def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return lr_lambda + + +def get_cosine_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return lr_lambda + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int = 1, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return lr_lambda + + +def get_polynomial_decay_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + lr_init: float, + lr_end: float = 1e-7, + power: float = 1.0, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + """ + + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return lr_lambda + + +def _raise_errors_if_packages_not_available(name: str) -> None: + name_split = name.split("-") + if len(name_split) < 2: + return + package_name = name_split[1] + if package_name == "bnb": + if not is_bitsandbytes_available(): + raise ImportError( + f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py b/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..878bd31bbe550bab7dae31e98e98d8a30038a6cc --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py @@ -0,0 +1,22 @@ +from enum import Enum +from typing import Union + +from .accelerate import AccelerateParallelBackend +from .ptd import PytorchDTensorParallelBackend +from .utils import dist_max, dist_mean + + +ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend] + + +class ParallelBackendEnum(str, Enum): + ACCELERATE = "accelerate" + PTD = "ptd" + + +def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType: + if backend == ParallelBackendEnum.ACCELERATE: + return AccelerateParallelBackend + if backend == ParallelBackendEnum.PTD: + return PytorchDTensorParallelBackend + raise ValueError(f"Unknown parallel backend: {backend}") diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py b/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py new file mode 100644 index 0000000000000000000000000000000000000000..59c1b5e4155528c55a2a4c968a6420c1b621d570 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py @@ -0,0 +1,383 @@ +import datetime +import os +import pathlib +import shutil +import time +from typing import Any, Callable, Dict, Optional + +import torch +from diffusers.utils import is_accelerate_available + +from finetrainers.logging import get_logger +from finetrainers.utils import get_device_info + +from .base import BaseCheckpointer, BaseParallelBackend + + +if not is_accelerate_available(): + raise ImportError( + "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend." + ) + +from accelerate import Accelerator +from accelerate.data_loader import DataLoader +from accelerate.utils import ( + DataLoaderConfiguration, + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) + + +logger = get_logger() +_device_type, _device_module = get_device_info() + + +class AccelerateParallelBackend(BaseParallelBackend): + def __init__( + self, + world_size: int, + pp_degree: int = 1, + dp_degree: int = 1, + dp_shards: int = -1, + cp_degree: int = 1, + tp_degree: int = 1, + backend: str = "nccl", + timeout: int = 180, + logging_dir: Optional[str] = None, + output_dir: Optional[str] = None, + gradient_accumulation_steps: Optional[int] = None, + ) -> None: + super().__init__() + + self._world_size = world_size + self._pp_degree = pp_degree + self._dp_degree = dp_degree + self._dp_shards = dp_shards + self._cp_degree = cp_degree + self._tp_degree = tp_degree + self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None + self._logging_dir = ( + self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None + ) + self._backend = backend + self._timeout = timeout + self._gradient_accumulation_steps = gradient_accumulation_steps + + if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1: + raise ValueError( + "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment." + ) + if dp_degree != world_size: + raise ValueError("Data parallel degree must be equal to world size.") + + self._accelerator = None + if world_size == 1: + # Needs special handling for single GPU training + project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) + dataloader_config = DataLoaderConfiguration( + split_batches=False, dispatch_batches=False, use_stateful_dataloader=True + ) + init_process_group_kwargs = InitProcessGroupKwargs( + backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) + ) + self._accelerator = Accelerator( + project_config=project_config, + dataloader_config=dataloader_config, + gradient_accumulation_steps=gradient_accumulation_steps, + log_with=None, + kwargs_handlers=[init_process_group_kwargs], + ) + if torch.backends.mps.is_available(): + self._accelerator.native_amp = False + + self._mesh: torch.distributed.DeviceMesh = None + + def enable_determinism(self, seed: int) -> None: + set_seed(seed) + + def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: + project_config = None + ddp_kwargs = None + init_process_group_kwargs = None + if self._accelerator is None: + project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + dataloader_config = DataLoaderConfiguration( + split_batches=False, dispatch_batches=False, use_stateful_dataloader=True + ) + init_process_group_kwargs = InitProcessGroupKwargs( + backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) + ) + self._accelerator, model = apply_ddp( + model, + project_config, + ddp_kwargs, + init_process_group_kwargs, + dataloader_config, + self._gradient_accumulation_steps, + accelerator=self._accelerator, + ) + logger.debug("Applied AccelerateParallel::apply_ddp to model.") + return model + + def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: + return self._accelerator.prepare_model(model) + + def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: + logger.debug("AccelerateParallelBackend::prepare_dataset completed!") + return dataset + + def prepare_dataloader( + self, + dataset: torch.utils.data.IterableDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + ) -> DataLoader: + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory + ) + dataloader = self._accelerator.prepare_data_loader(dataloader) + logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") + return dataloader + + def prepare_optimizer(self, optimizer, lr_scheduler): + optimizer = self._accelerator.prepare_optimizer(optimizer) + lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler) + return optimizer, lr_scheduler + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + def _get_mesh(): + if name is None: + return self._mesh + try: + return self._mesh[name] + except (KeyError, RuntimeError): + return self._mesh + + if self._mesh is not None: + return _get_mesh() + + mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)] + mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] + names = [x[0] for x in mesh_list] + degrees = [x[1] for x in mesh_list] + mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) + + dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] + + if self.data_replication_enabled: + dp_mesh_names.append("dp_replicate") + dp_cp_mesh_names.append("dp_replicate") + if self.data_sharding_enabled: + dp_mesh_names.append("dp_shard") + dp_cp_mesh_names.append("dp_shard") + dp_shard_cp_mesh_names.append("dp_shard") + if self.context_parallel_enabled: + dp_cp_mesh_names.append("cp") + dp_shard_cp_mesh_names.append("cp") + + if len(dp_mesh_names) > 0: + mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") + if len(dp_cp_mesh_names) > 0: + mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") + if len(dp_shard_cp_mesh_names) > 0: + mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") + + logger.debug(f"Device mesh: {mesh}") + self._mesh = mesh + return _get_mesh() + + def get_checkpointer(self, *args, **kwargs): + return AccelerateCheckpointer(self._accelerator, *args, **kwargs) + + @property + def world_size(self): + return self._accelerator.num_processes + + @property + def rank(self): + return self._accelerator.process_index + + @property + def local_rank(self): + return self._accelerator.local_process_index + + @property + def is_main_process(self): + r"""Returns `True` if the current process is the main process on the master node.""" + return self._accelerator.is_main_process + + @property + def is_local_main_process(self): + r"""Returns `True` if the current process is the main process on local node.""" + return self._accelerator.is_local_main_process + + @property + def device(self): + return self._accelerator.device + + def wait_for_everyone(self): + self._accelerator.wait_for_everyone() + + def destroy(self): + if self.is_main_process and self.tracker is not None: + self.tracker.finish() + self._accelerator.end_training() + + @property + def pipeline_parallel_enabled(self): + return self._pp_degree > 1 + + @property + def data_parallel_enabled(self): + return self._dp_degree > 1 or self._dp_shards > 1 + + @property + def data_replication_enabled(self): + return self._dp_degree > 1 + + @property + def data_sharding_enabled(self): + return self._dp_shards > 1 + + @property + def context_parallel_enabled(self): + return self._cp_degree > 1 + + @property + def tensor_parallel_enabled(self): + return self._tp_degree > 1 + + +class AccelerateCheckpointer(BaseCheckpointer): + def __init__( + self, + accelerator: Accelerator, + states: Dict[str, Any], + checkpointing_steps: int, + checkpointing_limit: int, + output_dir: str, + enable: bool = True, + _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, + _prefix: str = "finetrainers_step", + *args, + **kwargs, + ) -> None: + self.accelerator = accelerator + self.states = states + + self.checkpointing_steps = checkpointing_steps + self.checkpointing_limit = checkpointing_limit + self.output_dir = pathlib.Path(output_dir) + self.enable = enable + self._callback_fn = _callback_fn + self._prefix = _prefix + + def save_model_hook(models, weights, output_dir: str) -> None: + if not self.accelerator.is_main_process: + return + + # TODO(aryan): this is a temporary assertion since we only support training transformer at the moment. + # Remove it when adding support for training text encoders/vae and more. + assert len(models) == 1 + + _callback_fn(weights[0]) + torch.save(self.states, os.path.join(output_dir, "states.pt")) + + def load_model_hook(models, input_dir) -> None: + self.states = torch.load(os.path.join(input_dir, "states.pt")) + + self.accelerator.register_save_state_pre_hook(save_model_hook) + self.accelerator.register_load_state_pre_hook(load_model_hook) + + logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") + + def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: + if not self._should_checkpoint(step, force): + return None + + checkpoint_dir = self._get_checkpoint_dir(step) + begin_time = time.monotonic() + self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True) + end_time = time.monotonic() + logger.info( + f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" + ) + self._purge_stale_checkpoints() + + return checkpoint_dir.as_posix() + + def load(self, step: int = -1) -> bool: + if not self.enable: + return False + if not self.output_dir.exists(): + return False + if step != -1 and not self._get_checkpoint_dir(step).exists(): + return False + + if step == -1: + latest_checkpoint_dir = self._find_latest_checkpoint_dir() + if latest_checkpoint_dir is None: + return False + step = int(latest_checkpoint_dir.name.split("_")[-1]) + + checkpoint_dir = self._get_checkpoint_dir(step) + logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") + + begin_time = time.monotonic() + self.accelerator.load_state(checkpoint_dir.as_posix()) + end_time = time.monotonic() + logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") + + return True + + def _should_checkpoint(self, step: int, force: bool) -> bool: + if not self.enable: + return False + if not force: + if step % self.checkpointing_steps != 0: + return False + return True + + def _get_checkpoint_dir(self, step: int) -> pathlib.Path: + return self.output_dir / f"{self._prefix}_{step}" + + def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: + checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) + return checkpoints[-1] if len(checkpoints) > 0 else None + + def _purge_stale_checkpoints(self) -> None: + if self.checkpointing_limit is None or self.checkpointing_limit <= 0: + return + checkpoints = sorted( + self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True + ) + for checkpoint in checkpoints[self.checkpointing_limit :]: + logger.info(f"Deleting stale checkpoint: {checkpoint}") + shutil.rmtree(checkpoint, ignore_errors=True) + + +def apply_ddp( + model: torch.nn.Module, + project_config: Optional[ProjectConfiguration] = None, + ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, + init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, + dataloader_config: Optional[DataLoaderConfiguration] = None, + gradient_accumulation_steps: Optional[int] = None, + accelerator: Optional[Accelerator] = None, +) -> torch.nn.Module: + if accelerator is None: + accelerator = Accelerator( + project_config=project_config, + dataloader_config=dataloader_config, + gradient_accumulation_steps=gradient_accumulation_steps, + log_with=None, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + if torch.backends.mps.is_available(): + accelerator.native_amp = False + accelerator.prepare_model(model) + return accelerator, model diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/base.py b/docs/finetrainers-src-codebase/finetrainers/parallel/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ab04aeb71ed82db34f4154e7ba23b51ce2737579 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/base.py @@ -0,0 +1,145 @@ +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional + +import torch + +from finetrainers.trackers import DummyTracker, TrackerType, initialize_trackers + + +class BaseParallelBackend: + r""" + Base class that contains properties and methods that should be implemented by different parallel backends. + """ + + def __init__(self): + self.tracker = None + + def enable_determinism(self, seed: int) -> None: + raise NotImplementedError("Method `enable_determinism` must be implemented by subclass.") + + def apply_ddp(self, *args, **kwargs) -> torch.nn.Module: + raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.") + + def apply_fsdp2(self, *args, **kwargs) -> torch.nn.Module: + raise NotImplementedError("Method `apply_fsdp2` must be implemented by subclass.") + + def apply_context_parallel(self, *args, **kwargs) -> torch.nn.Module: + raise NotImplementedError("Method `apply_context_parallel` must be implemented by subclass.") + + def prepare_model(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_model` must be implemented by subclass.") + + def prepare_dataset(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.") + + def prepare_dataloader(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.") + + def prepare_optimizer(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.") + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + raise NotImplementedError("Method `get_mesh` must be implemented by subclass.") + + def get_checkpointer(self, *args, **kwargs) -> None: + raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.") + + def initialize_trackers( + self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + ) -> TrackerType: + if self.is_main_process: + self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir) + else: + self.tracker = DummyTracker() + + def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.is_main_process: + self.tracker.log(metrics, step) + + def wait_for_everyone(self): + raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.") + + @contextmanager + def main_process_first(self): + raise NotImplementedError("Method `main_process_first` must be implemented by subclass.") + + def destroy(self): + raise NotImplementedError("Method `destroy` must be implemented by subclass.") + + @property + def world_size(self): + raise NotImplementedError("Method `world_size` must be implemented by subclass.") + + @property + def rank(self): + raise NotImplementedError("Method `rank` must be implemented by subclass.") + + @property + def local_rank(self): + raise NotImplementedError("Method `local_rank` must be implemented by subclass.") + + @property + def is_main_process(self): + raise NotImplementedError("Method `is_main_process` must be implemented by subclass.") + + @property + def is_local_main_process(self): + raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.") + + @property + def device(self): + raise NotImplementedError("Method `device` must be implemented by subclass.") + + @property + def pipeline_parallel_enabled(self): + raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.") + + @property + def data_parallel_enabled(self): + raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.") + + @property + def data_replication_enabled(self): + raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.") + + @property + def data_sharding_enabled(self): + raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.") + + @property + def context_parallel_enabled(self): + raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.") + + @property + def tensor_parallel_enabled(self): + raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.") + + +class BaseCheckpointer: + r""" + Base class that contains properties and methods that should be implemented by different parallel backends. + """ + + def __init__( + self, + dataloader: torch.utils.data.DataLoader, + model_parts: List[torch.nn.Module], + optimizers: Any, + schedulers: Any, + states: Dict[str, Any], + checkpointing_steps: int, + checkpointing_limit: int, + output_dir: str, + enable: bool = True, + _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, + _prefix: str = "finetrainers_step", + *args, + **kwargs, + ) -> None: + raise NotImplementedError("Method `__init__` must be implemented by subclass.") + + def save(self, step: int, force: bool, *, _device: Optional[torch.device] = None, _is_main_process: bool) -> str: + raise NotImplementedError("Method `save` must be implemented by subclass.") + + def load(self, step: int = -1) -> bool: + raise NotImplementedError("Method `load` must be implemented by subclass.") diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py b/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9f54d66ec1941ffc44d6239b305cc397ce61d4 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py @@ -0,0 +1,7 @@ +from .base import BaseParallelBackend + + +class DeepspeedParallelBackend(BaseParallelBackend): + def __init__(self): + # TODO(aryan) + raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.") diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py b/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py new file mode 100644 index 0000000000000000000000000000000000000000..2a95b1a95781913a4b682b2a57159416d8d7443b --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py @@ -0,0 +1,709 @@ +import datetime +import functools +import os +import pathlib +import shutil +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import datasets.distributed +import torch +import torch.distributed._functional_collectives +import torch.distributed.checkpoint +import torch.distributed.checkpoint.stateful +from diffusers.hooks import HookRegistry, ModelHook +from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard +from torch.distributed._composable.replicate import replicate +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + set_model_state_dict, +) +from torch.distributed.tensor import DTensor, Shard + +from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry +from finetrainers.data import DPDataLoader +from finetrainers.logging import get_logger +from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module +from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES + +from .base import BaseCheckpointer, BaseParallelBackend + + +if TYPE_CHECKING: + from finetrainers import optimizer + + +_device_type, _device_module = get_device_info() +logger = get_logger() + + +class PytorchDTensorParallelBackend(BaseParallelBackend): + def __init__( + self, + world_size: int, + pp_degree: int = 1, + dp_degree: int = 1, + dp_shards: int = -1, + cp_degree: int = 1, + tp_degree: int = 1, + backend: str = "nccl", + timeout: int = 180, + logging_dir: Optional[str] = None, + output_dir: Optional[str] = None, + gradient_accumulation_steps: Optional[int] = None, + ) -> None: + super().__init__() + + self._world_size = world_size + self._pp_degree = pp_degree + self._dp_degree = dp_degree + self._dp_shards = dp_shards + self._cp_degree = cp_degree + self._tp_degree = tp_degree + self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None + self._logging_dir = ( + self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None + ) + self._backend = backend + self._timeout = timeout + + for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]: + if degree < 1: + raise ValueError(f"Parallel degree must be at least 1, got {degree}.") + + if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size: + raise ValueError( + f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards." + ) + + torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)) + _device_module.set_device(self.local_rank) + + logger.info( + f"Initialized parallel state with:\n" + f" - World size: {world_size}\n" + f" - Pipeline parallel degree: {pp_degree}\n" + f" - Data parallel degree: {dp_degree}\n" + f" - Context parallel degree: {cp_degree}\n" + f" - Tensor parallel degree: {tp_degree}\n" + f" - Data parallel shards: {dp_shards}\n" + ) + + self._mesh: torch.distributed.DeviceMesh = None + + def enable_determinism(self, seed): + world_mesh = self.get_mesh() + enable_determinism(seed, world_mesh) + + def apply_ddp( + self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None + ) -> torch.nn.Module: + if device_mesh is None: + device_mesh = self.get_mesh() + apply_ddp(model, device_mesh) + logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.") + return model + + def apply_fsdp2( + self, + model: torch.nn.Module, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype, + pp_enabled: bool = False, + cpu_offload: bool = False, + device_mesh: Optional[torch.distributed.DeviceMesh] = None, + ) -> torch.nn.Module: + if device_mesh is None: + device_mesh = self.get_mesh() + apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload) + logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.") + return model + + def apply_context_parallel( + self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None + ) -> torch.nn.Module: + if device_mesh is None: + device_mesh = self.get_mesh() + apply_context_parallel(model, device_mesh) + logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.") + return model + + def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: + return model + + def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: + if self._dp_degree == 1: + return dataset + dp_mesh = self.get_mesh()["dp_replicate"] + dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) + logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!") + return dataset + + def prepare_dataloader( + self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool + ) -> DPDataLoader: + if self._dp_degree == 1: + dp_local_rank = 0 + else: + dp_mesh = self.get_mesh()["dp_replicate"] + dp_local_rank = dp_mesh.get_local_rank() + dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) + logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") + return dataloader + + def prepare_optimizer(self, optimizer, lr_scheduler): + logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!") + return optimizer, lr_scheduler + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + def _get_mesh(): + if name is None: + return self._mesh + try: + return self._mesh[name] + except (KeyError, RuntimeError): + if self._mesh.ndim == 0: + return None + return self._mesh + + if self._mesh is not None: + return _get_mesh() + + mesh_list = [ + ("pp", self._pp_degree), + ("dp_replicate", self._dp_degree), + ("dp_shard", self._dp_shards), + ("cp", self._cp_degree), + ("tp", self._tp_degree), + ] + mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] + names = [x[0] for x in mesh_list] + degrees = [x[1] for x in mesh_list] + mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) + + dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] + + if self.data_replication_enabled: + dp_mesh_names.append("dp_replicate") + dp_cp_mesh_names.append("dp_replicate") + if self.data_sharding_enabled: + dp_mesh_names.append("dp_shard") + dp_cp_mesh_names.append("dp_shard") + dp_shard_cp_mesh_names.append("dp_shard") + if self.context_parallel_enabled: + dp_cp_mesh_names.append("cp") + dp_shard_cp_mesh_names.append("cp") + + if len(dp_mesh_names) > 0: + mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") + if len(dp_cp_mesh_names) > 0: + mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") + if len(dp_shard_cp_mesh_names) > 0: + mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") + + logger.debug(f"Device mesh: {mesh}") + self._mesh = mesh + return _get_mesh() + + def get_checkpointer(self, *args, **kwargs): + return PTDCheckpointer(*args, **kwargs) + + @property + def world_size(self): + return torch.distributed.get_world_size() + + @property + def rank(self): + return torch.distributed.get_rank() + + @property + def local_rank(self): + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def is_main_process(self): + r"""Returns `True` if the current process is the main process on the master node.""" + return self.rank == 0 + + @property + def is_local_main_process(self): + r"""Returns `True` if the current process is the main process on local node.""" + return self.local_rank == 0 + + @property + def device(self): + return torch.device(_device_type, self.local_rank) + + def wait_for_everyone(self): + return torch.distributed.barrier() + + # @contextmanager + # def main_process_first(self): + # if self.is_main_process: + # yield + # self.wait_for_everyone() + # else: + # self.wait_for_everyone() + # yield + + def destroy(self): + if self.is_main_process and self.tracker is not None: + self.tracker.finish() + return torch.distributed.destroy_process_group() + + @property + def pipeline_parallel_enabled(self): + return self._pp_degree > 1 + + @property + def data_parallel_enabled(self): + return self._dp_degree > 1 or self._dp_shards > 1 + + @property + def data_replication_enabled(self): + return self._dp_degree > 1 + + @property + def data_sharding_enabled(self): + return self._dp_shards > 1 + + @property + def context_parallel_enabled(self): + return self._cp_degree > 1 + + @property + def tensor_parallel_enabled(self): + return self._tp_degree > 1 + + +class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: + self.model = [model] if isinstance(model, torch.nn.Module) else model + + def state_dict(self) -> Dict[str, Any]: + return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_model_state_dict, + model_state_dict=state_dict, + options=StateDictOptions(strict=False), + ) + list(map(func, self.model)) + + +class PTDCheckpointer(BaseCheckpointer): + def __init__( + self, + dataloader: torch.utils.data.DataLoader, + model_parts: List[torch.nn.Module], + optimizers: "optimizer.OptimizerWrapper", + schedulers: "optimizer.SchedulerWrapper", + states: Dict[str, Any], + checkpointing_steps: int, + checkpointing_limit: int, + output_dir: str, + enable: bool = True, + _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, + _prefix: str = "finetrainers_step", + ) -> None: + self.states = states + self.states.update( + { + "model": ModelWrapper(model_parts), + "optimizer": optimizers, + "dataloader": dataloader, + } + ) + self.states.update(schedulers.get_lr_scheduler_state()) + + self.checkpointing_steps = checkpointing_steps + self.checkpointing_limit = checkpointing_limit + self.output_dir = pathlib.Path(output_dir) + self.enable = enable + self._callback_fn = _callback_fn + self._prefix = _prefix + + logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") + + def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: + if not self._should_checkpoint(step, force): + return None + + checkpoint_dir = self._get_checkpoint_dir(step) + begin_time = time.monotonic() + torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) + end_time = time.monotonic() + logger.info( + f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" + ) + self._purge_stale_checkpoints() + + state_dicts = [ + gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) + for model in self.states["model"].model + ] + if self._callback_fn is not None: + list(map(self._callback_fn, state_dicts)) + + return checkpoint_dir.as_posix() + + def load(self, step: int = -1) -> bool: + if not self.enable: + return False + if not self.output_dir.exists(): + return False + if step != -1 and not self._get_checkpoint_dir(step).exists(): + return False + + if step == -1: + latest_checkpoint_dir = self._find_latest_checkpoint_dir() + if latest_checkpoint_dir is None: + return False + step = int(latest_checkpoint_dir.name.split("_")[-1]) + + checkpoint_dir = self._get_checkpoint_dir(step) + logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") + + # For step 0, optimizers/schedulers are not available as they are created during training after first step + states = {"model": self.states["model"]} if step == 0 else self.states + + # See bug: https://github.com/pytorch/pytorch/pull/138575 + original_stateful_states = { + k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful) + } + begin_time = time.monotonic() + torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) + end_time = time.monotonic() + logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") + + # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() + states.update(original_stateful_states) + + return True + + def _should_checkpoint(self, step: int, force: bool) -> bool: + if not self.enable: + return False + if not force: + if step % self.checkpointing_steps != 0: + return False + return True + + def _get_checkpoint_dir(self, step: int) -> pathlib.Path: + return self.output_dir / f"{self._prefix}_{step}" + + def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: + checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) + return checkpoints[-1] if len(checkpoints) > 0 else None + + def _purge_stale_checkpoints(self) -> None: + if self.checkpointing_limit is None or self.checkpointing_limit <= 0: + return + checkpoints = sorted( + self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True + ) + for checkpoint in checkpoints[self.checkpointing_limit :]: + logger.info(f"Deleting stale checkpoint: {checkpoint}") + shutil.rmtree(checkpoint, ignore_errors=True) + + +def gather_state_dict_on_cpu_rank0( + model, device: Optional[torch.device] = None, *, is_main_process: bool +) -> Dict[str, Any]: + cpu_state_dict = {} + sharded_sd = model.state_dict() + for param_name, param in sharded_sd.items(): + if param.is_cpu: + # Move back to device if offloaded to CPU + param = param.to(device) + if hasattr(param, "_local_tensor"): + # Gather DTensor + param = param.full_tensor() + if is_main_process: + cpu_state_dict[param_name] = param.cpu() + torch.distributed.barrier() + return cpu_state_dict + + +# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict +# def dcp_to_torch_save( +# dcp_checkpoint_dir: Union[str, os.PathLike], +# torch_save_path: Union[str, os.PathLike], +# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, +# ): +# """ +# Given a directory containing a DCP checkpoint, this function will convert it into a +# Torch save file. + +# Args: +# dcp_checkpoint_dir: Directory containing the DCP checkpoint. +# torch_save_path: Filename to store the converted Torch save file. +# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. + +# .. warning:: +# To avoid OOM, it's recommended to only run this function on a single rank. +# """ +# state_dict = {} +# _load_state_dict( +# state_dict, +# storage_reader=FileSystemReader(dcp_checkpoint_dir), +# planner=_EmptyStateDictLoadPlanner(), +# no_dist=True, +# ) +# if callback_fn is not None: +# state_dict = callback_fn(state_dict) +# torch.save(state_dict, torch_save_path) + + +def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + +def apply_fsdp2( + model: torch.nn.Module, + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype, + pp_enabled: bool = False, + cpu_offload: bool = False, +) -> None: + """Apply FSDP2 on a model.""" + mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) + + def apply_fully_shard(blocks): + for layer_index, block in enumerate(blocks): + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = layer_index < len(blocks) - 1 + fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) + + for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: + blocks = getattr(model, transformer_block_name, None) + if blocks is not None: + apply_fully_shard(blocks) + + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_context_parallel( + model: torch.nn.Module, + mesh: torch.distributed.device_mesh.DeviceMesh, + plan: Optional[Dict[str, ContextParallelModelPlan]] = None, +) -> None: + """Apply context parallel on a model.""" + logger.debug(f"Applying context parallel with CP mesh: {mesh}") + model_cls = unwrap_module(model).__class__ + + if plan is None: + plan = TransformerRegistry.get(model_cls).cp_plan + + for module_id, cp_model_plan in plan.items(): + module = get_submodule_by_name(model, module_id) + if not isinstance(module, list): + module = [module] + logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules") + for m in module: + registry = HookRegistry.check_if_exists_or_initialize(m) + if isinstance(cp_model_plan, list): + # Metadata can only be a list when it is a list of CPOutput + assert all(isinstance(x, CPOutput) for x in cp_model_plan) + hook = ContextParallelGatherHook(cp_model_plan, mesh) + hook_name = f"cp_output---{module_id}" + else: + hook = ContextParallelSplitHook(cp_model_plan, mesh) + hook_name = f"cp_input---{module_id}" + registry.register_hook(hook, hook_name) + + +class ContextParallelSplitHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: + super().__init__() + self.metadata = metadata + self.mesh = mesh + + def pre_forward(self, module, *args, **kwargs): + args_list = list(args) + + for param_identifier, cpm in self.metadata.items(): + name = param_identifier.name + index = param_identifier.index + + if isinstance(cpm, CPInput) and cpm.split_output: + continue + + # Maybe the parameter was passed as a keyword argument + is_kwarg = True + input_val = kwargs.get(name, None) + + # If not, maybe it was passed as a positional argument + if input_val is None and index is not None: + if index < len(args_list): # Ensure index is within bounds + input_val = args_list[index] + is_kwarg = False + else: + logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.") + continue # Skip if index is invalid + + # Either the input_val is truly None, or argument is passed as normal argument + # but user forgot to specify the index when registering metadata + if input_val is None: + continue + + # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard + # the output instead of input for a particular layer by setting split_output=True + if torch.is_tensor(input_val): + input_val = self._prepare_cp_input(input_val, cpm) + + elif isinstance(input_val, (list, tuple)): + if len(input_val) != len(cpm): + raise ValueError( + f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." + ) + sharded_input_val = [] + for i, x in enumerate(input_val): + if torch.is_tensor(x) and not cpm[i].split_output: + x = self._prepare_cp_input(x, cpm[i]) + sharded_input_val.append(x) + input_val = sharded_input_val + + else: + raise ValueError(f"Unsupported input type: {type(input_val)}") + + if is_kwarg: + kwargs[name] = input_val + elif index is not None and index < len(args_list): + args_list[index] = input_val + + return tuple(args_list), kwargs + + def post_forward(self, module, output): + is_tensor = torch.is_tensor(output) + is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output) + if not is_tensor and not is_tensor_list: + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + output = [output] if is_tensor else list(output) + for param_identifier, cpm in self.metadata.items(): + if not isinstance(cpm, CPInput) or not cpm.split_output: + continue + index = param_identifier.index + if index >= len(output): + raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") + current_output = output[index] + current_output = self._prepare_cp_input(current_output, cpm) + output[index] = current_output + return output[0] if is_tensor else tuple(output) + + def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor: + if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: + raise ValueError( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + ) + return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh) + + +class ContextParallelGatherHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: + super().__init__() + self.metadata = metadata + self.mesh = mesh + + def post_forward(self, module, output): + is_tensor = torch.is_tensor(output) + if is_tensor: + output = [output] + output = list(output) + assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}." + for i, cpm in enumerate(self.metadata): + if cpm is None: + continue + output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh) + return output[0] if is_tensor else tuple(output) + + +class _ContextParallelSharder: + @classmethod + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses") + + @classmethod + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses") + + +class _EquipartitionSharder(_ContextParallelSharder): + """ + Shards the input tensor along the specified dimension into cp_mesh's world size chunks. + Essentially, rank_i gets the i-th chunk. + + This sharding strategy should only be used when performing full attention. Otherwise, it will + have performance penalty. If using causal attention, please use _CausalSharder instead. + """ + + @classmethod + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + assert tensor.size()[dim] % mesh.size() == 0 + return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()] + + @classmethod + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + tensor = tensor.contiguous() + # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim + result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() + return result + + +# TODO(aryan): this class is untested +class _CausalSharder(_ContextParallelSharder): + """ + Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks. + Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk. + + This sharding strategy improves the performance for causal attention, as it allows + equal distribution of computation across all ranks. + + Causal attention mask: + ``` + 1 0 0 0 <--- Group 0 + 1 1 0 0 <--- Group 1 + 1 1 1 0 <--- Group 1 + 1 1 1 1 <--- Group 0 + ``` + """ + + @classmethod + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + world_size = mesh.size() + rank = mesh.get_local_rank() + assert tensor.size()[dim] % (2 * world_size) == 0 + chunks = tensor.chunk(2 * world_size, dim=dim) + i, j = rank, 2 * world_size - 1 - rank + return torch.cat((chunks[i], chunks[j]), dim=dim) + + @classmethod + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + tensor = tensor.contiguous() + world_size = mesh.size() + # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim + all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() + sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)] + ordered_tensors = list(sliced_tensors) + for i, t in enumerate(sliced_tensors): + if i % 2 == 0: + ordered_tensors[i // 2] = t + else: + ordered_tensors[world_size * 2 - (i // 2) - 1] = t + return torch.cat(ordered_tensors, dim=dim) diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py b/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a13ef10bd679d4443bea447eaba90a883b763c7e --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py @@ -0,0 +1,19 @@ +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor + + +def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + if isinstance(x, torch.distributed.tensor.DTensor): + # functional collectives do not support DTensor inputs + x = x.full_tensor() + assert x.numel() == 1 # required by `.item()` + return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() + + +def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) + + +def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c81ba8da8f19b818ea21979a4ec237f9ee56aeb9 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +import torch + +from .dependencies.diffusers.peft import load_lora_weights + + +if TYPE_CHECKING: + from finetrainers.args import BaseArgsType + from finetrainers.parallel import ParallelBackendType + + +def perform_patches_for_training(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None: + # To avoid circular imports + from finetrainers.config import ModelType, TrainingType + + from .dependencies.diffusers import patch + + # Modeling patches + patch_scaled_dot_product_attention() + + patch.patch_diffusers_rms_norm_forward() + + # LTX Video patches + if args.model_name == ModelType.LTX_VIDEO: + from .models.ltx_video import patch + + patch.patch_transformer_forward() + if parallel_backend.tensor_parallel_enabled: + patch.patch_apply_rotary_emb_for_tp_compatibility() + + # Wan patches + if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules: + from .models.wan import patch + + patch.patch_time_text_image_embedding_forward() + + # LoRA patches + if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0: + from .dependencies.peft import patch + + patch.patch_peft_move_adapter_to_device_of_base_layer() + + +def perform_patches_for_inference(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None: + # To avoid circular imports + from .dependencies.diffusers import patch + + # Modeling patches + patch_scaled_dot_product_attention() + + patch.patch_diffusers_rms_norm_forward() + + +def patch_scaled_dot_product_attention(): + from finetrainers.models.attention_dispatch import attention_dispatch + + torch.nn.functional.scaled_dot_product_attention = attention_dispatch diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py new file mode 100644 index 0000000000000000000000000000000000000000..baa45910659f6a79ad8d133cf76671482284b44a --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py @@ -0,0 +1,36 @@ +from contextlib import contextmanager +from typing import List, Union + +import torch +from diffusers.hooks import HookRegistry, ModelHook + + +_CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK" + + +class ControlChannelConcatenateHook(ModelHook): + def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]): + self.input_names = input_names + self.inputs = inputs + self.dims = dims + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims): + original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name] + control_tensor = torch.cat([original_tensor, input_tensor], dim=dim) + if isinstance(input_name, int): + args[input_name] = control_tensor + else: + kwargs[input_name] = control_tensor + return args, kwargs + + +@contextmanager +def control_channel_concat( + module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int] +): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = ControlChannelConcatenateHook(input_names, inputs, dims) + registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK) + yield + registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False) diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0c7952574b034039a0082caec50d4253a343ab --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py @@ -0,0 +1,6 @@ +def patch_diffusers_rms_norm_forward() -> None: + import diffusers.models.normalization + + from .rms_norm import _patched_rms_norm_forward + + diffusers.models.normalization.RMSNorm.forward = _patched_rms_norm_forward diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py new file mode 100644 index 0000000000000000000000000000000000000000..f625323b548e159717598f5b5990d6626f5fc3b0 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py @@ -0,0 +1,61 @@ +import json +from pathlib import Path +from typing import Optional + +import safetensors.torch +from diffusers import DiffusionPipeline +from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA +from huggingface_hub import repo_exists, snapshot_download +from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + +from finetrainers.logging import get_logger +from finetrainers.utils import find_files + + +logger = get_logger() + + +def load_lora_weights( + pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs +) -> None: + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + + is_local_file_path = Path(pretrained_model_name_or_path).is_dir() + if not is_local_file_path: + does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model") + if not does_repo_exist: + raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.") + else: + pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model") + + prefix = "transformer" + state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path) + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + + file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1) + if len(file_list) == 0: + raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.") + if len(file_list) > 1: + logger.warning( + f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}." + ) + with safetensors.torch.safe_open(file_list[0], framework="pt") as f: + metadata = f.metadata() + metadata = json.loads(metadata["lora_config"]) + + transformer = pipeline.transformer + if adapter_name is None: + adapter_name = "default" + + lora_config = LoraConfig(**metadata) + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + result = set_peft_model_state_dict( + transformer, + state_dict, + adapter_name=adapter_name, + ignore_mismatched_sizes=False, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + logger.debug( + f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}" + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e9a4ef14590665a44104f9cdf1651b26fe81b0 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +from diffusers.utils import is_torch_npu_available, is_torch_version + + +def _patched_rms_norm_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if is_torch_npu_available(): + import torch_npu + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] + if self.bias is not None: + hidden_states = hidden_states + self.bias + elif is_torch_version(">=", "2.4"): + ### ===== ======= + input_dtype = hidden_states.dtype + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = nn.functional.rms_norm( + hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps + ) + if self.bias is not None: + hidden_states = hidden_states + self.bias + hidden_states = hidden_states.to(input_dtype) + ### ===== ===== + else: + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..4de4b1a965fa6c33ebc9acad81fb1dddc1ba8de2 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py @@ -0,0 +1,25 @@ +import functools + +from peft.tuners.tuners_utils import BaseTunerLayer + +from finetrainers.patches.utils import DisableTensorToDtype + + +def patch_peft_move_adapter_to_device_of_base_layer() -> None: + _perform_patch_move_adapter_to_device_of_base_layer() + + +def _perform_patch_move_adapter_to_device_of_base_layer() -> None: + BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( + BaseTunerLayer._move_adapter_to_device_of_base_layer + ) + + +def _patched_move_adapter_to_device_of_base_layer(func) -> None: + # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor. + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with DisableTensorToDtype(): + return func(self, *args, **kwargs) + + return wrapper diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8caa803f0716280ff066d6e7865746344fb8e9 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, Optional, Tuple + +import diffusers +import torch +from diffusers import LTXVideoTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils.import_utils import is_torch_version + + +def patch_transformer_forward() -> None: + _perform_ltx_transformer_forward_patch() + + +def patch_apply_rotary_emb_for_tp_compatibility() -> None: + _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() + + +def _perform_ltx_transformer_forward_patch() -> None: + LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward + + +def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None: + def apply_rotary_emb(x, freqs): + cos, sin = freqs + # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ======== + # The change is made due to unsupported DTensor operation aten.ops.unbind + # FIXME: Once aten.ops.unbind support lands, this will no longer be required + # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] + x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2] + # ================================================================== + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb + + +def _patched_LTXVideoTransformer3D_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + return_dict: bool = True, + *args, + **kwargs, +) -> torch.Tensor: + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # ===== This is modified compared to Diffusers ===== + # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep + if timestep.ndim == 1: + timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1) + # ================================================== + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + # ===== This is modified compared to Diffusers ===== + # temb = temb.view(batch_size, -1, temb.size(-1)) + # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + # ================================================== + # This is done to make it possible to use per-token timestep embedding + temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1)) + # ================================================== + + hidden_states = self.proj_in(hidden_states) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c44ae42637fb9c6fc0a9803930f1728a92b693 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py @@ -0,0 +1,33 @@ +from typing import Optional + +import diffusers +import torch + + +def patch_time_text_image_embedding_forward() -> None: + _patch_time_text_image_embedding_forward() + + +def _patch_time_text_image_embedding_forward() -> None: + diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = ( + _patched_WanTimeTextImageEmbedding_forward + ) + + +def _patched_WanTimeTextImageEmbedding_forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, +): + # Some code has been removed compared to original implementation in Diffusers + # Also, timestep is typed as that of encoder_hidden_states + timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/utils.py b/docs/finetrainers-src-codebase/finetrainers/patches/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7f4726cc8183a461310570762ee95b5c4e6187 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/patches/utils.py @@ -0,0 +1,18 @@ +import torch + + +class DisableTensorToDtype: + def __enter__(self): + self.original_to = torch.Tensor.to + + def modified_to(tensor, *args, **kwargs): + # remove dtype from args if present + args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] + if "dtype" in kwargs: + kwargs.pop("dtype") + return self.original_to(tensor, *args, **kwargs) + + torch.Tensor.to = modified_to + + def __exit__(self, *args, **kwargs): + torch.Tensor.to = self.original_to diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py b/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82a99170fa50b01e2767f21591db37e8e3046883 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, List, Optional + +from .base import ProcessorMixin +from .canny import CannyProcessor +from .clip import CLIPPooledProcessor +from .glm import CogView4GLMProcessor +from .llama import LlamaProcessor +from .t5 import T5Processor +from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor + + +class CopyProcessor(ProcessorMixin): + r"""Processor that copies the input data unconditionally to the output.""" + + def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None): + super().__init__() + + self.output_names = output_names + self.input_names = input_names + assert len(output_names) == 1 + + def forward(self, input: Any) -> Any: + return {self.output_names[0]: input} diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/base.py b/docs/finetrainers-src-codebase/finetrainers/processors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8989fd70f359268620a16d1cca885983eed02d --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/base.py @@ -0,0 +1,24 @@ +import inspect +from typing import Any, Dict, List + + +class ProcessorMixin: + def __init__(self) -> None: + self._forward_parameter_names = inspect.signature(self.forward).parameters.keys() + self.output_names: List[str] = None + self.input_names: Dict[str, Any] = None + + def __call__(self, *args, **kwargs) -> Any: + shallow_copy_kwargs = dict(kwargs.items()) + if self.input_names is not None: + for k, v in self.input_names.items(): + if k in shallow_copy_kwargs: + shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k) + acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names} + output = self.forward(*args, **acceptable_kwargs) + if "__drop__" in output: + output.pop("__drop__") + return output + + def forward(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.") diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/canny.py b/docs/finetrainers-src-codebase/finetrainers/processors/canny.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf95e8c753e7fb539e8b0fde15788445fafd8d --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/canny.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch + +from ..utils.import_utils import is_kornia_available +from .base import ProcessorMixin + + +if is_kornia_available(): + import kornia + + +class CannyProcessor(ProcessorMixin): + r""" + Processor for obtaining the Canny edge detection of an image. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the Canny edge detection of + the input image. + """ + + def __init__( + self, + output_names: List[str] = None, + input_names: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + + self.output_names = output_names + self.input_names = input_names + self.device = device + assert len(output_names) == 1 + + def forward(self, input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]]) -> torch.Tensor: + r""" + Obtain the Canny edge detection of the input image. + + Args: + input (`torch.Tensor`, `PIL.Image.Image`, or `List[PIL.Image.Image]`): + The input tensor, image or list of images for which the Canny edge detection should be obtained. + If a tensor, must be a 3D (CHW) or 4D (BCHW) or 5D (BTCHW) tensor. The input tensor should have + values in the range [0, 1]. + + Returns: + torch.Tensor: + The Canny edge detection of the input image. The output has the same shape as the input tensor. If + the input is an image, the output is a 3D tensor. If the input is a list of images, the output is a 5D + tensor. The output tensor has values in the range [0, 1]. + """ + if isinstance(input, PIL.Image.Image): + input = kornia.utils.image.image_to_tensor(np.array(input)).unsqueeze(0) / 255.0 + input = input.to(self.device) + output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1).squeeze(0) + elif isinstance(input, list): + input = kornia.utils.image.image_list_to_tensor([np.array(img) for img in input]) / 255.0 + output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) + else: + ndim = input.ndim + assert ndim in [3, 4, 5] + + batch_size = 1 if ndim == 3 else input.size(0) + + if ndim == 3: + input = input.unsqueeze(0) # [C, H, W] -> [1, C, H, W] + elif ndim == 5: + input = input.flatten(0, 1) # [B, F, C, H, W] -> [B*F, C, H, W] + + output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1) + output = output[0] if ndim == 3 else output.unflatten(0, (batch_size, -1)) if ndim == 5 else output + + # TODO(aryan): think about how one can pass parameters to the underlying function from + # a UI perspective. It's important to think about ProcessorMixin in terms of a Graph-based + # data processing pipeline. + return {self.output_names[0]: output} diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/clip.py b/docs/finetrainers-src-codebase/finetrainers/processors/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..e58282b69fb4845f079b103122a732acf7348d14 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/clip.py @@ -0,0 +1,63 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast + +from .base import ProcessorMixin + + +class CLIPPooledProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + + self.output_names = output_names + self.input_names = input_names + + assert len(output_names) == 1 + + def forward( + self, + tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], + text_encoder: CLIPTextModel, + caption: Union[str, List[str]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + text_inputs = tokenizer( + caption, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return {self.output_names[0]: prompt_embeds} diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/glm.py b/docs/finetrainers-src-codebase/finetrainers/processors/glm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf742130bb7da8808710ec562c85d9c64a535cb6 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/glm.py @@ -0,0 +1,74 @@ +from typing import List, Tuple, Union + +import torch +from transformers import AutoTokenizer, GlmModel + +from .base import ProcessorMixin + + +class CogView4GLMProcessor(ProcessorMixin): + r""" + Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + This processor is specific to CogView4 but can be used with any other model. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + + self.output_names = output_names + + assert len(self.output_names) == 1 + + def forward( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + caption: Union[str, List[str]], + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`AutoTokenizer`): + The tokenizer used to tokenize the input text. + text_encoder (`GlmModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + text_inputs = tokenizer( + caption, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + current_length = text_input_ids.size(1) + pad_length = 16 - current_length % 16 + if pad_length > 0: + pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return {self.output_names[0]: prompt_embeds} diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/llama.py b/docs/finetrainers-src-codebase/finetrainers/processors/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..749e5f313541b92317279669faf915edeb9129c4 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/llama.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast + +from .base import ProcessorMixin + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +class LlamaProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None): + super().__init__() + + self.output_names = output_names + + assert len(output_names) == 2 + + def forward( + self, + tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], + text_encoder: LlamaModel, + caption: Union[str, List[str]], + max_sequence_length: int, + prompt_template: Optional[Dict[str, Any]] = None, + num_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + prompt_template (`Optional[Dict[str, Any]]`): + The prompt template to be used to encode the input text. + """ + if prompt_template is None: + prompt_template = DEFAULT_PROMPT_TEMPLATE + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + batch_size = len(caption) + caption = [prompt_template["template"].format(c) for c in caption] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = tokenizer( + caption, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + prompt_embeds = text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-(num_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return { + self.output_names[0]: prompt_embeds, + self.output_names[1]: prompt_attention_mask, + } diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/t5.py b/docs/finetrainers-src-codebase/finetrainers/processors/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..006aed2c18376c3ff1509bd6fadd57e48ea39350 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/t5.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast + +from .base import ProcessorMixin + + +class T5Processor(ProcessorMixin): + r""" + Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__( + self, + output_names: List[str], + input_names: Optional[Dict[str, Any]] = None, + *, + use_attention_mask: bool = False, + ): + super().__init__() + + self.output_names = output_names + self.input_names = input_names + self.use_attention_mask = use_attention_mask + + if input_names is not None: + assert len(input_names) <= 4 + assert len(self.output_names) == 2 + + def forward( + self, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + text_encoder: T5EncoderModel, + caption: Union[str, List[str]], + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`T5EncoderModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + batch_size = len(caption) + text_inputs = tokenizer( + caption, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + te_mask = None + if self.use_attention_mask: + te_mask = prompt_attention_mask + + prompt_embeds = text_encoder(text_input_ids.to(device), te_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return { + self.output_names[0]: prompt_embeds, + self.output_names[1]: prompt_attention_mask, + } diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/text.py b/docs/finetrainers-src-codebase/finetrainers/processors/text.py new file mode 100644 index 0000000000000000000000000000000000000000..884284725004e285a817424ef4561b3aefeb466a --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/processors/text.py @@ -0,0 +1,23 @@ +from typing import List, Union + +import torch + +import finetrainers.functional as FF + +from .base import ProcessorMixin + + +class CaptionTextDropoutProcessor(ProcessorMixin): + def __init__(self, dropout_p: float = 0.0) -> None: + self.dropout_p = dropout_p + + def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]: + return FF.dropout_caption(caption, self.dropout_p) + + +class CaptionEmbeddingDropoutProcessor(ProcessorMixin): + def __init__(self, dropout_p: float = 0.0) -> None: + self.dropout_p = dropout_p + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + return FF.dropout_embeddings_to_zero(embedding, self.dropout_p) diff --git a/docs/finetrainers-src-codebase/finetrainers/state.py b/docs/finetrainers-src-codebase/finetrainers/state.py new file mode 100644 index 0000000000000000000000000000000000000000..0a44b6d6df74139b5ee405cc90288ec58abda3bd --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/state.py @@ -0,0 +1,66 @@ +import io +from dataclasses import dataclass, field +from typing import Any, Dict, List + +import torch +import torch.distributed.checkpoint.stateful + +from .parallel import ParallelBackendType +from .utils import get_device_info + + +_device_type, _ = get_device_info() + + +@dataclass +class TrainState(torch.distributed.checkpoint.stateful.Stateful): + step: int = 0 + observed_data_samples: int = 0 + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = io.BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = io.BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = io.BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + state_dict["global_avg_losses"].seek(0) + state_dict["global_max_losses"].seek(0) + state_dict["log_steps"].seek(0) + + self.step = state_dict["step"].item() + self.observed_data_samples = state_dict["observed_data_samples"].item() + self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) + self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) + + +@dataclass +class State: + # Parallel state + parallel_backend: ParallelBackendType = None + + # Training state + train_state: TrainState = None + num_trainable_parameters: int = 0 + generator: torch.Generator = None + + # Hub state + repo_id: str = None + + # Artifacts state + output_dir: str = None diff --git a/docs/finetrainers-src-codebase/finetrainers/trackers.py b/docs/finetrainers-src-codebase/finetrainers/trackers.py new file mode 100644 index 0000000000000000000000000000000000000000..68a53c5adc5934b8a1a802a1e48d2e5c5323b240 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trackers.py @@ -0,0 +1,145 @@ +import contextlib +import copy +import pathlib +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from .logging import get_logger +from .utils import Timer, TimerDevice + + +logger = get_logger() + + +class BaseTracker: + r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging.""" + + def __init__(self): + self._timed_metrics = {} + + @contextlib.contextmanager + def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False): + r"""Context manager to track time for a specific operation.""" + timer = Timer(name, device, device_sync) + timer.start() + yield timer + timer.end() + elapsed_time = timer.elapsed_time + if name in self._timed_metrics: + # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet + self._timed_metrics[name] += elapsed_time + else: + self._timed_metrics[name] = elapsed_time + + def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + def finish(self) -> None: + pass + + +class DummyTracker(BaseTracker): + def __init__(self): + super().__init__() + + def log(self, *args, **kwargs): + pass + + def finish(self) -> None: + pass + + +class WandbTracker(BaseTracker): + r"""Logger implementation for Weights & Biases.""" + + def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + + import wandb + + self.wandb = wandb + + # WandB does not create a directory if it does not exist and instead starts using the system temp directory. + pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) + + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) + logger.info("WandB logging enabled") + + def log(self, metrics: Dict[str, Any], step: int) -> None: + metrics = {**self._timed_metrics, **metrics} + self.run.log(metrics, step=step) + self._timed_metrics = {} + + def finish(self) -> None: + self.run.finish() + + +class SequentialTracker(BaseTracker): + r"""Sequential tracker that logs to multiple trackers in sequence.""" + + def __init__(self, trackers: List[BaseTracker]) -> None: + super().__init__() + self.trackers = trackers + + @contextlib.contextmanager + def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False): + r"""Context manager to track time for a specific operation.""" + timer = Timer(name, device, device_sync) + timer.start() + yield timer + timer.end() + elapsed_time = timer.elapsed_time + if name in self._timed_metrics: + # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet + self._timed_metrics[name] += elapsed_time + else: + self._timed_metrics[name] = elapsed_time + for tracker in self.trackers: + tracker._timed_metrics = copy.deepcopy(self._timed_metrics) + + def log(self, metrics: Dict[str, Any], step: int) -> None: + for tracker in self.trackers: + tracker.log(metrics, step) + self._timed_metrics = {} + + def finish(self) -> None: + for tracker in self.trackers: + tracker.finish() + + +class Trackers(str, Enum): + r"""Enum for supported trackers.""" + + NONE = "none" + WANDB = "wandb" + + +_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()] + + +def initialize_trackers( + trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str +) -> Union[BaseTracker, SequentialTracker]: + r"""Initialize loggers based on the provided configuration.""" + + logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}") + + if len(trackers) == 0: + return BaseTracker() + + if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)): + raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}") + + tracker_instances = [] + for tracker_name in set(trackers): + if tracker_name == Trackers.NONE: + tracker = BaseTracker() + elif tracker_name == Trackers.WANDB: + tracker = WandbTracker(experiment_name, log_dir, config) + tracker_instances.append(tracker) + + tracker = SequentialTracker(tracker_instances) + return tracker + + +TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker] diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30a243509e26f71cc7deab8cd3aa03f6fa779e98 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py @@ -0,0 +1,2 @@ +from .control_trainer import ControlTrainer +from .sft_trainer import SFTTrainer diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/base.py b/docs/finetrainers-src-codebase/finetrainers/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..445fc89ee43283992b2d2f4263fd12ef9c0d2d46 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/base.py @@ -0,0 +1,188 @@ +import contextlib +import functools +import os +from typing import Callable, List, Tuple + +import torch +import torch.backends +from diffusers.hooks import HookRegistry, ModelHook + +from finetrainers import logging, parallel, patches +from finetrainers.args import BaseArgsType +from finetrainers.logging import get_logger +from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry +from finetrainers.state import State + + +logger = get_logger() + +_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook" + + +class Trainer: + def __init__(self, args: BaseArgsType): + self.args = args + + self.state = State() + + self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training) + self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference) + + self._init_distributed() + self._init_config_options() + + # Perform any patches that might be necessary for training to work as expected + patches.perform_patches_for_training(self.args, self.state.parallel_backend) + + @contextlib.contextmanager + def attention_provider_ctx(self, training: bool = True): + name_providers_active = ( + self._module_name_providers_training if training else self._module_name_providers_inference + ) + name_providers_dict = dict(name_providers_active) + default_provider = _AttentionProviderRegistry._active_provider + + all_registered_module_names = [ + attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module) + ] + for module_name in all_registered_module_names: + if module_name in name_providers_dict: + continue + name_providers_dict[module_name] = default_provider + + module_providers_dict = {} + for module_name, provider in name_providers_dict.items(): + module = getattr(self, module_name, None) + if module is not None: + module_providers_dict[module] = (module_name, provider) + + # We don't want to immediately unset the attention provider to default after forward because if the + # model is being trained, the backward pass must be invoked with the same attention provider + # So, we lazily switch attention providers only when the forward pass of a new module is called + def callback(m: torch.nn.Module): + module_name, provider = module_providers_dict[m] + # HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage + if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled: + if not _AttentionProviderRegistry.supports_context_parallel(provider): + raise ValueError( + f"Attention provider {provider} does not support context parallel. Please use a different provider." + ) + _AttentionProviderRegistry._set_context_parallel( + mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather" + ) + _AttentionProviderRegistry._active_provider = provider + + # HACK: for VAE + if "vae" in name_providers_dict: + _apply_forward_hooks_hack(self.vae, name_providers_dict["vae"]) + + for module in module_providers_dict.keys(): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LatestActiveModuleHook(callback) + registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK) + + yield + + _AttentionProviderRegistry._active_provider = default_provider + _AttentionProviderRegistry._set_context_parallel(reset=True) + for module in module_providers_dict.keys(): + registry: HookRegistry = module._diffusers_hook + registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK) + + def _init_distributed(self) -> None: + world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) + + # TODO(aryan): handle other backends + backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) + self.state.parallel_backend = backend_cls( + world_size=world_size, + pp_degree=self.args.pp_degree, + dp_degree=self.args.dp_degree, + dp_shards=self.args.dp_shards, + cp_degree=self.args.cp_degree, + tp_degree=self.args.tp_degree, + backend="nccl", + timeout=self.args.init_timeout, + logging_dir=self.args.logging_dir, + output_dir=self.args.output_dir, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + ) + + if self.args.seed is not None: + self.state.parallel_backend.enable_determinism(self.args.seed) + + def _init_logging(self) -> None: + logging._set_parallel_backend(self.state.parallel_backend) + logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) + logger.info("Initialized FineTrainers") + + def _init_trackers(self) -> None: + # TODO(aryan): handle multiple trackers + trackers = [self.args.report_to] + experiment_name = self.args.tracker_name or "finetrainers-experiment" + self.state.parallel_backend.initialize_trackers( + trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir + ) + + def _init_config_options(self) -> None: + # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision(self.args.float32_matmul_precision) + + @property + def tracker(self): + return self.state.parallel_backend.tracker + + +class LatestActiveModuleHook(ModelHook): + def __init__(self, callback: Callable[[torch.nn.Module], None] = None): + super().__init__() + self.callback = callback + + def pre_forward(self, module, *args, **kwargs): + self.callback(module) + return args, kwargs + + +def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]: + parsed_providers = [] + if attn_providers: + for provider_str in attn_providers: + parts = provider_str.split(":") + if len(parts) != 2: + raise ValueError( + f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'." + ) + parts[1] = AttentionProvider(parts[1]) + parsed_providers.append(tuple(parts)) + return parsed_providers + + +# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked +def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider): + if hasattr(module, "_finetrainers_wrapped_methods"): + return + + def create_wrapper(old_method): + @functools.wraps(old_method) + def wrapper(*args, **kwargs): + _AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement + old_provider = _AttentionProviderRegistry._active_provider + _AttentionProviderRegistry._active_provider = provider + output = old_method(*args, **kwargs) + _AttentionProviderRegistry._active_provider = old_provider + return output + + return wrapper + + methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"] + finetrainers_wrapped_methods = [] + for method_name in methods: + if not hasattr(module, method_name): + continue + method = getattr(module, method_name) + wrapper = create_wrapper(method) + setattr(module, method_name, wrapper) + finetrainers_wrapped_methods.append(method_name) + module._finetrainers_wrapped_methods = finetrainers_wrapped_methods diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b72fc82a2c73cfbbd8e95aaca9a1f127d15774 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py @@ -0,0 +1,2 @@ +from .config import ControlFullRankConfig, ControlLowRankConfig +from .trainer import ControlTrainer diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..14cfe715749fd74a622abd5afe5131efc0db130f --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py @@ -0,0 +1,185 @@ +import argparse +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from finetrainers.utils import ArgsConfigMixin + + +if TYPE_CHECKING: + from finetrainers.args import BaseArgs + + +class ControlType(str, Enum): + r""" + Enum class for the control types. + """ + + CANNY = "canny" + CUSTOM = "custom" + NONE = "none" + + +class FrameConditioningType(str, Enum): + r""" + Enum class for the frame conditioning types. + """ + + INDEX = "index" + PREFIX = "prefix" + RANDOM = "random" + FIRST_AND_LAST = "first_and_last" + FULL = "full" + + +class ControlLowRankConfig(ArgsConfigMixin): + r""" + Configuration class for SFT channel-concatenated Control low rank training. + + Args: + control_type (`str`, defaults to `"canny"`): + Control type for the low rank approximation matrices. Can be "canny", "custom". + rank (int, defaults to `64`): + Rank of the low rank approximation matrix. + lora_alpha (int, defaults to `64`): + The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. + target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`): + Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. + train_qk_norm (`bool`, defaults to `False`): + Whether to train the QK normalization layers. + frame_conditioning_type (`str`, defaults to `"full"`): + Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". + frame_conditioning_index (int, defaults to `0`): + Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". + frame_conditioning_concatenate_mask (`bool`, defaults to `False`): + Whether to concatenate the frame mask with the latents across channel dim. + """ + + control_type: str = ControlType.CANNY + rank: int = 64 + lora_alpha: int = 64 + target_modules: Union[str, List[str]] = ( + "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" + ) + train_qk_norm: bool = False + + # Specific to video models + frame_conditioning_type: str = FrameConditioningType.FULL + frame_conditioning_index: int = 0 + frame_conditioning_concatenate_mask: bool = False + + def add_args(self, parser: argparse.ArgumentParser): + parser.add_argument( + "--control_type", + type=str, + default=ControlType.CANNY.value, + choices=[x.value for x in ControlType.__members__.values()], + ) + parser.add_argument("--rank", type=int, default=64) + parser.add_argument("--lora_alpha", type=int, default=64) + parser.add_argument( + "--target_modules", + type=str, + nargs="+", + default=[ + "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" + ], + ) + parser.add_argument("--train_qk_norm", action="store_true") + parser.add_argument( + "--frame_conditioning_type", + type=str, + default=FrameConditioningType.INDEX.value, + choices=[x.value for x in FrameConditioningType.__members__.values()], + ) + parser.add_argument("--frame_conditioning_index", type=int, default=0) + parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") + + def validate_args(self, args: "BaseArgs"): + assert self.rank > 0, "Rank must be a positive integer." + assert self.lora_alpha > 0, "lora_alpha must be a positive integer." + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.control_type = argparse_args.control_type + mapped_args.rank = argparse_args.rank + mapped_args.lora_alpha = argparse_args.lora_alpha + mapped_args.target_modules = ( + argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules + ) + mapped_args.train_qk_norm = argparse_args.train_qk_norm + mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type + mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index + mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask + + def to_dict(self) -> Dict[str, Any]: + return { + "control_type": self.control_type, + "rank": self.rank, + "lora_alpha": self.lora_alpha, + "target_modules": self.target_modules, + "train_qk_norm": self.train_qk_norm, + "frame_conditioning_type": self.frame_conditioning_type, + "frame_conditioning_index": self.frame_conditioning_index, + "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, + } + + +class ControlFullRankConfig(ArgsConfigMixin): + r""" + Configuration class for SFT channel-concatenated Control full rank training. + + Args: + control_type (`str`, defaults to `"canny"`): + Control type for the low rank approximation matrices. Can be "canny", "custom". + train_qk_norm (`bool`, defaults to `False`): + Whether to train the QK normalization layers. + frame_conditioning_type (`str`, defaults to `"index"`): + Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full". + frame_conditioning_index (int, defaults to `0`): + Index of the frame conditioning. Only used if `frame_conditioning_type` is "index". + frame_conditioning_concatenate_mask (`bool`, defaults to `False`): + Whether to concatenate the frame mask with the latents across channel dim. + """ + + control_type: str = ControlType.CANNY + train_qk_norm: bool = False + + # Specific to video models + frame_conditioning_type: str = FrameConditioningType.INDEX + frame_conditioning_index: int = 0 + frame_conditioning_concatenate_mask: bool = False + + def add_args(self, parser: argparse.ArgumentParser): + parser.add_argument( + "--control_type", + type=str, + default=ControlType.CANNY.value, + choices=[x.value for x in ControlType.__members__.values()], + ) + parser.add_argument("--train_qk_norm", action="store_true") + parser.add_argument( + "--frame_conditioning_type", + type=str, + default=FrameConditioningType.INDEX.value, + choices=[x.value for x in FrameConditioningType.__members__.values()], + ) + parser.add_argument("--frame_conditioning_index", type=int, default=0) + parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true") + + def validate_args(self, args: "BaseArgs"): + pass + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.control_type = argparse_args.control_type + mapped_args.train_qk_norm = argparse_args.train_qk_norm + mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type + mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index + mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask + + def to_dict(self) -> Dict[str, Any]: + return { + "control_type": self.control_type, + "train_qk_norm": self.train_qk_norm, + "frame_conditioning_type": self.frame_conditioning_type, + "frame_conditioning_index": self.frame_conditioning_index, + "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask, + } diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5c91fec06a62e22148415860c48c20e6ae2605d8 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py @@ -0,0 +1,268 @@ +import random +from typing import Any, Dict, Optional + +import torch +import torch.distributed.checkpoint.stateful +from diffusers.video_processor import VideoProcessor + +import finetrainers.functional as FF +from finetrainers.logging import get_logger +from finetrainers.processors import CannyProcessor, CopyProcessor + +from .config import ControlType, FrameConditioningType + + +logger = get_logger() + + +class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__( + self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None + ): + super().__init__() + + self.dataset = dataset + self.control_type = control_type + + self.control_processors = [] + if control_type == ControlType.CANNY: + self.control_processors.append( + CannyProcessor( + output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device + ) + ) + elif control_type == ControlType.NONE: + self.control_processors.append( + CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"}) + ) + + logger.info("Initialized IterableControlDataset") + + def __iter__(self): + logger.info("Starting IterableControlDataset") + for data in iter(self.dataset): + control_augmented_data = self._run_control_processors(data) + yield control_augmented_data + + def load_state_dict(self, state_dict): + self.dataset.load_state_dict(state_dict) + + def state_dict(self): + return self.dataset.state_dict() + + def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: + if "control_image" in data: + if "image" in data: + data["control_image"] = FF.resize_to_nearest_bucket_image( + data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic" + ) + if "video" in data: + batch_size, num_frames, num_channels, height, width = data["video"].shape + data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( + data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" + ) + if _first_frame_only: + msg = ( + "The number of frames in the control video is less than the minimum bucket size " + "specified. The first frame is being used as a single frame video. This " + "message is logged at the first occurence and for every 128th occurence " + "after that." + ) + logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) + data["control_video"] = data["control_video"][0] + return data + + if "control_video" in data: + if "image" in data: + data["control_image"] = FF.resize_to_nearest_bucket_image( + data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic" + ) + if "video" in data: + batch_size, num_frames, num_channels, height, width = data["video"].shape + data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( + data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" + ) + if _first_frame_only: + msg = ( + "The number of frames in the control video is less than the minimum bucket size " + "specified. The first frame is being used as a single frame video. This " + "message is logged at the first occurence and for every 128th occurence " + "after that." + ) + logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) + data["control_video"] = data["control_video"][0] + return data + + if self.control_type == ControlType.CUSTOM: + return data + + shallow_copy_data = dict(data.items()) + is_image_control = "image" in shallow_copy_data + is_video_control = "video" in shallow_copy_data + if (is_image_control + is_video_control) != 1: + raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") + for processor in self.control_processors: + result = processor(**shallow_copy_data) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(shallow_copy_data.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + shallow_copy_data.update(result) + if "control_output" in shallow_copy_data: + # Normalize to [-1, 1] range + control_output = shallow_copy_data.pop("control_output") + # TODO(aryan): need to specify a dim for normalize here across channels + control_output = FF.normalize(control_output, min=-1.0, max=1.0) + key = "control_image" if is_image_control else "control_video" + shallow_copy_data[key] = control_output + return shallow_copy_data + + +class ValidationControlDataset(torch.utils.data.IterableDataset): + def __init__( + self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None + ): + super().__init__() + + self.dataset = dataset + self.control_type = control_type + self.device = device + self._video_processor = VideoProcessor() + + self.control_processors = [] + if control_type == ControlType.CANNY: + self.control_processors.append( + CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device) + ) + elif control_type == ControlType.NONE: + self.control_processors.append( + CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"}) + ) + + logger.info("Initialized ValidationControlDataset") + + def __iter__(self): + logger.info("Starting ValidationControlDataset") + for data in iter(self.dataset): + control_augmented_data = self._run_control_processors(data) + yield control_augmented_data + + def load_state_dict(self, state_dict): + self.dataset.load_state_dict(state_dict) + + def state_dict(self): + return self.dataset.state_dict() + + def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.control_type == ControlType.CUSTOM: + return data + # These are already expected to be tensors + if "control_image" in data or "control_video" in data: + return data + shallow_copy_data = dict(data.items()) + is_image_control = "image" in shallow_copy_data + is_video_control = "video" in shallow_copy_data + if (is_image_control + is_video_control) != 1: + raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") + for processor in self.control_processors: + result = processor(**shallow_copy_data) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(shallow_copy_data.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + shallow_copy_data.update(result) + if "control_output" in shallow_copy_data: + # Normalize to [-1, 1] range + control_output = shallow_copy_data.pop("control_output") + if torch.is_tensor(control_output): + # TODO(aryan): need to specify a dim for normalize here across channels + control_output = FF.normalize(control_output, min=-1.0, max=1.0) + ndim = control_output.ndim + assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5" + if ndim == 5: + control_output = self._video_processor.postprocess_video(control_output, output_type="pil") + else: + if ndim == 3: + control_output = control_output.unsqueeze(0) + control_output = self._video_processor.postprocess(control_output, output_type="pil")[0] + key = "control_image" if is_image_control else "control_video" + shallow_copy_data[key] = control_output + return shallow_copy_data + + +# TODO(aryan): write a test for this function +def apply_frame_conditioning_on_latents( + latents: torch.Tensor, + expected_num_frames: int, + channel_dim: int, + frame_dim: int, + frame_conditioning_type: FrameConditioningType, + frame_conditioning_index: Optional[int] = None, + concatenate_mask: bool = False, +) -> torch.Tensor: + num_frames = latents.size(frame_dim) + mask = torch.zeros_like(latents) + + if frame_conditioning_type == FrameConditioningType.INDEX: + frame_index = min(frame_conditioning_index, num_frames - 1) + indexing = [slice(None)] * latents.ndim + indexing[frame_dim] = frame_index + mask[tuple(indexing)] = 1 + latents = latents * mask + + elif frame_conditioning_type == FrameConditioningType.PREFIX: + frame_index = random.randint(1, num_frames) + indexing = [slice(None)] * latents.ndim + indexing[frame_dim] = slice(0, frame_index) # Keep frames 0 to frame_index-1 + mask[tuple(indexing)] = 1 + latents = latents * mask + + elif frame_conditioning_type == FrameConditioningType.RANDOM: + # Zero or more random frames to keep + num_frames_to_keep = random.randint(1, num_frames) + frame_indices = random.sample(range(num_frames), num_frames_to_keep) + indexing = [slice(None)] * latents.ndim + indexing[frame_dim] = frame_indices + mask[tuple(indexing)] = 1 + latents = latents * mask + + elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST: + indexing = [slice(None)] * latents.ndim + indexing[frame_dim] = 0 + mask[tuple(indexing)] = 1 + indexing[frame_dim] = num_frames - 1 + mask[tuple(indexing)] = 1 + latents = latents * mask + + elif frame_conditioning_type == FrameConditioningType.FULL: + indexing = [slice(None)] * latents.ndim + indexing[frame_dim] = slice(0, num_frames) + mask[tuple(indexing)] = 1 + + if latents.size(frame_dim) >= expected_num_frames: + slicing = [slice(None)] * latents.ndim + slicing[frame_dim] = slice(expected_num_frames) + latents = latents[tuple(slicing)] + mask = mask[tuple(slicing)] + else: + pad_size = expected_num_frames - num_frames + pad_shape = list(latents.shape) + pad_shape[frame_dim] = pad_size + padding = latents.new_zeros(pad_shape) + latents = torch.cat([latents, padding], dim=frame_dim) + mask = torch.cat([mask, padding], dim=frame_dim) + + if concatenate_mask: + slicing = [slice(None)] * latents.ndim + slicing[channel_dim] = 0 + latents = torch.cat([latents, mask], dim=channel_dim) + + return latents diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..576e17a0c0298f48b3a413ebc144586e4ce9e590 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py @@ -0,0 +1,1021 @@ +import functools +import json +import os +import re +import time +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Union + +import datasets.distributed +import safetensors.torch +import torch +import wandb +from diffusers import DiffusionPipeline +from diffusers.hooks import apply_layerwise_casting +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict +from tqdm import tqdm + +from finetrainers import data, logging, models, optimizer, parallel, utils +from finetrainers.args import BaseArgsType +from finetrainers.config import TrainingType +from finetrainers.patches import load_lora_weights +from finetrainers.state import TrainState + +from ..base import Trainer +from .config import ControlFullRankConfig, ControlLowRankConfig +from .data import IterableControlDataset, ValidationControlDataset + + +ArgsType = Union[BaseArgsType, ControlFullRankConfig, ControlLowRankConfig] + +logger = logging.get_logger() + + +class ControlTrainer(Trainer): + def __init__(self, args: ArgsType, model_specification: models.ControlModelSpecification) -> None: + super().__init__(args) + + # Tokenizers + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + + # Text encoders + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + + # Denoisers + self.transformer = None + self.unet = None + + # Autoencoders + self.vae = None + + # Scheduler + self.scheduler = None + + # Optimizer & LR scheduler + self.optimizer = None + self.lr_scheduler = None + + # Checkpoint manager + self.checkpointer = None + + self.model_specification = model_specification + self._are_condition_models_loaded = False + + model_specification._trainer_init( + args.frame_conditioning_type, args.frame_conditioning_index, args.frame_conditioning_concatenate_mask + ) + + def run(self) -> None: + try: + self._prepare_models() + self._prepare_trainable_parameters() + self._prepare_for_training() + self._prepare_dataset() + self._prepare_checkpointing() + self._train() + # trainer._evaluate() + except Exception as e: + logger.error(f"Error during training: {e}") + self.state.parallel_backend.destroy() + raise e + + def _prepare_models(self) -> None: + logger.info("Initializing models") + + # TODO(aryan): allow multiple control conditions instead of just one if there's a use case for it + new_in_features = self.model_specification._original_control_layer_in_features * 2 + diffusion_components = self.model_specification.load_diffusion_models(new_in_features) + self._set_components(diffusion_components) + + if self.state.parallel_backend.pipeline_parallel_enabled: + raise NotImplementedError( + "Pipeline parallelism is not supported yet. This will be supported in the future." + ) + + def _prepare_trainable_parameters(self) -> None: + logger.info("Initializing trainable parameters") + + parallel_backend = self.state.parallel_backend + model_spec = self.model_specification + + if self.args.training_type == TrainingType.CONTROL_FULL_FINETUNE: + logger.info("Finetuning transformer with no additional parameters") + utils.set_requires_grad([self.transformer], True) + else: + logger.info("Finetuning transformer with PEFT parameters") + utils.set_requires_grad([self.transformer], False) + + # Layerwise upcasting must be applied before adding the LoRA adapter. + # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on + # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. + if ( + self.args.training_type == TrainingType.CONTROL_LORA + and "transformer" in self.args.layerwise_upcasting_modules + ): + apply_layerwise_casting( + self.transformer, + storage_dtype=self.args.layerwise_upcasting_storage_dtype, + compute_dtype=self.args.transformer_dtype, + skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, + non_blocking=True, + ) + + transformer_lora_config = None + if self.args.training_type == TrainingType.CONTROL_LORA: + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self._get_lora_target_modules(), + rank_pattern={ + model_spec.control_injection_layer_name: model_spec._original_control_layer_out_features + }, + alpha_pattern={ + model_spec.control_injection_layer_name: model_spec._original_control_layer_out_features + }, + ) + self.transformer.add_adapter(transformer_lora_config) + + if self.args.train_qk_norm: + qk_norm_identifiers = model_spec._qk_norm_identifiers + qk_norm_module_names, qk_norm_modules = [], [] + + for name, module in self.transformer.named_modules(): + regex_match = any(re.search(identifier, name) is not None for identifier in qk_norm_identifiers) + is_parameteric = len(list(module.parameters())) > 0 + if regex_match and is_parameteric: + qk_norm_module_names.append(name) + qk_norm_modules.append(module) + + if len(qk_norm_modules) > 0: + logger.info(f"Training QK norms for modules: {qk_norm_module_names}") + utils.set_requires_grad(qk_norm_modules, True) + else: + logger.warning(f"No QK norm modules found with identifiers: {qk_norm_identifiers}") + + # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all + # parameters to be of the same dtype. + if parallel_backend.data_sharding_enabled: + self.transformer.to(dtype=self.args.transformer_dtype) + else: + if self.args.training_type == TrainingType.CONTROL_LORA: + cast_training_params([self.transformer], dtype=torch.float32) + + def _prepare_for_training(self) -> None: + # 1. Apply parallelism + parallel_backend = self.state.parallel_backend + model_specification = self.model_specification + + if parallel_backend.context_parallel_enabled: + parallel_backend.apply_context_parallel(self.transformer, parallel_backend.get_mesh()["cp"]) + + if parallel_backend.tensor_parallel_enabled: + # TODO(aryan): handle fp8 from TorchAO here + model_specification.apply_tensor_parallel( + backend=parallel.ParallelBackendEnum.PTD, + device_mesh=parallel_backend.get_mesh()["tp"], + transformer=self.transformer, + ) + + # Enable gradient checkpointing + if self.args.gradient_checkpointing: + # TODO(aryan): support other checkpointing types + utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") + + # Apply torch.compile + self._maybe_torch_compile() + + # Enable DDP, FSDP or HSDP + if parallel_backend.data_sharding_enabled: + # TODO(aryan): remove this when supported + if self.args.parallel_backend == "accelerate": + raise NotImplementedError("Data sharding is not supported with Accelerate yet.") + + dp_method = "HSDP" if parallel_backend.data_replication_enabled else "FSDP" + logger.info(f"Applying {dp_method} on the model") + + if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: + dp_mesh_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_names = ("dp_shard_cp",) + + parallel_backend.apply_fsdp2( + model=self.transformer, + param_dtype=self.args.transformer_dtype, + reduce_dtype=torch.float32, + output_dtype=None, + pp_enabled=parallel_backend.pipeline_parallel_enabled, + cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later + device_mesh=parallel_backend.get_mesh()[dp_mesh_names], + ) + elif parallel_backend.data_replication_enabled: + if parallel_backend.get_mesh().ndim > 1: + raise ValueError("DDP not supported for > 1D parallelism") + parallel_backend.apply_ddp(self.transformer, parallel_backend.get_mesh()) + else: + parallel_backend.prepare_model(self.transformer) + + self._move_components_to_device() + + # 2. Prepare optimizer and lr scheduler + # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. + # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer + # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) + # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require + # gradients. TODO(aryan): look into it in the future. + model_parts = [self.transformer] + self.state.num_trainable_parameters = sum( + p.numel() for m in model_parts for p in m.parameters() if p.requires_grad + ) + + # Setup distributed optimizer and lr scheduler + logger.info("Initializing optimizer and lr scheduler") + self.state.train_state = TrainState() + self.optimizer = optimizer.get_optimizer( + parallel_backend=self.args.parallel_backend, + name=self.args.optimizer, + model_parts=model_parts, + learning_rate=self.args.lr, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + fused=False, + ) + self.lr_scheduler = optimizer.get_lr_scheduler( + parallel_backend=self.args.parallel_backend, + name=self.args.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=self.args.lr_warmup_steps, + num_training_steps=self.args.train_steps, + # TODO(aryan): handle last_epoch + ) + self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) + + # 3. Initialize trackers, directories and repositories + self._init_logging() + self._init_trackers() + self._init_directories_and_repositories() + + def _prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + with open(self.args.dataset_config, "r") as file: + dataset_configs = json.load(file)["datasets"] + logger.info(f"Training configured to use {len(dataset_configs)} datasets") + + datasets = [] + for config in dataset_configs: + data_root = config.pop("data_root", None) + dataset_file = config.pop("dataset_file", None) + dataset_type = config.pop("dataset_type") + caption_options = config.pop("caption_options", {}) + + if data_root is not None and dataset_file is not None: + raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") + + dataset_name_or_root = data_root or dataset_file + dataset = data.initialize_dataset( + dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options + ) + + if not dataset._precomputable_once and self.args.precomputation_once: + raise ValueError( + f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." + ) + + logger.info(f"Initialized dataset: {dataset_name_or_root}") + dataset = self.state.parallel_backend.prepare_dataset(dataset) + dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) + datasets.append(dataset) + + dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) + dataset = IterableControlDataset(dataset, self.args.control_type, self.state.parallel_backend.device) + dataloader = self.state.parallel_backend.prepare_dataloader( + dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory + ) + + self.dataset = dataset + self.dataloader = dataloader + + def _prepare_checkpointing(self) -> None: + parallel_backend = self.state.parallel_backend + + def save_model_hook(state_dict: Dict[str, Any]) -> None: + state_dict = utils.get_unwrapped_model_state_dict(state_dict) + if parallel_backend.is_main_process: + if self.args.training_type == TrainingType.CONTROL_LORA: + state_dict = get_peft_model_state_dict(self.transformer, state_dict) + qk_norm_state_dict = None + if self.args.train_qk_norm: + qk_norm_state_dict = { + name: parameter + for name, parameter in state_dict.items() + if any( + re.search(identifier, name) is not None + for identifier in self.model_specification._qk_norm_identifiers + ) + and parameter.numel() > 0 + } + if len(qk_norm_state_dict) == 0: + qk_norm_state_dict = None + # fmt: off + metadata = { + "r": self.args.rank, + "lora_alpha": self.args.lora_alpha, + "init_lora_weights": True, + "target_modules": self._get_lora_target_modules(), + "rank_pattern": {self.model_specification.control_injection_layer_name: self.model_specification._original_control_layer_out_features}, + "alpha_pattern": {self.model_specification.control_injection_layer_name: self.model_specification._original_control_layer_out_features}, + } + metadata = {"lora_config": json.dumps(metadata, indent=4)} + # fmt: on + self.model_specification._save_lora_weights( + os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}"), + state_dict, + qk_norm_state_dict, + self.scheduler, + metadata, + ) + elif self.args.training_type == TrainingType.CONTROL_FULL_FINETUNE: + self.model_specification._save_model( + os.path.join(self.args.output_dir, "model_weights", f"{self.state.train_state.step:06d}"), + self.transformer, + state_dict, + self.scheduler, + ) + parallel_backend.wait_for_everyone() + + enable_state_checkpointing = self.args.checkpointing_steps > 0 + self.checkpointer = parallel_backend.get_checkpointer( + dataloader=self.dataloader, + model_parts=[self.transformer], + optimizers=self.optimizer, + schedulers=self.lr_scheduler, + states={"train_state": self.state.train_state}, + checkpointing_steps=self.args.checkpointing_steps, + checkpointing_limit=self.args.checkpointing_limit, + output_dir=self.args.output_dir, + enable=enable_state_checkpointing, + _callback_fn=save_model_hook, + ) + + resume_from_checkpoint = self.args.resume_from_checkpoint + if resume_from_checkpoint == "latest": + resume_from_checkpoint = -1 + if resume_from_checkpoint is not None: + self.checkpointer.load(resume_from_checkpoint) + + def _train(self) -> None: + logger.info("Starting training") + + parallel_backend = self.state.parallel_backend + train_state = self.state.train_state + device = parallel_backend.device + dtype = self.args.transformer_dtype + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + global_batch_size = self.args.batch_size * parallel_backend._dp_degree + info = { + "trainable parameters": self.state.num_trainable_parameters, + "train steps": self.args.train_steps, + "per-replica batch size": self.args.batch_size, + "global batch size": global_batch_size, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + progress_bar = tqdm( + range(0, self.args.train_steps), + initial=train_state.step, + desc="Training steps", + disable=not parallel_backend.is_local_main_process, + ) + + generator = torch.Generator(device=device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + self.state.generator = generator + + scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) + scheduler_sigmas = ( + scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None + ) + scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) + scheduler_alphas = ( + scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None + ) + # timesteps_buffer = [] + + self.transformer.train() + data_iterator = iter(self.dataloader) + + compute_posterior = False if self.args.enable_precomputation else (not self.args.precomputation_once) + preprocessor = data.initialize_preprocessor( + rank=parallel_backend.rank, + world_size=parallel_backend.world_size, + num_items=self.args.precomputation_items if self.args.enable_precomputation else 1, + processor_fn={ + "condition": self.model_specification.prepare_conditions, + "latent": functools.partial( + self.model_specification.prepare_latents, compute_posterior=compute_posterior + ), + }, + save_dir=self.args.precomputation_dir, + enable_precomputation=self.args.enable_precomputation, + enable_reuse=self.args.precomputation_reuse, + ) + condition_iterator: Iterable[Dict[str, Any]] = None + latent_iterator: Iterable[Dict[str, Any]] = None + sampler = data.ResolutionSampler( + batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys + ) + requires_gradient_step = True + accumulated_loss = 0.0 + + while ( + train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples + ): + # 1. Load & preprocess data if required + if preprocessor.requires_data: + condition_iterator, latent_iterator = self._prepare_data(preprocessor, data_iterator) + + # 2. Prepare batch + with self.tracker.timed("timing/batch_preparation"): + try: + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + sampler.consume(condition_item, latent_item) + except StopIteration: + if requires_gradient_step: + self.optimizer.step() + self.lr_scheduler.step() + requires_gradient_step = False + logger.info("Data exhausted. Exiting training loop.") + break + + if sampler.is_ready: + condition_batch, latent_batch = sampler.get_batch() + condition_model_conditions = self.model_specification.collate_conditions(condition_batch) + latent_model_conditions = self.model_specification.collate_latents(latent_batch) + else: + continue + + train_state.step += 1 + train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree + + logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") + + latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype) + condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype) + latent_model_conditions = utils.make_contiguous(latent_model_conditions) + condition_model_conditions = utils.make_contiguous(condition_model_conditions) + + # 3. Forward pass + sigmas = utils.prepare_sigmas( + scheduler=self.scheduler, + sigmas=scheduler_sigmas, + batch_size=self.args.batch_size, + num_train_timesteps=self.scheduler.config.num_train_timesteps, + flow_weighting_scheme=self.args.flow_weighting_scheme, + flow_logit_mean=self.args.flow_logit_mean, + flow_logit_std=self.args.flow_logit_std, + flow_mode_scale=self.args.flow_mode_scale, + device=device, + generator=self.state.generator, + ) + sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) + + # NOTE: for planned refactor, make sure that forward and backward pass run under the context. + # If only forward runs under context, backward will most likely fail when using activation checkpointing + with self.attention_provider_ctx(training=True): + with self.tracker.timed("timing/forward"): + pred, target, sigmas = self.model_specification.forward( + transformer=self.transformer, + scheduler=self.scheduler, + condition_model_conditions=condition_model_conditions, + latent_model_conditions=latent_model_conditions, + sigmas=sigmas, + compute_posterior=compute_posterior, + ) + + timesteps = (sigmas * 1000.0).long() + weights = utils.prepare_loss_weights( + scheduler=self.scheduler, + alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, + sigmas=sigmas, + flow_weighting_scheme=self.args.flow_weighting_scheme, + ) + weights = utils.expand_tensor_dims(weights, pred.ndim) + + # 4. Compute loss & backward pass + with self.tracker.timed("timing/backward"): + loss = weights.float() * (pred.float() - target.float()).pow(2) + # Average loss across all but batch dimension (for per-batch debugging in case needed) + loss = loss.mean(list(range(1, loss.ndim))) + # Average loss across batch dimension + loss = loss.mean() + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + loss.backward() + + accumulated_loss += loss.detach().item() + requires_gradient_step = True + + # 5. Clip gradients + model_parts = [self.transformer] + grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + self.args.max_grad_norm, + foreach=True, + pp_mesh=parallel_backend.get_mesh()["pp"] if parallel_backend.pipeline_parallel_enabled else None, + ) + + # 6. Step optimizer & log metrics + logs = {} + + if train_state.step % self.args.gradient_accumulation_steps == 0: + # TODO(aryan): revisit no_sync() for FSDP + with self.tracker.timed("timing/optimizer_step"): + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + if grad_norm is not None: + grad_norm = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() + if ( + parallel_backend.data_replication_enabled + or parallel_backend.data_sharding_enabled + or parallel_backend.context_parallel_enabled + ): + dp_cp_mesh = parallel_backend.get_mesh()["dp_cp"] + if grad_norm is not None: + grad_norm = parallel.dist_mean(torch.tensor([grad_norm], device=device), dp_cp_mesh) + global_avg_loss, global_max_loss = ( + parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + ) + else: + global_avg_loss = global_max_loss = accumulated_loss + + logs["train/global_avg_loss"] = global_avg_loss + logs["train/global_max_loss"] = global_max_loss + if grad_norm is not None: + logs["train/grad_norm"] = grad_norm + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + accumulated_loss = 0.0 + requires_gradient_step = False + + progress_bar.update(1) + progress_bar.set_postfix(logs) + + # timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) + + if train_state.step % self.args.logging_steps == 0: + # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts + # TODO(aryan): causes NCCL hang for some reason. look into later + # logs.update(self.lr_scheduler.get_last_lr()) + + # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) + # logs["timesteps"] = wandb.plot.scatter( + # timesteps_table, "step", "timesteps", title="Timesteps distribution" + # ) + # timesteps_buffer = [] + + logs["train/observed_data_samples"] = train_state.observed_data_samples + + parallel_backend.log(logs, step=train_state.step) + train_state.log_steps.append(train_state.step) + + # 7. Save checkpoint if required + with self.tracker.timed("timing/checkpoint"): + self.checkpointer.save( + step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process + ) + + # 8. Perform validation if required + if train_state.step % self.args.validation_steps == 0: + self._validate(step=train_state.step, final_validation=False) + + # 9. Final checkpoint, validation & cleanup + self.checkpointer.save( + train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process + ) + parallel_backend.wait_for_everyone() + self._validate(step=train_state.step, final_validation=True) + + self._delete_components() + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + # 10. Upload artifacts to hub + if parallel_backend.is_main_process and self.args.push_to_hub: + upload_folder( + repo_id=self.state.repo_id, + folder_path=self.args.output_dir, + ignore_patterns=[f"{self.checkpointer._prefix}_*"], + ) + + parallel_backend.destroy() + + def _validate(self, step: int, final_validation: bool = False) -> None: + if self.args.validation_dataset_file is None: + return + + logger.info("Starting validation") + + # 1. Load validation dataset + parallel_backend = self.state.parallel_backend + dataset = data.ValidationDataset(self.args.validation_dataset_file) + + # Hack to make accelerate work. TODO(aryan): refactor + if parallel_backend._dp_degree > 1: + dp_mesh = parallel_backend.get_mesh()["dp"] + dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) + else: + dp_mesh = None + dp_local_rank, dp_world_size = parallel_backend.local_rank, 1 + + dataset = ValidationControlDataset(dataset, self.args.control_type, parallel_backend.device) + validation_dataloader = data.DPDataLoader( + dp_local_rank, + dataset, + batch_size=1, + num_workers=self.args.dataloader_num_workers, + collate_fn=lambda items: items, + ) + data_iterator = iter(validation_dataloader) + main_process_prompts_to_filenames = {} # Used to save model card + all_processes_artifacts = [] # Used to gather artifacts from all processes + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + seed = self.args.seed if self.args.seed is not None else 0 + generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) + pipeline = self._init_pipeline(final_validation=final_validation) + + # 2. Run validation + # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we + # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset + # size is not divisible by dp_shards. + self.transformer.eval() + while True: + validation_data = next(data_iterator, None) + if validation_data is None: + break + + validation_data = validation_data[0] + with self.attention_provider_ctx(training=False): + validation_artifacts = self.model_specification.validation( + pipeline=pipeline, generator=generator, **validation_data + ) + + if dp_local_rank != 0: + continue + + PROMPT = validation_data["prompt"] + IMAGE = validation_data.get("image", None) + VIDEO = validation_data.get("video", None) + CONTROL_IMAGE = validation_data.get("control_image", None) + CONTROL_VIDEO = validation_data.get("control_video", None) + EXPORT_FPS = validation_data.get("export_fps", 30) + + # 2.1. If there are any initial images or videos, they will be logged to keep track of them as + # conditioning for generation. + prompt_filename = utils.string_to_filename(PROMPT)[:25] + artifacts = { + "input_image": data.ImageArtifact(value=IMAGE), + "input_video": data.VideoArtifact(value=VIDEO), + "control_image": data.ImageArtifact(value=CONTROL_IMAGE), + "control_video": data.VideoArtifact(value=CONTROL_VIDEO), + } + + # 2.2. Track the artifacts generated from validation + for i, validation_artifact in enumerate(validation_artifacts): + if validation_artifact.value is None: + continue + artifacts.update({f"artifact_{i}": validation_artifact}) + + # 2.3. Save the artifacts to the output directory and create appropriate logging objects + # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. + for index, (key, artifact) in enumerate(list(artifacts.items())): + assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) + if artifact.value is None: + continue + + time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension + filename = "validation-" if not final_validation else "final-" + filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}" + + if parallel_backend.is_main_process and ext in ["mp4", "jpg", "jpeg", "png"]: + main_process_prompts_to_filenames[PROMPT] = filename + + caption = PROMPT + if key == "control_image": + filename = f"control_image-{filename}" + caption = f"[control] {caption}" + elif key == "control_video": + filename = f"control_video-{filename}" + caption = f"[control] {caption}" + + output_filename = os.path.join(self.args.output_dir, filename) + + if isinstance(artifact, data.ImageArtifact): + artifact.value.save(output_filename) + all_processes_artifacts.append(wandb.Image(output_filename, caption=caption)) + elif isinstance(artifact, data.VideoArtifact): + export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) + all_processes_artifacts.append(wandb.Video(output_filename, caption=caption)) + + # 3. Cleanup & log artifacts + parallel_backend.wait_for_everyone() + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + + # Remove all hooks that might have been added during pipeline initialization to the models + pipeline.remove_all_hooks() + del pipeline + module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"] + if self.args.enable_precomputation: + self._delete_components(module_names) + torch.cuda.reset_peak_memory_stats(parallel_backend.device) + + # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. + all_artifacts = [None] * dp_world_size + if dp_world_size > 1: + torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) + else: + all_artifacts = [all_processes_artifacts] + all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] + + if parallel_backend.is_main_process: + tracker_key = "final" if final_validation else "validation" + artifact_log_dict = {} + + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + if len(image_artifacts) > 0: + artifact_log_dict["images"] = image_artifacts + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + if len(video_artifacts) > 0: + artifact_log_dict["videos"] = video_artifacts + parallel_backend.log({tracker_key: artifact_log_dict}, step=step) + + if self.args.push_to_hub and final_validation: + video_filenames = list(main_process_prompts_to_filenames.values()) + prompts = list(main_process_prompts_to_filenames.keys()) + utils.save_model_card( + args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts + ) + + parallel_backend.wait_for_everyone() + if not final_validation: + self._move_components_to_device() + self.transformer.train() + + def _evaluate(self) -> None: + raise NotImplementedError("Evaluation has not been implemented yet.") + + def _init_directories_and_repositories(self) -> None: + if self.state.parallel_backend.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + self.state.output_dir = Path(self.args.output_dir) + + if self.args.push_to_hub: + repo_id = self.args.hub_model_id or Path(self.args.output_dir).name + self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id + + def _move_components_to_device( + self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None + ) -> None: + if device is None: + device = self.state.parallel_backend.device + if components is None: + components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae] + components = utils.get_non_null_items(components) + components = list(filter(lambda x: hasattr(x, "to"), components)) + for component in components: + component.to(device) + + def _set_components(self, components: Dict[str, Any]) -> None: + for component_name in self._all_component_names: + existing_component = getattr(self, component_name, None) + new_component = components.get(component_name, existing_component) + setattr(self, component_name, new_component) + + def _delete_components(self, component_names: Optional[List[str]] = None) -> None: + if component_names is None: + component_names = self._all_component_names + for component_name in component_names: + setattr(self, component_name, None) + utils.free_memory() + utils.synchronize_device() + + def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: + parallel_backend = self.state.parallel_backend + module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"] + + if not final_validation: + module_names.remove("transformer") + pipeline = self.model_specification.load_pipeline( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + text_encoder_3=self.text_encoder_3, + # TODO(aryan): handle unwrapping for compiled modules + # transformer=utils.unwrap_model(accelerator, self.transformer), + transformer=self.transformer, + vae=self.vae, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=True, + ) + else: + self._delete_components() + + # TODO(aryan): allow multiple control conditions instead of just one if there's a use case for it + new_in_features = self.model_specification._original_control_layer_in_features * 2 + if self.args.frame_conditioning_concatenate_mask: + new_in_features += 1 + transformer = self.model_specification.load_diffusion_models(new_in_features)["transformer"] + + pipeline = self.model_specification.load_pipeline( + transformer=transformer, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=False, + device=parallel_backend.device, + ) + + # Load the LoRA weights if performing LoRA finetuning + if self.args.training_type == TrainingType.CONTROL_LORA: + load_lora_weights( + pipeline, os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}") + ) + norm_state_dict_path = os.path.join( + self.args.output_dir, + "lora_weights", + f"{self.state.train_state.step:06d}", + "norm_state_dict.safetensors", + ) + if self.args.train_qk_norm and norm_state_dict_path.exists(): + norm_state_dict = safetensors.torch.load_file(norm_state_dict_path, parallel_backend.device) + self.transformer.load_state_dict(norm_state_dict) + + components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} + self._set_components(components) + if not self.args.enable_model_cpu_offload: + self._move_components_to_device(list(components.values())) + self._maybe_torch_compile() + return pipeline + + def _prepare_data( + self, + preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor], + data_iterator, + ): + if not self.args.enable_precomputation: + if not self._are_condition_models_loaded: + logger.info( + "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs." + ) + condition_components = self.model_specification.load_condition_models() + latent_components = self.model_specification.load_latent_models() + all_components = {**condition_components, **latent_components} + self._set_components(all_components) + self._move_components_to_device(list(all_components.values())) + utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) + self._maybe_torch_compile() + else: + condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))} + latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))} + + condition_iterator = preprocessor.consume( + "condition", + components=condition_components, + data_iterator=data_iterator, + generator=self.state.generator, + cache_samples=True, + ) + latent_iterator = preprocessor.consume( + "latent", + components=latent_components, + data_iterator=data_iterator, + generator=self.state.generator, + use_cached_samples=True, + drop_samples=True, + ) + + self._are_condition_models_loaded = True + else: + logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") + + parallel_backend = self.state.parallel_backend + if parallel_backend.world_size == 1: + self._move_components_to_device([self.transformer], "cpu") + utils.free_memory() + utils.synchronize_device() + torch.cuda.reset_peak_memory_stats(parallel_backend.device) + + consume_fn = preprocessor.consume_once if self.args.precomputation_once else preprocessor.consume + + # Prepare condition iterators + condition_components, component_names, component_modules = {}, [], [] + if not self.args.precomputation_reuse: + condition_components = self.model_specification.load_condition_models() + component_names = list(condition_components.keys()) + component_modules = list(condition_components.values()) + self._set_components(condition_components) + self._move_components_to_device(component_modules) + self._maybe_torch_compile() + condition_iterator = consume_fn( + "condition", + components=condition_components, + data_iterator=data_iterator, + generator=self.state.generator, + cache_samples=True, + ) + self._delete_components(component_names) + del condition_components, component_names, component_modules + + # Prepare latent iterators + latent_components, component_names, component_modules = {}, [], [] + if not self.args.precomputation_reuse: + latent_components = self.model_specification.load_latent_models() + utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) + component_names = list(latent_components.keys()) + component_modules = list(latent_components.values()) + self._set_components(latent_components) + self._move_components_to_device(component_modules) + self._maybe_torch_compile() + latent_iterator = consume_fn( + "latent", + components=latent_components, + data_iterator=data_iterator, + generator=self.state.generator, + use_cached_samples=True, + drop_samples=True, + ) + self._delete_components(component_names) + del latent_components, component_names, component_modules + + if parallel_backend.world_size == 1: + self._move_components_to_device([self.transformer]) + + return condition_iterator, latent_iterator + + def _maybe_torch_compile(self): + for model_name, compile_scope in zip(self.args.compile_modules, self.args.compile_scopes): + model = getattr(self, model_name, None) + if model is not None: + logger.info(f"Applying torch.compile to '{model_name}' with scope '{compile_scope}'.") + compiled_model = utils.apply_compile(model, compile_scope) + setattr(self, model_name, compiled_model) + + def _get_training_info(self) -> Dict[str, Any]: + info = self.args.to_dict() + + # Removing flow matching arguments when not using flow-matching objective + diffusion_args = info.get("diffusion_arguments", {}) + scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" + if scheduler_name != "FlowMatchEulerDiscreteScheduler": + filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} + else: + filtered_diffusion_args = diffusion_args + + info.update({"diffusion_arguments": filtered_diffusion_args}) + return info + + def _get_lora_target_modules(self): + target_modules = self.args.target_modules + if isinstance(target_modules, list): + target_modules = list(target_modules) # Make a copy to avoid modifying args + target_modules.append(f"^{self.model_specification.control_injection_layer_name}$") + if isinstance(target_modules, str): + target_modules = f"(^{self.model_specification.control_injection_layer_name}$)|({target_modules})" + return target_modules + + # fmt: off + _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] + _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"] + _latent_component_names = ["vae"] + _diffusion_component_names = ["transformer", "unet", "scheduler"] + # fmt: on diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726049bd5a9059c0df70efeacc76ac9f3423315a --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py @@ -0,0 +1,2 @@ +from .config import SFTFullRankConfig, SFTLowRankConfig +from .trainer import SFTTrainer diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c70c6503373ed7e6aaf8c2b60fc4ba0a0f0f81a6 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py @@ -0,0 +1,65 @@ +import argparse +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from finetrainers.utils import ArgsConfigMixin + + +if TYPE_CHECKING: + from finetrainers.args import BaseArgs + + +class SFTLowRankConfig(ArgsConfigMixin): + r""" + Configuration class for SFT low rank training. + + Args: + rank (int): + Rank of the low rank approximation matrix. + lora_alpha (int): + The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. + target_modules (`str` or `List[str]`): + Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings. + """ + + rank: int = 64 + lora_alpha: int = 64 + target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" + + def add_args(self, parser: argparse.ArgumentParser): + parser.add_argument("--rank", type=int, default=64) + parser.add_argument("--lora_alpha", type=int, default=64) + parser.add_argument( + "--target_modules", + type=str, + nargs="+", + default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"], + ) + + def validate_args(self, args: "BaseArgs"): + assert self.rank > 0, "Rank must be a positive integer." + assert self.lora_alpha > 0, "lora_alpha must be a positive integer." + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.rank = argparse_args.rank + mapped_args.lora_alpha = argparse_args.lora_alpha + mapped_args.target_modules = ( + argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules + ) + + def to_dict(self) -> Dict[str, Any]: + return {"rank": self.rank, "lora_alpha": self.lora_alpha, "target_modules": self.target_modules} + + +class SFTFullRankConfig(ArgsConfigMixin): + r""" + Configuration class for SFT full rank training. + """ + + def add_args(self, parser: argparse.ArgumentParser): + pass + + def validate_args(self, args: "BaseArgs"): + pass + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + pass diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..789545963e533b8ee159a528bece5185f13634cc --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py @@ -0,0 +1,946 @@ +import functools +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Union + +import datasets.distributed +import torch +import wandb +from diffusers import DiffusionPipeline +from diffusers.hooks import apply_layerwise_casting +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict +from tqdm import tqdm + +from finetrainers import data, logging, models, optimizer, parallel, utils +from finetrainers.args import BaseArgsType +from finetrainers.config import TrainingType +from finetrainers.state import TrainState + +from ..base import Trainer +from .config import SFTFullRankConfig, SFTLowRankConfig + + +ArgsType = Union[BaseArgsType, SFTFullRankConfig, SFTLowRankConfig] + +logger = logging.get_logger() + + +class SFTTrainer(Trainer): + def __init__(self, args: ArgsType, model_specification: models.ModelSpecification) -> None: + super().__init__(args) + + # Tokenizers + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + + # Text encoders + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + + # Image encoders + self.image_encoder = None + self.image_processor = None + + # Denoisers + self.transformer = None + self.unet = None + + # Autoencoders + self.vae = None + + # Scheduler + self.scheduler = None + + # Optimizer & LR scheduler + self.optimizer = None + self.lr_scheduler = None + + # Checkpoint manager + self.checkpointer = None + + self.model_specification = model_specification + self._are_condition_models_loaded = False + + def run(self) -> None: + try: + self._prepare_models() + self._prepare_trainable_parameters() + self._prepare_for_training() + self._prepare_dataset() + self._prepare_checkpointing() + self._train() + # trainer._evaluate() + except Exception as e: + logger.error(f"Error during training: {e}") + self.state.parallel_backend.destroy() + raise e + + def _prepare_models(self) -> None: + logger.info("Initializing models") + + diffusion_components = self.model_specification.load_diffusion_models() + self._set_components(diffusion_components) + + if self.state.parallel_backend.pipeline_parallel_enabled: + raise NotImplementedError( + "Pipeline parallelism is not supported yet. This will be supported in the future." + ) + + def _prepare_trainable_parameters(self) -> None: + logger.info("Initializing trainable parameters") + + parallel_backend = self.state.parallel_backend + + if self.args.training_type == TrainingType.FULL_FINETUNE: + logger.info("Finetuning transformer with no additional parameters") + utils.set_requires_grad([self.transformer], True) + else: + logger.info("Finetuning transformer with PEFT parameters") + utils.set_requires_grad([self.transformer], False) + + # Layerwise upcasting must be applied before adding the LoRA adapter. + # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on + # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. + if self.args.training_type == TrainingType.LORA and "transformer" in self.args.layerwise_upcasting_modules: + apply_layerwise_casting( + self.transformer, + storage_dtype=self.args.layerwise_upcasting_storage_dtype, + compute_dtype=self.args.transformer_dtype, + skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, + non_blocking=True, + ) + + transformer_lora_config = None + if self.args.training_type == TrainingType.LORA: + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.transformer.add_adapter(transformer_lora_config) + + # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all + # parameters to be of the same dtype. + if parallel_backend.data_sharding_enabled: + self.transformer.to(dtype=self.args.transformer_dtype) + else: + if self.args.training_type == TrainingType.LORA: + cast_training_params([self.transformer], dtype=torch.float32) + + def _prepare_for_training(self) -> None: + # 1. Apply parallelism + parallel_backend = self.state.parallel_backend + model_specification = self.model_specification + + if parallel_backend.context_parallel_enabled: + parallel_backend.apply_context_parallel(self.transformer, parallel_backend.get_mesh()["cp"]) + + if parallel_backend.tensor_parallel_enabled: + # TODO(aryan): handle fp8 from TorchAO here + model_specification.apply_tensor_parallel( + backend=parallel.ParallelBackendEnum.PTD, + device_mesh=parallel_backend.get_mesh()["tp"], + transformer=self.transformer, + ) + + # Enable gradient checkpointing + if self.args.gradient_checkpointing: + # TODO(aryan): support other checkpointing types + utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") + + # Apply torch.compile + self._maybe_torch_compile() + + # Enable DDP, FSDP or HSDP + if parallel_backend.data_sharding_enabled: + # TODO(aryan): remove this when supported + if self.args.parallel_backend == "accelerate": + raise NotImplementedError("Data sharding is not supported with Accelerate yet.") + + dp_method = "HSDP" if parallel_backend.data_replication_enabled else "FSDP" + logger.info(f"Applying {dp_method} on the model") + + if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: + dp_mesh_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_names = ("dp_shard_cp",) + + parallel_backend.apply_fsdp2( + model=self.transformer, + param_dtype=self.args.transformer_dtype, + reduce_dtype=torch.float32, + output_dtype=None, + pp_enabled=parallel_backend.pipeline_parallel_enabled, + cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later + device_mesh=parallel_backend.get_mesh()[dp_mesh_names], + ) + elif parallel_backend.data_replication_enabled: + if parallel_backend.get_mesh().ndim > 1: + raise ValueError("DDP not supported for > 1D parallelism") + logger.info("Applying DDP to the model") + parallel_backend.apply_ddp(self.transformer, parallel_backend.get_mesh()) + else: + parallel_backend.prepare_model(self.transformer) + + self._move_components_to_device() + + # 2. Prepare optimizer and lr scheduler + # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. + # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer + # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) + # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require + # gradients. TODO(aryan): look into it in the future. + model_parts = [self.transformer] + self.state.num_trainable_parameters = sum( + p.numel() for m in model_parts for p in m.parameters() if p.requires_grad + ) + + # Setup distributed optimizer and lr scheduler + logger.info("Initializing optimizer and lr scheduler") + self.state.train_state = TrainState() + self.optimizer = optimizer.get_optimizer( + parallel_backend=self.args.parallel_backend, + name=self.args.optimizer, + model_parts=model_parts, + learning_rate=self.args.lr, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + fused=False, + ) + self.lr_scheduler = optimizer.get_lr_scheduler( + parallel_backend=self.args.parallel_backend, + name=self.args.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=self.args.lr_warmup_steps, + num_training_steps=self.args.train_steps, + # TODO(aryan): handle last_epoch + ) + self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) + + # 3. Initialize trackers, directories and repositories + self._init_logging() + self._init_trackers() + self._init_directories_and_repositories() + + def _prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + with open(self.args.dataset_config, "r") as file: + dataset_configs = json.load(file)["datasets"] + logger.info(f"Training configured to use {len(dataset_configs)} datasets") + + datasets = [] + for config in dataset_configs: + data_root = config.pop("data_root", None) + dataset_file = config.pop("dataset_file", None) + dataset_type = config.pop("dataset_type") + caption_options = config.pop("caption_options", {}) + + if data_root is not None and dataset_file is not None: + raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") + + dataset_name_or_root = data_root or dataset_file + dataset = data.initialize_dataset( + dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options + ) + + if not dataset._precomputable_once and self.args.precomputation_once: + raise ValueError( + f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." + ) + + logger.info(f"Initialized dataset: {dataset_name_or_root}") + dataset = self.state.parallel_backend.prepare_dataset(dataset) + dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) + datasets.append(dataset) + + dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) + dataloader = self.state.parallel_backend.prepare_dataloader( + dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory + ) + + self.dataset = dataset + self.dataloader = dataloader + + def _prepare_checkpointing(self) -> None: + parallel_backend = self.state.parallel_backend + + def save_model_hook(state_dict: Dict[str, Any]) -> None: + state_dict = utils.get_unwrapped_model_state_dict(state_dict) + if parallel_backend.is_main_process: + if self.args.training_type == TrainingType.LORA: + state_dict = get_peft_model_state_dict(self.transformer, state_dict) + # fmt: off + metadata = { + "r": self.args.rank, + "lora_alpha": self.args.lora_alpha, + "init_lora_weights": True, + "target_modules": self.args.target_modules, + } + metadata = {"lora_config": json.dumps(metadata, indent=4)} + # fmt: on + self.model_specification._save_lora_weights( + os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}"), + state_dict, + self.scheduler, + metadata, + ) + elif self.args.training_type == TrainingType.FULL_FINETUNE: + self.model_specification._save_model( + os.path.join(self.args.output_dir, "model_weights", f"{self.state.train_state.step:06d}"), + self.transformer, + state_dict, + self.scheduler, + ) + parallel_backend.wait_for_everyone() + + enable_state_checkpointing = self.args.checkpointing_steps > 0 + self.checkpointer = parallel_backend.get_checkpointer( + dataloader=self.dataloader, + model_parts=[self.transformer], + optimizers=self.optimizer, + schedulers=self.lr_scheduler, + states={"train_state": self.state.train_state}, + checkpointing_steps=self.args.checkpointing_steps, + checkpointing_limit=self.args.checkpointing_limit, + output_dir=self.args.output_dir, + enable=enable_state_checkpointing, + _callback_fn=save_model_hook, + ) + + resume_from_checkpoint = self.args.resume_from_checkpoint + if resume_from_checkpoint == "latest": + resume_from_checkpoint = -1 + if resume_from_checkpoint is not None: + self.checkpointer.load(resume_from_checkpoint) + + def _train(self) -> None: + logger.info("Starting training") + + parallel_backend = self.state.parallel_backend + train_state = self.state.train_state + device = parallel_backend.device + dtype = self.args.transformer_dtype + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + global_batch_size = self.args.batch_size * parallel_backend._dp_degree + info = { + "trainable parameters": self.state.num_trainable_parameters, + "train steps": self.args.train_steps, + "per-replica batch size": self.args.batch_size, + "global batch size": global_batch_size, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + progress_bar = tqdm( + range(0, self.args.train_steps), + initial=train_state.step, + desc="Training steps", + disable=not parallel_backend.is_local_main_process, + ) + + generator = torch.Generator(device=device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + self.state.generator = generator + + scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) + scheduler_sigmas = ( + scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None + ) + scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) + scheduler_alphas = ( + scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None + ) + # timesteps_buffer = [] + + self.transformer.train() + data_iterator = iter(self.dataloader) + + compute_posterior = False if self.args.enable_precomputation else (not self.args.precomputation_once) + preprocessor = data.initialize_preprocessor( + rank=parallel_backend.rank, + world_size=parallel_backend.world_size, + num_items=self.args.precomputation_items if self.args.enable_precomputation else 1, + processor_fn={ + "condition": self.model_specification.prepare_conditions, + "latent": functools.partial( + self.model_specification.prepare_latents, compute_posterior=compute_posterior + ), + }, + save_dir=self.args.precomputation_dir, + enable_precomputation=self.args.enable_precomputation, + enable_reuse=self.args.precomputation_reuse, + ) + condition_iterator: Iterable[Dict[str, Any]] = None + latent_iterator: Iterable[Dict[str, Any]] = None + sampler = data.ResolutionSampler( + batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys + ) + requires_gradient_step = True + accumulated_loss = 0.0 + + while ( + train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples + ): + # 1. Load & preprocess data if required + if preprocessor.requires_data: + condition_iterator, latent_iterator = self._prepare_data(preprocessor, data_iterator) + + # 2. Prepare batch + with self.tracker.timed("timing/batch_preparation"): + try: + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + sampler.consume(condition_item, latent_item) + except StopIteration: + if requires_gradient_step: + self.optimizer.step() + self.lr_scheduler.step() + requires_gradient_step = False + logger.info("Data exhausted. Exiting training loop.") + break + + if sampler.is_ready: + condition_batch, latent_batch = sampler.get_batch() + condition_model_conditions = self.model_specification.collate_conditions(condition_batch) + latent_model_conditions = self.model_specification.collate_latents(latent_batch) + else: + continue + + train_state.step += 1 + train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree + + logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") + + latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype) + condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype) + latent_model_conditions = utils.make_contiguous(latent_model_conditions) + condition_model_conditions = utils.make_contiguous(condition_model_conditions) + + # 3. Forward pass + sigmas = utils.prepare_sigmas( + scheduler=self.scheduler, + sigmas=scheduler_sigmas, + batch_size=self.args.batch_size, + num_train_timesteps=self.scheduler.config.num_train_timesteps, + flow_weighting_scheme=self.args.flow_weighting_scheme, + flow_logit_mean=self.args.flow_logit_mean, + flow_logit_std=self.args.flow_logit_std, + flow_mode_scale=self.args.flow_mode_scale, + device=device, + generator=self.state.generator, + ) + sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) + + # NOTE: for planned refactor, make sure that forward and backward pass run under the context. + # If only forward runs under context, backward will most likely fail when using activation checkpointing + with self.attention_provider_ctx(training=True): + with self.tracker.timed("timing/forward"): + pred, target, sigmas = self.model_specification.forward( + transformer=self.transformer, + scheduler=self.scheduler, + condition_model_conditions=condition_model_conditions, + latent_model_conditions=latent_model_conditions, + sigmas=sigmas, + compute_posterior=compute_posterior, + ) + + timesteps = (sigmas * 1000.0).long() + weights = utils.prepare_loss_weights( + scheduler=self.scheduler, + alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, + sigmas=sigmas, + flow_weighting_scheme=self.args.flow_weighting_scheme, + ) + weights = utils.expand_tensor_dims(weights, pred.ndim) + + # 4. Compute loss & backward pass + with self.tracker.timed("timing/backward"): + loss = weights.float() * (pred.float() - target.float()).pow(2) + # Average loss across all but batch dimension (for per-batch debugging in case needed) + loss = loss.mean(list(range(1, loss.ndim))) + # Average loss across batch dimension + loss = loss.mean() + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + loss.backward() + + accumulated_loss += loss.detach().item() + requires_gradient_step = True + + # 5. Clip gradients + model_parts = [self.transformer] + grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + self.args.max_grad_norm, + foreach=True, + pp_mesh=parallel_backend.get_mesh()["pp"] if parallel_backend.pipeline_parallel_enabled else None, + ) + + # 6. Step optimizer & log metrics + logs = {} + + if train_state.step % self.args.gradient_accumulation_steps == 0: + # TODO(aryan): revisit no_sync() for FSDP + with self.tracker.timed("timing/optimizer_step"): + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + if grad_norm is not None: + grad_norm = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() + if ( + parallel_backend.data_replication_enabled + or parallel_backend.data_sharding_enabled + or parallel_backend.context_parallel_enabled + ): + dp_cp_mesh = parallel_backend.get_mesh()["dp_cp"] + if grad_norm is not None: + grad_norm = parallel.dist_mean(torch.tensor([grad_norm], device=device), dp_cp_mesh) + global_avg_loss, global_max_loss = ( + parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + ) + else: + global_avg_loss = global_max_loss = accumulated_loss + + logs["train/global_avg_loss"] = global_avg_loss + logs["train/global_max_loss"] = global_max_loss + if grad_norm is not None: + logs["train/grad_norm"] = grad_norm + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + accumulated_loss = 0.0 + requires_gradient_step = False + + progress_bar.update(1) + progress_bar.set_postfix(logs) + + # timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) + + if train_state.step % self.args.logging_steps == 0: + # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts + # TODO(aryan): causes NCCL hang for some reason. look into later + # logs.update(self.lr_scheduler.get_last_lr()) + + # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) + # logs["timesteps"] = wandb.plot.scatter( + # timesteps_table, "step", "timesteps", title="Timesteps distribution" + # ) + # timesteps_buffer = [] + + logs["train/observed_data_samples"] = train_state.observed_data_samples + + parallel_backend.log(logs, step=train_state.step) + train_state.log_steps.append(train_state.step) + + # 7. Save checkpoint if required + with self.tracker.timed("timing/checkpoint"): + self.checkpointer.save( + step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process + ) + + # 8. Perform validation if required + if train_state.step % self.args.validation_steps == 0: + self._validate(step=train_state.step, final_validation=False) + + # 9. Final checkpoint, validation & cleanup + self.checkpointer.save( + train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process + ) + parallel_backend.wait_for_everyone() + self._validate(step=train_state.step, final_validation=True) + + self._delete_components() + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + # 10. Upload artifacts to hub + if parallel_backend.is_main_process and self.args.push_to_hub: + upload_folder( + repo_id=self.state.repo_id, + folder_path=self.args.output_dir, + ignore_patterns=[f"{self.checkpointer._prefix}_*"], + ) + + parallel_backend.destroy() + + def _validate(self, step: int, final_validation: bool = False) -> None: + if self.args.validation_dataset_file is None: + return + + logger.info("Starting validation") + + # 1. Load validation dataset + parallel_backend = self.state.parallel_backend + dataset = data.ValidationDataset(self.args.validation_dataset_file) + + # Hack to make accelerate work. TODO(aryan): refactor + if parallel_backend._dp_degree > 1: + dp_mesh = parallel_backend.get_mesh()["dp"] + dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) + else: + dp_mesh = None + dp_local_rank, dp_world_size = parallel_backend.local_rank, 1 + + validation_dataloader = data.DPDataLoader( + dp_local_rank, + dataset, + batch_size=1, + num_workers=self.args.dataloader_num_workers, + collate_fn=lambda items: items, + ) + data_iterator = iter(validation_dataloader) + main_process_prompts_to_filenames = {} # Used to save model card + all_processes_artifacts = [] # Used to gather artifacts from all processes + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + seed = self.args.seed if self.args.seed is not None else 0 + generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) + pipeline = self._init_pipeline(final_validation=final_validation) + + # 2. Run validation + # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we + # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset + # size is not divisible by dp_shards. + self.transformer.eval() + while True: + validation_data = next(data_iterator, None) + if validation_data is None: + break + + validation_data = validation_data[0] + with self.attention_provider_ctx(training=False): + validation_artifacts = self.model_specification.validation( + pipeline=pipeline, generator=generator, **validation_data + ) + + if dp_local_rank != 0: + continue + + PROMPT = validation_data["prompt"] + IMAGE = validation_data.get("image", None) + VIDEO = validation_data.get("video", None) + EXPORT_FPS = validation_data.get("export_fps", 30) + + # 2.1. If there are any initial images or videos, they will be logged to keep track of them as + # conditioning for generation. + prompt_filename = utils.string_to_filename(PROMPT)[:25] + artifacts = { + "input_image": data.ImageArtifact(value=IMAGE), + "input_video": data.VideoArtifact(value=VIDEO), + } + + # 2.2. Track the artifacts generated from validation + for i, validation_artifact in enumerate(validation_artifacts): + if validation_artifact.value is None: + continue + artifacts.update({f"artifact_{i}": validation_artifact}) + + # 2.3. Save the artifacts to the output directory and create appropriate logging objects + # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. + for index, (key, artifact) in enumerate(list(artifacts.items())): + assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) + if artifact.value is None: + continue + + time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension + filename = "validation-" if not final_validation else "final-" + filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}" + output_filename = os.path.join(self.args.output_dir, filename) + + if parallel_backend.is_main_process and ext in ["mp4", "jpg", "jpeg", "png"]: + main_process_prompts_to_filenames[PROMPT] = filename + + if isinstance(artifact, data.ImageArtifact): + artifact.value.save(output_filename) + all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT)) + elif isinstance(artifact, data.VideoArtifact): + export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) + all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT)) + + # 3. Cleanup & log artifacts + parallel_backend.wait_for_everyone() + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + + # Remove all hooks that might have been added during pipeline initialization to the models + pipeline.remove_all_hooks() + del pipeline + module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "image_processor", "vae"] + if self.args.enable_precomputation: + self._delete_components(module_names) + torch.cuda.reset_peak_memory_stats(parallel_backend.device) + + # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. + all_artifacts = [None] * dp_world_size + if dp_world_size > 1: + torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) + else: + all_artifacts = [all_processes_artifacts] + all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] + + if parallel_backend.is_main_process: + tracker_key = "final" if final_validation else "validation" + artifact_log_dict = {} + + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + if len(image_artifacts) > 0: + artifact_log_dict["images"] = image_artifacts + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + if len(video_artifacts) > 0: + artifact_log_dict["videos"] = video_artifacts + parallel_backend.log({tracker_key: artifact_log_dict}, step=step) + + if self.args.push_to_hub and final_validation: + video_filenames = list(main_process_prompts_to_filenames.values()) + prompts = list(main_process_prompts_to_filenames.keys()) + utils.save_model_card( + args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts + ) + + parallel_backend.wait_for_everyone() + if not final_validation: + self._move_components_to_device() + self.transformer.train() + + def _evaluate(self) -> None: + raise NotImplementedError("Evaluation has not been implemented yet.") + + def _init_directories_and_repositories(self) -> None: + if self.state.parallel_backend.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + self.state.output_dir = Path(self.args.output_dir) + + if self.args.push_to_hub: + repo_id = self.args.hub_model_id or Path(self.args.output_dir).name + self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id + + def _move_components_to_device( + self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None + ) -> None: + if device is None: + device = self.state.parallel_backend.device + if components is None: + components = [ + self.text_encoder, + self.text_encoder_2, + self.text_encoder_3, + self.image_encoder, + self.transformer, + self.vae, + ] + components = utils.get_non_null_items(components) + components = list(filter(lambda x: hasattr(x, "to"), components)) + for component in components: + component.to(device) + + def _set_components(self, components: Dict[str, Any]) -> None: + for component_name in self._all_component_names: + existing_component = getattr(self, component_name, None) + new_component = components.get(component_name, existing_component) + setattr(self, component_name, new_component) + + def _delete_components(self, component_names: Optional[List[str]] = None) -> None: + if component_names is None: + component_names = self._all_component_names + for component_name in component_names: + setattr(self, component_name, None) + utils.free_memory() + utils.synchronize_device() + + def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: + module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "transformer", "vae"] + + if not final_validation: + module_names.remove("transformer") + pipeline = self.model_specification.load_pipeline( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + text_encoder_3=self.text_encoder_3, + image_encoder=self.image_encoder, + image_processor=self.image_processor, + # TODO(aryan): handle unwrapping for compiled modules + # transformer=utils.unwrap_model(accelerator, self.transformer), + transformer=self.transformer, + vae=self.vae, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=True, + ) + else: + self._delete_components() + + # Load the transformer weights from the final checkpoint if performing full-finetune + transformer = None + if self.args.training_type == TrainingType.FULL_FINETUNE: + transformer = self.model_specification.load_diffusion_models()["transformer"] + + pipeline = self.model_specification.load_pipeline( + transformer=transformer, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=False, + ) + + # Load the LoRA weights if performing LoRA finetuning + if self.args.training_type == TrainingType.LORA: + pipeline.load_lora_weights( + os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}") + ) + + components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} + self._set_components(components) + if not self.args.enable_model_cpu_offload: + self._move_components_to_device(list(components.values())) + self._maybe_torch_compile() + return pipeline + + def _prepare_data( + self, + preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor], + data_iterator, + ): + if not self.args.enable_precomputation: + if not self._are_condition_models_loaded: + logger.info( + "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs." + ) + condition_components = self.model_specification.load_condition_models() + latent_components = self.model_specification.load_latent_models() + all_components = {**condition_components, **latent_components} + self._set_components(all_components) + self._move_components_to_device(list(all_components.values())) + utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) + self._maybe_torch_compile() + else: + condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))} + latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))} + + condition_iterator = preprocessor.consume( + "condition", + components=condition_components, + data_iterator=data_iterator, + generator=self.state.generator, + cache_samples=True, + ) + latent_iterator = preprocessor.consume( + "latent", + components=latent_components, + data_iterator=data_iterator, + generator=self.state.generator, + use_cached_samples=True, + drop_samples=True, + ) + + self._are_condition_models_loaded = True + else: + logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") + + parallel_backend = self.state.parallel_backend + if parallel_backend.world_size == 1: + self._move_components_to_device([self.transformer], "cpu") + utils.free_memory() + utils.synchronize_device() + torch.cuda.reset_peak_memory_stats(parallel_backend.device) + + consume_fn = preprocessor.consume_once if self.args.precomputation_once else preprocessor.consume + + # Prepare condition iterators + condition_components, component_names, component_modules = {}, [], [] + if not self.args.precomputation_reuse: + condition_components = self.model_specification.load_condition_models() + component_names = list(condition_components.keys()) + component_modules = list(condition_components.values()) + self._set_components(condition_components) + self._move_components_to_device(component_modules) + self._maybe_torch_compile() + condition_iterator = consume_fn( + "condition", + components=condition_components, + data_iterator=data_iterator, + generator=self.state.generator, + cache_samples=True, + ) + self._delete_components(component_names) + del condition_components, component_names, component_modules + + # Prepare latent iterators + latent_components, component_names, component_modules = {}, [], [] + if not self.args.precomputation_reuse: + latent_components = self.model_specification.load_latent_models() + utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) + component_names = list(latent_components.keys()) + component_modules = list(latent_components.values()) + self._set_components(latent_components) + self._move_components_to_device(component_modules) + self._maybe_torch_compile() + latent_iterator = consume_fn( + "latent", + components=latent_components, + data_iterator=data_iterator, + generator=self.state.generator, + use_cached_samples=True, + drop_samples=True, + ) + self._delete_components(component_names) + del latent_components, component_names, component_modules + + if parallel_backend.world_size == 1: + self._move_components_to_device([self.transformer]) + + return condition_iterator, latent_iterator + + def _maybe_torch_compile(self): + for model_name, compile_scope in zip(self.args.compile_modules, self.args.compile_scopes): + model = getattr(self, model_name, None) + if model is not None: + logger.info(f"Applying torch.compile to '{model_name}' with scope '{compile_scope}'.") + compiled_model = utils.apply_compile(model, compile_scope) + setattr(self, model_name, compiled_model) + + def _get_training_info(self) -> Dict[str, Any]: + info = self.args.to_dict() + + # Removing flow matching arguments when not using flow-matching objective + diffusion_args = info.get("diffusion_arguments", {}) + scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" + if scheduler_name != "FlowMatchEulerDiscreteScheduler": + filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} + else: + filtered_diffusion_args = diffusion_args + + info.update({"diffusion_arguments": filtered_diffusion_args}) + return info + + # fmt: off + _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "image_processor", "transformer", "unet", "vae", "scheduler"] + _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"] + _latent_component_names = ["image_encoder", "image_processor", "vae"] + _diffusion_component_names = ["transformer", "unet", "scheduler"] + # fmt: on diff --git a/docs/finetrainers-src-codebase/finetrainers/typing.py b/docs/finetrainers-src-codebase/finetrainers/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b3b339f252d8f47ef0ff67aa6c6733a2ccd7cf --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/typing.py @@ -0,0 +1,11 @@ +from typing import Union + +from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler +from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast + +from .data import ImageArtifact, VideoArtifact + + +ArtifactType = Union[ImageArtifact, VideoArtifact] +SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler] +TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast] diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py b/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56fd3b2819959cc968b338bc6e7bd7051b05c77b --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py @@ -0,0 +1,51 @@ +import inspect +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from .activation_checkpoint import apply_activation_checkpointing +from .args_config import ArgsConfigMixin +from .data import determine_batch_size, should_perform_precomputation +from .diffusion import ( + _enable_vae_memory_optimizations, + default_flow_shift, + get_scheduler_alphas, + get_scheduler_sigmas, + prepare_loss_weights, + prepare_sigmas, + prepare_target, + resolution_dependent_timestep_flow_shift, +) +from .file import delete_files, find_files, string_to_filename +from .hub import save_model_card +from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous +from .model import resolve_component_cls +from .serialization import safetensors_torch_save_function +from .timing import Timer, TimerDevice +from .torch import ( + align_device_and_dtype, + apply_compile, + clip_grad_norm_, + enable_determinism, + expand_tensor_dims, + get_device_info, + get_submodule_by_name, + get_unwrapped_model_state_dict, + is_compiled_module, + set_requires_grad, + synchronize_device, + unwrap_module, +) + + +def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]: + if method_name is not None: + obj = getattr(obj, method_name) + return {name for name, _ in inspect.signature(obj).parameters.items()} + + +def get_non_null_items( + x: Union[List[Any], Tuple[Any], Dict[str, Any]], +) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: + if isinstance(x, dict): + return {k: v for k, v in x.items() if v is not None} + if isinstance(x, (list, tuple)): + return type(x)(v for v in x if v is not None) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/_common.py b/docs/finetrainers-src-codebase/finetrainers/utils/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a54abe26707181d6f4795e99f24fae34b911b2b9 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/_common.py @@ -0,0 +1,7 @@ +DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [ + "transformer_blocks", + "single_transformer_blocks", + "temporal_transformer_blocks", + "blocks", + "layers", +] diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py b/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4193a6cc027a771fe1fc2c3cb34595fbc336b2 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py @@ -0,0 +1,71 @@ +import collections +from enum import Enum + +import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + +from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES + + +class CheckpointType(str, Enum): + FULL = "full" + OPS = "ops" + BLOCK_SKIP = "block_skip" + + +_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, +} + + +def apply_activation_checkpointing( + module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 +) -> torch.nn.Module: + if checkpointing_type == CheckpointType.FULL: + module = _apply_activation_checkpointing_blocks(module) + elif checkpointing_type == CheckpointType.OPS: + module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) + elif checkpointing_type == CheckpointType.BLOCK_SKIP: + module = _apply_activation_checkpointing_blocks(module, n_layer) + else: + raise ValueError( + f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" + ) + return module + + +def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: + for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: + blocks: torch.nn.Module = getattr(module, transformer_block_name, None) + if blocks is None: + continue + for index, (layer_id, block) in enumerate(blocks.named_children()): + if n_layer is None or index % n_layer == 0: + block = checkpoint_wrapper(block, preserve_rng_state=False) + blocks.register_module(layer_id, block) + return module + + +def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = collections.defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py b/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py new file mode 100644 index 0000000000000000000000000000000000000000..64a1ed0754116615cc885202da218da920ce52fb --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py @@ -0,0 +1,20 @@ +import argparse +from typing import TYPE_CHECKING, Any, Dict + + +if TYPE_CHECKING: + from finetrainers.args import BaseArgs + + +class ArgsConfigMixin: + def add_args(self, parser: argparse.ArgumentParser): + raise NotImplementedError("ArgsConfigMixin::add_args should be implemented by subclasses.") + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + raise NotImplementedError("ArgsConfigMixin::map_args should be implemented by subclasses.") + + def validate_args(self, args: "BaseArgs"): + raise NotImplementedError("ArgsConfigMixin::validate_args should be implemented by subclasses.") + + def to_dict(self) -> Dict[str, Any]: + return {} diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/data.py b/docs/finetrainers-src-codebase/finetrainers/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ecebdcf90b5d1ff719f2d5b18d5946bf47bde97a --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/data.py @@ -0,0 +1,51 @@ +from pathlib import Path +from typing import Any, Union + +import torch + +from finetrainers.constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME +from finetrainers.logging import get_logger + + +logger = get_logger() + + +def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool: + if isinstance(precomputation_dir, str): + precomputation_dir = Path(precomputation_dir) + conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + if conditions_dir.exists() and latents_dir.exists(): + num_files_conditions = len(list(conditions_dir.glob("*.pt"))) + num_files_latents = len(list(latents_dir.glob("*.pt"))) + if num_files_conditions != num_files_latents: + logger.warning( + f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})." + f"Cleaning up precomputed directories and re-running precomputation." + ) + # clean up precomputed directories + for file in conditions_dir.glob("*.pt"): + file.unlink() + for file in latents_dir.glob("*.pt"): + file.unlink() + return True + if num_files_conditions > 0: + logger.info(f"Found {num_files_conditions} precomputed conditions and latents.") + return False + logger.info("Precomputed data not found. Running precomputation.") + return True + + +def determine_batch_size(x: Any) -> int: + if isinstance(x, list): + return len(x) + if isinstance(x, torch.Tensor): + return x.size(0) + if isinstance(x, dict): + for key in x: + try: + return determine_batch_size(x[key]) + except ValueError: + pass + return 1 + raise ValueError("Could not determine batch size from input.") diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py b/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed3746c160b7aa1ea96fb382ccbece85db6ae42 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py @@ -0,0 +1,152 @@ +import math +from typing import Optional, Union + +import torch +from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.training_utils import compute_loss_weighting_for_sd3 + + +# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47 +def resolution_dependent_timestep_flow_shift( + latents: torch.Tensor, + sigmas: torch.Tensor, + base_image_seq_len: int = 256, + max_image_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> torch.Tensor: + image_or_video_sequence_length = 0 + if latents.ndim == 4: + image_or_video_sequence_length = latents.shape[2] * latents.shape[3] + elif latents.ndim == 5: + image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4] + else: + raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor") + + m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) + b = base_shift - m * base_image_seq_len + mu = m * image_or_video_sequence_length + b + sigmas = default_flow_shift(latents, sigmas, shift=mu) + return sigmas + + +def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor: + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + return sigmas + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + r""" + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device=device, generator=generator) + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device=device, generator=generator) + return u + + +def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return None + elif isinstance(scheduler, CogVideoXDDIMScheduler): + return scheduler.alphas_cumprod.clone() + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return scheduler.sigmas.clone() + elif isinstance(scheduler, CogVideoXDDIMScheduler): + return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps) + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def prepare_sigmas( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + sigmas: torch.Tensor, + batch_size: int, + num_train_timesteps: int, + flow_weighting_scheme: str = "none", + flow_logit_mean: float = 0.0, + flow_logit_std: float = 1.0, + flow_mode_scale: float = 1.29, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + weights = compute_density_for_timestep_sampling( + weighting_scheme=flow_weighting_scheme, + batch_size=batch_size, + logit_mean=flow_logit_mean, + logit_std=flow_logit_std, + mode_scale=flow_mode_scale, + device=device, + generator=generator, + ) + indices = (weights * num_train_timesteps).long() + elif isinstance(scheduler, CogVideoXDDIMScheduler): + # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes. + weights = torch.rand(size=(batch_size,), device=device, generator=generator) + indices = (weights * num_train_timesteps).long() + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + return sigmas[indices] + + +def prepare_loss_weights( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + alphas: Optional[torch.Tensor] = None, + sigmas: Optional[torch.Tensor] = None, + flow_weighting_scheme: str = "none", +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) + elif isinstance(scheduler, CogVideoXDDIMScheduler): + # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas). + # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results. + return 1 / (1 - alphas) + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + +def prepare_target( + scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], + noise: torch.Tensor, + latents: torch.Tensor, +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + target = noise - latents + elif isinstance(scheduler, CogVideoXDDIMScheduler): + target = latents + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + return target + + +def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False): + if hasattr(vae, "enable_slicing") and enable_slicing: + vae.enable_slicing() + if hasattr(vae, "enable_tiling") and enable_tiling: + vae.enable_tiling() diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/file.py b/docs/finetrainers-src-codebase/finetrainers/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8308b32c5be0427e7c1b8fd3d1f978aad002a1 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/file.py @@ -0,0 +1,51 @@ +import pathlib +import shutil +from pathlib import Path +from typing import List, Union + +from finetrainers.logging import get_logger + + +logger = get_logger() + + +def find_files(root: str, pattern: str, depth: int = 0) -> List[str]: + root_path = pathlib.Path(root) + result_files = [] + + def within_depth(path: pathlib.Path) -> bool: + return len(path.relative_to(root_path).parts) <= depth + + if depth == 0: + result_files.extend([str(file) for file in root_path.glob(pattern)]) + else: + for file in root_path.rglob(pattern): + if not file.is_file() or not within_depth(file.parent): + continue + result_files.append(str(file)) + + return result_files + + +def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: + if not isinstance(dirs, list): + dirs = [dirs] + dirs = [Path(d) if isinstance(d, str) else d for d in dirs] + logger.debug(f"Deleting files: {dirs}") + for dir in dirs: + if not dir.exists(): + continue + shutil.rmtree(dir, ignore_errors=True) + + +def string_to_filename(s: str) -> str: + return ( + s.replace(" ", "-") + .replace("/", "-") + .replace(":", "-") + .replace(".", "-") + .replace(",", "-") + .replace(";", "-") + .replace("!", "-") + .replace("?", "-") + ) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/hub.py b/docs/finetrainers-src-codebase/finetrainers/utils/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1a16eb42cbb1f2848376440817a3e1680ce61c --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/hub.py @@ -0,0 +1,77 @@ +import os +from typing import List, Union + +import numpy as np +import wandb +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from PIL import Image + + +def save_model_card( + args, + repo_id: str, + videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]], + validation_prompts: List[str], + fps: int = 30, +) -> None: + widget_dict = [] + output_dir = str(args.output_dir) + if videos is not None and len(videos) > 0: + for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)): + if not isinstance(video, str): + export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"}, + } + ) + + model_description = f""" +# LoRA Finetune + + + +## Model description + +This is a lora finetune of model: `{args.pretrained_model_name_or_path}`. + +The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +TODO +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. +""" + if wandb.run.url: + model_description += f""" +Find out the wandb run URL and training configurations [here]({wandb.run.url}). +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + base_model=args.pretrained_model_name_or_path, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(args.output_dir, "README.md")) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py b/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ac2470313ca116ec1dd6ec88cfd36fadd28b0b --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py @@ -0,0 +1,129 @@ +import importlib +import importlib.util +import operator as op +from typing import Union + +import importlib_metadata +from packaging.version import Version, parse + +from finetrainers.logging import get_logger + + +logger = get_logger() + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + + +# This function was copied from: https://github.com/huggingface/diffusers/blob/5873377a660dac60a6bd86ef9b4fdfc385305977/src/diffusers/utils/import_utils.py#L57 +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Compares a library version to some requirement using a given operation. + + Args: + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") +_datasets_available, _datasets_version = _is_package_available("datasets") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_kornia_available, _kornia_version = _is_package_available("kornia") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_torch_available, _torch_version = _is_package_available("torch") +_xformers_available, _xformers_version = _is_package_available("xformers") + + +def is_bitsandbytes_available(): + return _bitsandbytes_available + + +def is_datasets_available(): + return _datasets_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +def is_kornia_available(): + return _kornia_available + + +def is_sageattention_available(): + return _sageattention_available + + +def is_torch_available(): + return _torch_available + + +def is_xformers_available(): + return _xformers_available + + +def is_bitsandbytes_version(operation: str, version: str): + if not _bitsandbytes_available: + return False + return compare_versions(parse(_bitsandbytes_version), operation, version) + + +def is_datasets_version(operation: str, version: str): + if not _datasets_available: + return False + return compare_versions(parse(_datasets_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + +def is_kornia_version(operation: str, version: str): + if not _kornia_available: + return False + return compare_versions(parse(_kornia_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_torch_version(operation: str, version: str): + if not _torch_available: + return False + return compare_versions(parse(_torch_version), operation, version) + + +def is_xformers_version(operation: str, version: str): + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/memory.py b/docs/finetrainers-src-codebase/finetrainers/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..99f5239f32b02ce78e8f38d0212ac58378fc2772 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/memory.py @@ -0,0 +1,59 @@ +import gc +from typing import Any, Dict, Union + +import torch + +from finetrainers.logging import get_logger + + +logger = get_logger() + + +def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: + memory_allocated = None + memory_reserved = None + max_memory_allocated = None + max_memory_reserved = None + + if torch.cuda.is_available(): + device = torch.cuda.current_device() + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + elif torch.backends.mps.is_available(): + memory_allocated = torch.mps.current_allocated_memory() + + else: + logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") + + return { + "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), + "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), + "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), + "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + } + + +def bytes_to_gigabytes(x: int) -> float: + if x is not None: + return x / 1024**3 + + +def free_memory() -> None: + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + # TODO(aryan): handle non-cuda devices + + +def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if isinstance(x, torch.Tensor): + return x.contiguous() + elif isinstance(x, dict): + return {k: make_contiguous(v) for k, v in x.items()} + else: + return x diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/model.py b/docs/finetrainers-src-codebase/finetrainers/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4427f97d25ed44b2d9832cf456b082f65d66c2a8 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/model.py @@ -0,0 +1,32 @@ +import importlib +import json +import os +from typing import Optional + +from huggingface_hub import hf_hub_download + + +def resolve_component_cls( + pretrained_model_name_or_path: str, + component_name: str, + filename: str = "model_index.json", + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path): + index_path = os.path.join(pretrained_model_name_or_path, filename) + else: + index_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir + ) + + with open(index_path, "r") as f: + model_index_dict = json.load(f) + + if component_name not in model_index_dict: + raise ValueError(f"No {component_name} found in the model index dict.") + + cls_config = model_index_dict[component_name] + library = importlib.import_module(cls_config[0]) + return getattr(library, cls_config[1]) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py b/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..d15b53ae28253cce99d13924885f5f9af7f1ff20 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py @@ -0,0 +1,10 @@ +from typing import Any, Dict, Optional + +import safetensors.torch + + +def safetensors_torch_save_function(weights: Dict[str, Any], filename: str, metadata: Optional[Dict[str, str]] = None): + if metadata is None: + metadata = {} + metadata["format"] = "pt" + safetensors.torch.save_file(weights, filename, metadata) diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/timing.py b/docs/finetrainers-src-codebase/finetrainers/utils/timing.py new file mode 100644 index 0000000000000000000000000000000000000000..99faf75770ae24fc7ad4dd863add4ea698b7968b --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/timing.py @@ -0,0 +1,108 @@ +import time +from dataclasses import dataclass +from enum import Enum + +import torch + +from finetrainers.constants import FINETRAINERS_ENABLE_TIMING +from finetrainers.logging import get_logger + + +logger = get_logger() + + +class TimerDevice(str, Enum): + CPU = "cpu" + CUDA = "cuda" + + +@dataclass +class TimerData: + name: str + device: TimerDevice + start_time: float = 0.0 + end_time: float = 0.0 + + +class Timer: + def __init__(self, name: str, device: TimerDevice, device_sync: bool = False): + self.data = TimerData(name=name, device=device) + + self._device_sync = device_sync + self._start_event = None + self._end_event = None + self._active = False + self._enabled = FINETRAINERS_ENABLE_TIMING + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() + return False + + def start(self): + if self._active: + logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.") + return + self._active = True + if not self._enabled: + return + if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): + self._start_cuda() + else: + self._start_cpu() + if not self.data.device == TimerDevice.CPU: + logger.warning( + f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." + ) + + def end(self): + if not self._active: + logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.") + return + self._active = False + if not self._enabled: + return + if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): + self._end_cuda() + else: + self._end_cpu() + if not self.data.device == TimerDevice.CPU: + logger.warning( + f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." + ) + + @property + def elapsed_time(self) -> float: + if self._active: + if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): + premature_end_event = torch.cuda.Event(enable_timing=True) + premature_end_event.record() + premature_end_event.synchronize() + return self._start_event.elapsed_time(premature_end_event) / 1000.0 + else: + return time.time() - self.data.start_time + else: + if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): + return self._start_event.elapsed_time(self._end_event) / 1000.0 + else: + return self.data.end_time - self.data.start_time + + def _start_cpu(self): + self.data.start_time = time.time() + + def _start_cuda(self): + torch.cuda.synchronize() + self._start_event = torch.cuda.Event(enable_timing=True) + self._end_event = torch.cuda.Event(enable_timing=True) + self._start_event.record() + + def _end_cpu(self): + self.data.end_time = time.time() + + def _end_cuda(self): + if self._device_sync: + torch.cuda.synchronize() + self._end_event.record() diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/torch.py b/docs/finetrainers-src-codebase/finetrainers/utils/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf064db8c923bd0e345651d3463bcb9caa3baa6 --- /dev/null +++ b/docs/finetrainers-src-codebase/finetrainers/utils/torch.py @@ -0,0 +1,395 @@ +import math +import os +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.backends +import torch.distributed as dist +import torch.distributed.tensor + +from finetrainers.logging import get_logger + + +logger = get_logger() + +_STRING_TO_DTYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()} + +_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x + + +def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module: + r"""Apply torch.compile to a model or its submodules if not already compiled.""" + if getattr(model, "_torch_compiled", False): + return model # Already compiled + + if compile_scope == "full": + model = torch.compile(model) + setattr(model, "_torch_compiled", True) + elif compile_scope == "regional": + if isinstance(model, torch.nn.ModuleList): + for name, module in model.named_children(): + if not getattr(module, "_torch_compiled", False): + compiled_module = torch.compile(module) + setattr(compiled_module, "_torch_compiled", True) + model.register_module(name, compiled_module) + else: + for name, module in model.named_children(): + apply_compile(module, compile_scope) + else: + raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.") + + return model + + +def _clip_grad_norm_while_handling_failing_dtensor_cases( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, +) -> Optional[torch.Tensor]: + global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES + + if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES: + try: + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh) + except NotImplementedError as e: + if "DTensor does not support cross-mesh operation" in str(e): + # https://github.com/pytorch/pytorch/issues/134212 + logger.warning( + "DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your " + "model, while combining other parallelisms such as FSDP, it could be the reason for this error. " + "Gradient clipping will be skipped and gradient norm will not be logged." + ) + except Exception as e: + logger.warning( + f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient " + f"norm will not be logged." + ) + _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True + return None + + +# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362 +@torch.no_grad() +def clip_grad_norm_( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, +) -> torch.Tensor: + r""" + Clip the gradient norm of parameters. + + Gradient norm clipping requires computing the gradient norm over the entire model. + `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. + We need to manually reduce the gradient norm across PP stages. + See https://github.com/pytorch/torchtitan/issues/596 for details. + + Args: + parameters (`torch.Tensor` or `List[torch.Tensor]`): + Tensors that will have gradients normalized. + max_norm (`float`): + Maximum norm of the gradients after clipping. + norm_type (`float`, defaults to `2.0`): + Type of p-norm to use. Can be `inf` for infinity norm. + error_if_nonfinite (`bool`, defaults to `False`): + If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`. + foreach (`bool`, defaults to `None`): + Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors + and silently fall back to the slow implementation for other device types. + pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`): + Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages. + + Returns: + `torch.Tensor`: + Total norm of the gradients + """ + grads = [p.grad for p in parameters if p.grad is not None] + + # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm` + # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + + # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. + # We can simply reduce the DTensor to get the total norm in this tensor's process group + # and then convert it to a local tensor. + # It has two purposes: + # 1. to make sure the total norm is computed correctly when PP is used (see below) + # 2. to return a reduced total_norm tensor whose .item() would return the correct value + if isinstance(total_norm, torch.distributed.tensor.DTensor): + # Will reach here if any non-PP parallelism is used. + # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() + + if pp_mesh is not None: + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +def enable_determinism( + seed: int, + world_mesh: Optional[torch.distributed.DeviceMesh] = None, + deterministic: bool = False, +) -> None: + r""" + For all ranks within the same DTensor SPMD group, the same seed will be set. + For PP groups, different seeds will be set. + """ + if deterministic: + logger.info("Deterministic algorithms are enabled (expect performance degradation).") + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + if not world_mesh: + if seed is not None: + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + logger.debug(f"Single-process job using seed: {seed}") + return + + # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, + # and choose a unique seed for each rank on the PP mesh. + if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names: + pp_mesh = world_mesh["pp"] + seed += pp_mesh.get_local_rank() + seed %= 2**64 + + info = { + "pp_rank": pp_mesh.get_local_rank(), + "global_rank": torch.distributed.distributed_c10d.get_rank(), + "seed": seed, + } + logger.debug(f"Enabling determinism: {info}") + spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names)) + spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None + else: + spmd_mesh = world_mesh + info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed} + logger.debug(f"Enabling determinism: {info}") + + # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency + torch.manual_seed(seed) + # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + + # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. + # IF PP is also used, this seed is unique per PP rank. + if spmd_mesh and spmd_mesh.get_coordinate() is not None: + torch.distributed.tensor._random.manual_seed(seed, spmd_mesh) + + +def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: + assert len(tensor.shape) <= ndim + return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) + + +def get_device_info(): + from torch._utils import _get_available_device_type, _get_device_module + + device_type = _get_available_device_type() + if device_type is None: + device_type = "cuda" + device_module = _get_device_module(device_type) + return device_type, device_module + + +def get_dtype_from_string(dtype: str): + return _STRING_TO_DTYPE[dtype] + + +def get_string_from_dtype(dtype: torch.dtype): + return _DTYPE_TO_STRING[dtype] + + +def get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + assert name.count("*") <= 1, "Wildcard '*' can only be used once in the name" + return _find_submodule_by_name(model, name) + + +def get_unwrapped_model_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + # Remove _orig_mod occurrences from the state dict keys + return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} + + +def is_compiled_module(module) -> bool: + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None: + if isinstance(models, torch.nn.Module): + models = [models] + for model in models: + if model is not None: + model.requires_grad_(value) + + +def synchronize_device() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif torch.backends.mps.is_available(): + torch.mps.synchronize() + + +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + +def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: + if name == "": + return model + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") + if first_atom == "*": + # Wildcard '*' can only be used once in the name + assert isinstance(model, torch.nn.ModuleList), "Wildcard '*' can only be used with ModuleList" + submodules = [] + for submodule in model: + subsubmodules = _find_submodule_by_name(submodule, remaining_name) + if not isinstance(subsubmodules, list): + subsubmodules = [subsubmodules] + submodules.extend(subsubmodules) + return submodules + else: + if hasattr(model, first_atom): + submodule = getattr(model, first_atom) + return _find_submodule_by_name(submodule, remaining_name) + else: + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") + + +# TODO(aryan): remove everything below this after next torch release +def _get_total_norm( + tensors: Union[torch.Tensor, List[torch.Tensor]], + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + tensors = list(tensors) + norm_type = float(norm_type) + if len(tensors) == 0: + return torch.tensor(0.0) + first_device = tensors[0].device + grouped_tensors: dict[tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]]] = ( + _group_tensors_by_device_and_dtype( + [tensors] # type: ignore[list-item] + ) + ) # type: ignore[assignment] + + norms: List[torch.Tensor] = [] + for (device, _), ([device_tensors], _) in grouped_tensors.items(): + if (foreach is None and _has_foreach_support(device_tensors, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_tensors, norm_type)) + elif foreach: + raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors]) + + total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + return total_norm + + +@torch.no_grad() +def _clip_grads_with_norm_( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + total_norm: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + if len(grads) == 0: + return + grouped_grads: dict[Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]]] = ( + _group_tensors_by_device_and_dtype([grads]) + ) # type: ignore[assignment] + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + +def _get_foreach_kernels_supported_devices() -> list[str]: + r"""Return the device type list that supports foreach kernels.""" + return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] + + +@torch.no_grad() +def _group_tensors_by_device_and_dtype( + tensorlistlist: List[List[Optional[torch.Tensor]]], + with_indices: bool = False, +) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]: + return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) + + +def _device_has_foreach_support(device: torch.device) -> bool: + return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() + + +def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool: + return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors) diff --git a/docs/finetrainers-src-codebase/pyproject.toml b/docs/finetrainers-src-codebase/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..79d64d3b5d8c741d94ebea0519d1f518dcfdf473 --- /dev/null +++ b/docs/finetrainers-src-codebase/pyproject.toml @@ -0,0 +1,28 @@ +[tool.ruff] +line-length = 119 + +[tool.ruff.lint] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "F402", "F823"] +select = ["C", "E", "F", "I", "W"] + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = [] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/docs/finetrainers-src-codebase/requirements.txt b/docs/finetrainers-src-codebase/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e5b32bd2364397588002c81b44d9226ed2c9ec9a --- /dev/null +++ b/docs/finetrainers-src-codebase/requirements.txt @@ -0,0 +1,20 @@ +accelerate +bitsandbytes +datasets>=3.3.2 +diffusers>=0.32.1 +transformers>=4.45.2 +huggingface_hub +hf_transfer>=0.1.8 +peft>=0.13.0 +decord>=0.6.0 +wandb +pandas +torch>=2.5.1 +torchvision>=0.20.1 +torchdata>=0.10.1 +torchao>=0.7.0 +sentencepiece>=0.2.0 +imageio-ffmpeg>=0.5.1 +numpy>=1.26.4 +kornia>=0.7.3 +ruff==0.9.10 diff --git a/docs/finetrainers-src-codebase/setup.py b/docs/finetrainers-src-codebase/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4cad38e51b1df8178a13f363c4caafdb5917dc --- /dev/null +++ b/docs/finetrainers-src-codebase/setup.py @@ -0,0 +1,46 @@ +from setuptools import find_packages, setup + + +with open("README.md", "r", encoding="utf-8") as file: + long_description = file.read() + +with open("requirements.txt", "r", encoding="utf-8") as file: + requirements = [line for line in file.read().splitlines() if len(line) > 0] + +setup( + name="finetrainers", + version="0.2.0.dev0", + description="Finetrainers is a work-in-progress library to support (accessible) training of diffusion models", + long_description=long_description, + long_description_content_type="text/markdown", + author="Aryan V S", + author_email="contact.aryanvs@gmail.com", + url="https://github.com/a-r-r-o-w/finetrainers", + python_requires=">=3.8.0", + license="Apache-2.0", + packages=find_packages(), + install_requires=requirements, + extras_require={"dev": ["pytest==8.3.2", "ruff==0.1.5"]}, + classifiers=[ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Operating System :: Microsoft :: Windows", + "Operating System :: Unix", + "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) + +# Steps to publish: +# 1. Update version in setup.py +# 2. python setup.py sdist bdist_wheel +# 3. Check if everything works with testpypi: +# twine upload --repository testpypi dist/* +# 4. Upload to pypi: +# twine upload dist/* diff --git a/docs/finetrainers-src-codebase/tests/README.md b/docs/finetrainers-src-codebase/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ef26d6044a2e628a98baf1e5f22a162a3e3d7d1 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/README.md @@ -0,0 +1,53 @@ +# Running tests + +TODO(aryan): everything here needs to be improved. + +## `trainer/` fast tests + +- For SFT tests: `test_sft_trainer.py` +- For Control tests: `test_control_trainer.py` + +Accelerate: + +``` +# world_size=1 tests +accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___Accelerate" +accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___Accelerate" + +# world_size=2 tests +accelerate launch --config_file accelerate_configs/uncompiled_2.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___Accelerate" +``` + +PTD: + +``` +# world_size=1 tests +torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_2 and ___PTD" + +# world_size=2 tests +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_2___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_2 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_2 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___tp_degree_2___batch_size_2 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___cp_degree_2___batch_size_1 and ___PTD" + +# world_size=4 tests +torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___dp_shards_2___batch_size_1 and ___PTD" +torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___cp_degree_2___batch_size_1 and ___PTD" +``` + +## CP tests + +PTD: + +``` +# world_size=2 tests +torchrun --nnodes 1 --nproc_per_node 2 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP2Test + +# world_size=4 tests +torchrun --nnodes 1 --nproc_per_node 4 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP4Test +``` diff --git a/docs/finetrainers-src-codebase/tests/__init__.py b/docs/finetrainers-src-codebase/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/_test_dataset_old.py b/docs/finetrainers-src-codebase/tests/_test_dataset_old.py new file mode 100644 index 0000000000000000000000000000000000000000..740a9c91e710b1b9cfaebf74b43071cd837acdb9 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/_test_dataset_old.py @@ -0,0 +1,104 @@ +# Run: python3 tests/test_dataset.py + +import sys + + +def test_video_dataset(): + from cogvideox.dataset import VideoDataset + + dataset_dirs = VideoDataset( + data_root="assets/tests/", + caption_column="prompts.txt", + video_column="videos.txt", + max_num_frames=49, + id_token=None, + random_flip=None, + ) + dataset_csv = VideoDataset( + data_root="assets/tests/", + dataset_file="assets/tests/metadata.csv", + caption_column="caption", + video_column="video", + max_num_frames=49, + id_token=None, + random_flip=None, + ) + + assert len(dataset_dirs) == 1 + assert len(dataset_csv) == 1 + assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720) + assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() + + print(dataset_dirs[0]["video"].shape) + + +def test_video_dataset_with_resizing(): + from cogvideox.dataset import VideoDatasetWithResizing + + dataset_dirs = VideoDatasetWithResizing( + data_root="assets/tests/", + caption_column="prompts.txt", + video_column="videos.txt", + max_num_frames=49, + id_token=None, + random_flip=None, + ) + dataset_csv = VideoDatasetWithResizing( + data_root="assets/tests/", + dataset_file="assets/tests/metadata.csv", + caption_column="caption", + video_column="video", + max_num_frames=49, + id_token=None, + random_flip=None, + ) + + assert len(dataset_dirs) == 1 + assert len(dataset_csv) == 1 + assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720) # Changes due to T2V frame bucket sampling + assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() + + print(dataset_dirs[0]["video"].shape) + + +def test_video_dataset_with_bucket_sampler(): + import torch + from cogvideox.dataset import BucketSampler, VideoDatasetWithResizing + from torch.utils.data import DataLoader + + dataset_dirs = VideoDatasetWithResizing( + data_root="assets/tests/", + caption_column="prompts_multi.txt", + video_column="videos_multi.txt", + max_num_frames=49, + id_token=None, + random_flip=None, + ) + sampler = BucketSampler(dataset_dirs, batch_size=8) + + def collate_fn(data): + captions = [x["prompt"] for x in data[0]] + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos) + return captions, videos + + dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn) + first = False + + for captions, videos in dataloader: + if not first: + assert len(captions) == 8 and isinstance(captions[0], str) + assert videos.shape == (8, 48, 3, 480, 720) + first = True + else: + assert len(captions) == 8 and isinstance(captions[0], str) + assert videos.shape == (8, 48, 3, 256, 360) + break + + +if __name__ == "__main__": + sys.path.append("./training") + + test_video_dataset() + test_video_dataset_with_resizing() + test_video_dataset_with_bucket_sampler() diff --git a/docs/finetrainers-src-codebase/tests/data/__init__.py b/docs/finetrainers-src-codebase/tests/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/data/test_dataset.py b/docs/finetrainers-src-codebase/tests/data/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..66186aba4718320767476b46486e27716062c640 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/data/test_dataset.py @@ -0,0 +1,355 @@ +import pathlib +import tempfile +import unittest + +import torch +from PIL import Image + +from finetrainers.data import ( + ImageCaptionFilePairDataset, + ImageFileCaptionFileListDataset, + ImageFolderDataset, + ValidationDataset, + VideoCaptionFilePairDataset, + VideoFileCaptionFileListDataset, + VideoFolderDataset, + VideoWebDataset, + initialize_dataset, +) +from finetrainers.utils import find_files + +from .utils import create_dummy_directory_structure + + +class DatasetTesterMixin: + num_data_files = None + directory_structure = None + caption = "A cat ruling the world" + metadata_extension = None + + def setUp(self): + if self.num_data_files is None: + raise ValueError("num_data_files is not defined") + if self.directory_structure is None: + raise ValueError("dataset_structure is not defined") + + self.tmpdir = tempfile.TemporaryDirectory() + create_dummy_directory_structure( + self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension + ) + + def tearDown(self): + self.tmpdir.cleanup() + + +class ImageDatasetTesterMixin(DatasetTesterMixin): + metadata_extension = "jpg" + + +class VideoDatasetTesterMixin(DatasetTesterMixin): + metadata_extension = "mp4" + + +class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "0.jpg", + "1.jpg", + "2.jpg", + "0.txt", + "1.txt", + "2.txt", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(self.num_data_files): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["image"])) + self.assertEqual(item["image"].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageCaptionFilePairDataset) + + +class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "prompts.txt", + "images.txt", + "images/", + "images/0.jpg", + "images/1.jpg", + "images/2.jpg", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for i in range(3): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["image"])) + self.assertEqual(item["image"].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFileCaptionFileListDataset) + + +class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.csv", + "0.jpg", + "1.jpg", + "2.jpg", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["image"])) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFolderDataset) + + +class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.jsonl", + "0.jpg", + "1.jpg", + "2.jpg", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["image"])) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFolderDataset) + + +class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "0.mp4", + "1.mp4", + "2.mp4", + "0.txt", + "1.txt", + "2.txt", + ] + + def setUp(self): + super().setUp() + self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(self.num_data_files): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 4) + self.assertEqual(item["video"][0].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoCaptionFilePairDataset) + + +class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "prompts.txt", + "videos.txt", + "videos/", + "videos/0.mp4", + "videos/1.mp4", + "videos/2.mp4", + ] + + def setUp(self): + super().setUp() + self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 4) + self.assertEqual(item["video"][0].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFileCaptionFileListDataset) + + +class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.csv", + "0.mp4", + "1.mp4", + "2.mp4", + ] + + def setUp(self): + super().setUp() + self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 4) + self.assertEqual(item["video"][0].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFolderDataset) + + +class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.jsonl", + "0.mp4", + "1.mp4", + "2.mp4", + ] + + def setUp(self): + super().setUp() + self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 4) + self.assertEqual(item["video"][0].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFolderDataset) + + +class ImageWebDatasetFastTests(unittest.TestCase): + # TODO(aryan): setup a dummy dataset + pass + + +class VideoWebDatasetFastTests(unittest.TestCase): + def setUp(self): + self.num_data_files = 15 + self.dataset = VideoWebDataset("finetrainers/dummy-squish-wds", infinite=False) + + def test_getitem(self): + for index, item in enumerate(self.dataset): + if index > 2: + break + self.assertIn("caption", item) + self.assertIn("video", item) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 121) + self.assertEqual(item["video"][0].shape, (3, 720, 1280)) + + def test_initialize_dataset(self): + dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False) + self.assertIsInstance(dataset, VideoWebDataset) + + +class DatasetUtilsFastTests(unittest.TestCase): + def test_find_files_depth_0(self): + with tempfile.TemporaryDirectory() as tmpdir: + file1 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) + file2 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) + file3 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) + + files = find_files(tmpdir, "*.txt") + self.assertEqual(len(files), 3) + self.assertIn(file1.name, files) + self.assertIn(file2.name, files) + self.assertIn(file3.name, files) + + def test_find_files_depth_n(self): + with tempfile.TemporaryDirectory() as tmpdir: + dir1 = tempfile.TemporaryDirectory(dir=tmpdir) + dir2 = tempfile.TemporaryDirectory(dir=dir1.name) + file1 = tempfile.NamedTemporaryFile(dir=dir1.name, suffix=".txt", delete=False) + file2 = tempfile.NamedTemporaryFile(dir=dir2.name, suffix=".txt", delete=False) + + files = find_files(tmpdir, "*.txt", depth=1) + self.assertEqual(len(files), 1) + self.assertIn(file1.name, files) + self.assertNotIn(file2.name, files) + + files = find_files(tmpdir, "*.txt", depth=2) + self.assertEqual(len(files), 2) + self.assertIn(file1.name, files) + self.assertIn(file2.name, files) + self.assertNotIn(dir1.name, files) + self.assertNotIn(dir2.name, files) + + +class ValidationDatasetFastTests(unittest.TestCase): + def setUp(self): + num_data_files = 3 + + self.tmpdir = tempfile.TemporaryDirectory() + metadata_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" + + with open(metadata_filename, "w") as f: + f.write("caption,image_path,video_path\n") + for i in range(num_data_files): + Image.new("RGB", (64, 64)).save((pathlib.Path(self.tmpdir.name) / f"{i}.jpg").as_posix()) + f.write(f"test caption,{self.tmpdir.name}/{i}.jpg,\n") + + self.dataset = ValidationDataset(metadata_filename.as_posix()) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_getitem(self): + for i, data in enumerate(self.dataset): + self.assertEqual(data["image_path"], f"{self.tmpdir.name}/{i}.jpg") + self.assertIsInstance(data["image"], Image.Image) + self.assertEqual(data["image"].size, (64, 64)) diff --git a/docs/finetrainers-src-codebase/tests/data/test_precomputation.py b/docs/finetrainers-src-codebase/tests/data/test_precomputation.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1879f3c82b2e900ffa229d50b2cfefd506b698 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/data/test_precomputation.py @@ -0,0 +1,212 @@ +import os +import tempfile +import unittest + +from finetrainers.data import ( + InMemoryDistributedDataPreprocessor, + PrecomputedDistributedDataPreprocessor, + VideoCaptionFilePairDataset, + initialize_preprocessor, + wrap_iterable_dataset_for_preprocessing, +) +from finetrainers.data.precomputation import PRECOMPUTED_DATA_DIR +from finetrainers.utils import find_files + +from .utils import create_dummy_directory_structure + + +class PreprocessorFastTests(unittest.TestCase): + def setUp(self): + self.rank = 0 + self.world_size = 1 + self.num_items = 3 + self.processor_fn = { + "latent": self._latent_processor_fn, + "condition": self._condition_processor_fn, + } + self.save_dir = tempfile.TemporaryDirectory() + + directory_structure = [ + "0.mp4", + "1.mp4", + "2.mp4", + "0.txt", + "1.txt", + "2.txt", + ] + create_dummy_directory_structure( + directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4" + ) + + dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True) + dataset = wrap_iterable_dataset_for_preprocessing( + dataset, + dataset_type="video", + config={ + "video_resolution_buckets": [[2, 32, 32]], + "reshape_mode": "bicubic", + }, + ) + self.dataset = dataset + + def tearDown(self): + self.save_dir.cleanup() + + @staticmethod + def _latent_processor_fn(**data): + video = data["video"] + video = video[:, :, :16, :16] + data["video"] = video + return data + + @staticmethod + def _condition_processor_fn(**data): + caption = data["caption"] + caption = caption + " surrounded by mystical aura" + data["caption"] = caption + return data + + def test_initialize_preprocessor(self): + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=False, + ) + self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor) + + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=True, + ) + self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor) + + def test_in_memory_preprocessor_consume(self): + data_iterator = iter(self.dataset) + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=False, + ) + + condition_iterator = preprocessor.consume( + "condition", components={}, data_iterator=data_iterator, cache_samples=True + ) + latent_iterator = preprocessor.consume( + "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True + ) + + self.assertFalse(preprocessor.requires_data) + for _ in range(self.num_items): + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + self.assertIn("caption", condition_item) + self.assertIn("video", latent_item) + self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") + self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) + self.assertTrue(preprocessor.requires_data) + + def test_in_memory_preprocessor_consume_once(self): + data_iterator = iter(self.dataset) + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=False, + ) + + condition_iterator = preprocessor.consume_once( + "condition", components={}, data_iterator=data_iterator, cache_samples=True + ) + latent_iterator = preprocessor.consume_once( + "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True + ) + + self.assertFalse(preprocessor.requires_data) + for _ in range(self.num_items): + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + self.assertIn("caption", condition_item) + self.assertIn("video", latent_item) + self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") + self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) + self.assertFalse(preprocessor.requires_data) + + def test_precomputed_preprocessor_consume(self): + data_iterator = iter(self.dataset) + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=True, + ) + + condition_iterator = preprocessor.consume( + "condition", components={}, data_iterator=data_iterator, cache_samples=True + ) + latent_iterator = preprocessor.consume( + "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True + ) + + precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) + condition_file_list = find_files(precomputed_data_dir, "condition-*") + latent_file_list = find_files(precomputed_data_dir, "latent-*") + self.assertEqual(len(condition_file_list), 3) + self.assertEqual(len(latent_file_list), 3) + + self.assertFalse(preprocessor.requires_data) + for _ in range(self.num_items): + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + self.assertIn("caption", condition_item) + self.assertIn("video", latent_item) + self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") + self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) + self.assertTrue(preprocessor.requires_data) + + def test_precomputed_preprocessor_consume_once(self): + data_iterator = iter(self.dataset) + preprocessor = initialize_preprocessor( + self.rank, + self.world_size, + self.num_items, + self.processor_fn, + self.save_dir.name, + enable_precomputation=True, + ) + + condition_iterator = preprocessor.consume_once( + "condition", components={}, data_iterator=data_iterator, cache_samples=True + ) + latent_iterator = preprocessor.consume_once( + "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True + ) + + precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) + condition_file_list = find_files(precomputed_data_dir, "condition-*") + latent_file_list = find_files(precomputed_data_dir, "latent-*") + self.assertEqual(len(condition_file_list), 3) + self.assertEqual(len(latent_file_list), 3) + + self.assertFalse(preprocessor.requires_data) + for _ in range(self.num_items): + condition_item = next(condition_iterator) + latent_item = next(latent_iterator) + self.assertIn("caption", condition_item) + self.assertIn("video", latent_item) + self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") + self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) + self.assertFalse(preprocessor.requires_data) diff --git a/docs/finetrainers-src-codebase/tests/data/utils.py b/docs/finetrainers-src-codebase/tests/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4cc1fe38bbd311ce718d5f35f91eb2c12c313e --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/data/utils.py @@ -0,0 +1,53 @@ +import pathlib +from typing import List + +from diffusers.utils import export_to_video +from PIL import Image + +from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES # noqa + + +def create_dummy_directory_structure( + directory_structure: List[str], tmpdir, num_data_files: int, caption: str, metadata_extension: str +): + for item in directory_structure: + # TODO(aryan): this should be improved + if item in COMMON_CAPTION_FILES: + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for _ in range(num_data_files): + f.write(f"{caption}\n") + elif item in COMMON_IMAGE_FILES: + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(num_data_files): + f.write(f"images/{i}.jpg\n") + elif item in COMMON_VIDEO_FILES: + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(num_data_files): + f.write(f"videos/{i}.mp4\n") + elif item == "metadata.csv": + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + f.write("file_name,caption\n") + for i in range(num_data_files): + f.write(f"{i}.{metadata_extension},{caption}\n") + elif item == "metadata.jsonl": + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(num_data_files): + f.write(f'{{"file_name": "{i}.{metadata_extension}", "caption": "{caption}"}}\n') + elif item.endswith(".txt"): + data_file = pathlib.Path(tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + f.write(caption) + elif item.endswith(".jpg") or item.endswith(".png"): + data_file = pathlib.Path(tmpdir.name) / item + Image.new("RGB", (64, 64)).save(data_file.as_posix()) + elif item.endswith(".mp4"): + data_file = pathlib.Path(tmpdir.name) / item + export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2) + else: + data_file = pathlib.Path(tmpdir.name, item) + data_file.mkdir(exist_ok=True, parents=True) diff --git a/docs/finetrainers-src-codebase/tests/models/__init__.py b/docs/finetrainers-src-codebase/tests/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py b/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..1de978d248b2733bab0e049555a78cd3d1832ce0 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py @@ -0,0 +1,363 @@ +import os +import random +import unittest + +import numpy as np +import torch +from torch.nn.functional import scaled_dot_product_attention + +from finetrainers.models.attention_dispatch import ( + AttentionProvider, + _AttentionProviderRegistry, + _set_context_parallel_options, + attention_dispatch, + attention_provider, + flash_attn_flash_attention, + native_cudnn_attention, + native_efficient_attention, + native_flash_attention, +) +from finetrainers.parallel.ptd import _EquipartitionSharder + + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + return int(os.environ.get("WORLD_SIZE", 1)) + + +class AttentionDispatchTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + set_seed(0) + + def test_forward(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available") + cuda_capability = torch.cuda.get_device_capability() + + query, key, value = self._create_dummy_inputs() + + all_providers = [ + (AttentionProvider._NATIVE_MATH, 0), + (AttentionProvider.NATIVE, 5e-3), + (AttentionProvider.FLASH, 5e-3), + (AttentionProvider.FLASH_VARLEN, 5e-3), + (AttentionProvider.FLEX, 2e-2), + (AttentionProvider._NATIVE_CUDNN, 5e-3), + (AttentionProvider._NATIVE_EFFICIENT, 5e-3), + (AttentionProvider._NATIVE_FLASH, 5e-3), + (AttentionProvider.SAGE, 1e-1), + (AttentionProvider.SAGE_VARLEN, 2e-0), + (AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, 2e-0), # TODO: look into the high difference threshold + (AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, 2e-0), + (AttentionProvider.XFORMERS, 5e-3), + ] + + if cuda_capability >= (8, 9): + all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, 2e-0)) + if cuda_capability >= (9, 0): + all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA_SM90, 2e-0)) + + ref_output = None + for i, (provider, threshold) in enumerate(all_providers): + try: + output = self._check_forward_pass(provider, query, key, value) + if i == 0: + ref_output = output.detach().clone() + else: + self.assertTrue( + torch.allclose(output, ref_output, atol=threshold), f"Forward pass mismatch for {provider}" + ) + except Exception as e: + print(f"Warning: Forward pass test failed for {provider} with error: {e}") + + def test_backward(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available") + + query, key, value = self._create_dummy_inputs() + + selected_providers = [ + AttentionProvider.FLASH, + AttentionProvider.FLASH_VARLEN, + AttentionProvider.FLEX, + AttentionProvider.NATIVE, + AttentionProvider.XFORMERS, + ] + + ref_output = None + for i, provider in enumerate(selected_providers): + try: + output = self._check_backward_pass(provider, query, key, value) + if i == 0: + ref_output = output.detach().clone() + else: + if provider == AttentionProvider.FLEX: + threshold = 1e-2 + else: + threshold = 1e-3 + self.assertTrue( + torch.allclose(output, ref_output, atol=threshold), f"Backward pass mismatch for {provider}" + ) + except Exception as e: + print(f"Warning: Backward pass test failed for {provider} with error: {e}") + + def _create_dummy_inputs( + self, batch_size=2, num_heads=8, seq_len=256, head_dim=64, dtype=torch.bfloat16, device="cuda" + ): + torch.manual_seed(0) + query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) + key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) + value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) + return query, key, value + + def _check_forward_pass(self, provider: AttentionProvider, query, key, value): + kwargs = {} + if provider == AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA: + kwargs["pv_accum_dtype"] = "fp32" + with attention_provider(provider): + output = attention_dispatch(query, key, value, attention_kwargs=kwargs) + self.assertIsNotNone(output) + self.assertEqual(output.shape, query.shape) + return output + + def _check_backward_pass(self, provider: AttentionProvider, query, key, value): + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + with attention_provider(provider): + output = attention_dispatch(query, key, value) + loss = output.mean() + loss.backward() + + self.assertTrue(query.grad is not None) + self.assertTrue(key.grad is not None) + self.assertTrue(value.grad is not None) + + query.grad.zero_() + key.grad.zero_() + value.grad.zero_() + return output + + +class RingAttentionTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.distributed.init_process_group(backend="nccl") + rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size() + + cls.rank = rank + cls.world_size = world_size + torch.cuda.set_device(rank) + cls.mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,)) + + set_seed(0) + cls.batch_size = 2 + cls.num_heads = 8 + cls.seq_len = 256 + cls.head_dim = 64 + cls.dtype = torch.bfloat16 + cls.device = "cuda" + + _AttentionProviderRegistry._set_context_parallel( + mesh=cls.mesh, convert_to_fp32=True, rotate_method="allgather" + ) + _set_context_parallel_options(is_causal=False) + + cls.full_query = torch.randn( + cls.batch_size, + cls.num_heads, + cls.seq_len * cls.world_size, + cls.head_dim, + dtype=cls.dtype, + device=cls.device, + requires_grad=True, + ) + cls.full_key = torch.randn( + cls.batch_size, + cls.num_heads, + cls.seq_len * cls.world_size, + cls.head_dim, + dtype=cls.dtype, + device=cls.device, + requires_grad=True, + ) + cls.full_value = torch.randn( + cls.batch_size, + cls.num_heads, + cls.seq_len * cls.world_size, + cls.head_dim, + dtype=cls.dtype, + device=cls.device, + requires_grad=True, + ) + + # Ensure all ranks have the same data + with torch.no_grad(): + torch.distributed.broadcast(cls.full_query, src=0) + torch.distributed.broadcast(cls.full_key, src=0) + torch.distributed.broadcast(cls.full_value, src=0) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + reference_output = scaled_dot_product_attention(cls.full_query, cls.full_key, cls.full_value) + + cls.reference_output = reference_output.detach().clone() + reference_output.sum().backward() + + cls.query, cls.key, cls.value = ( + _EquipartitionSharder.shard(x, dim=2, mesh=cls.mesh).detach().clone() + for x in (cls.full_query, cls.full_key, cls.full_value) + ) + + @classmethod + def tearDownClass(cls): + torch.distributed.destroy_process_group() + + def _test_forward_native_cudnn_attention(self, atol: float = 1e-3): + output = native_cudnn_attention(self.query, self.key, self.value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + self.assertEqual(output.shape, self.reference_output.shape) + self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) + + def _test_forward_native_efficient_attention(self, atol: float = 1e-3): + output = native_efficient_attention(self.query, self.key, self.value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + self.assertEqual(output.shape, self.reference_output.shape) + self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) + + def _test_forward_native_flash_attention(self, atol: float = 1e-3): + output = native_flash_attention(self.query, self.key, self.value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + self.assertEqual(output.shape, self.reference_output.shape) + self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) + + def _test_forward_flash_attn_flash_attention(self, atol: float = 1e-3): + output = flash_attn_flash_attention(self.query, self.key, self.value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + self.assertEqual(output.shape, self.reference_output.shape) + self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) + + def _test_backward_native_cudnn_attention(self, atol: float = 1e-3): + query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + output = native_cudnn_attention(query, key, value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + output.sum().backward() + with torch.no_grad(): + q_g, k_g, v_g = ( + _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) + for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) + ) + self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) + self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) + self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) + + def _test_backward_native_efficient_attention(self, atol: float = 1e-3): + query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + output = native_efficient_attention(query, key, value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + output.sum().backward() + with torch.no_grad(): + q_g, k_g, v_g = ( + _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) + for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) + ) + self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) + self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) + self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) + + def _test_backward_native_flash_attention(self, atol: float = 1e-3): + query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + output = native_flash_attention(query, key, value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + output.sum().backward() + with torch.no_grad(): + q_g, k_g, v_g = ( + _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) + for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) + ) + self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) + self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) + self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) + + def _test_backward_flash_attn_flash_attention(self, atol: float = 1e-3): + query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + output = flash_attn_flash_attention(query, key, value) + output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) + output.sum().backward() + with torch.no_grad(): + q_g, k_g, v_g = ( + _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) + for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) + ) + self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) + self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) + self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) + + +class RingAttentionCPTesterMixin: + def test_forward_native_cudnn_attention(self): + self._test_forward_native_cudnn_attention(atol=1e-2) + + def test_forward_native_efficient_attention(self): + self._test_forward_native_efficient_attention(atol=1e-2) + + def test_forward_native_flash_attention(self): + self._test_forward_native_flash_attention(atol=1e-2) + + def test_forward_flash_attn_flash_attention(self): + self._test_forward_flash_attn_flash_attention(atol=1e-2) + + def test_backward_native_cudnn_attention(self): + atol = 1e-2 * self.world_size # TODO: make bounds more strict + self._test_backward_native_cudnn_attention(atol=atol) + + def test_backward_native_efficient_attention(self): + atol = 1e-2 * self.world_size # TODO: make bounds more strict + self._test_backward_native_efficient_attention(atol=atol) + + def test_backward_native_flash_attention(self): + atol = 1e-2 * self.world_size # TODO: make bounds more strict + self._test_backward_native_flash_attention(atol=atol) + + @unittest.skip( + """query diff: 0.298828125, key diff: 2.09375, value diff: 0.68359375; Needs further investigation""" + ) + def test_backward_flash_attn_flash_attention(self): + # Seems to require much higher bound for some reason + atol = 1.5e-1 * self.world_size # TODO: make bounds more strict + self._test_backward_flash_attn_flash_attention(atol=atol) + + +@unittest.skipIf( + not torch.cuda.is_available() or get_world_size() != 2, "CUDA is not available or world size is not 2" +) +class RingAttentionCP2Test(RingAttentionTest, RingAttentionCPTesterMixin): + pass + + +@unittest.skipIf( + not torch.cuda.is_available() or get_world_size() != 4, "CUDA is not available or world size is not 4" +) +class RingAttentionCP4Test(RingAttentionTest, RingAttentionCPTesterMixin): + pass diff --git a/docs/finetrainers-src-codebase/tests/models/cogvideox/__init__.py b/docs/finetrainers-src-codebase/tests/models/cogvideox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py b/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..c02fd5e434246fafbc18e271404f7cb4a6edfd1c --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py @@ -0,0 +1,71 @@ +import torch +from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXTransformer3DModel +from transformers import AutoTokenizer, T5EncoderModel + +from finetrainers.models.cogvideox import CogVideoXModelSpecification + + +class DummyCogVideoXModelSpecification(CogVideoXModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3DModel( + num_attention_heads=4, + attention_head_dim=16, + in_channels=4, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, + num_layers=2, + sample_width=24, + sample_height=24, + sample_frames=9, + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, + use_rotary_positional_embeddings=True, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) + self.transformer_config = transformer.config + scheduler = CogVideoXDDIMScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/__init__.py b/docs/finetrainers-src-codebase/tests/models/cogview4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py b/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4f634c17c3dae14321994a129f8af005f15663 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py @@ -0,0 +1,35 @@ +import torch +from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from transformers import AutoTokenizer, GlmModel + +from finetrainers.models.cogview4 import CogView4ModelSpecification + + +class DummyCogView4ModelSpecification(CogView4ModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + text_encoder = GlmModel.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True + ) + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKL.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype + ) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = CogView4Transformer2DModel.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype + ) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py b/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..eab8c670ce307adb4b9dc0bbf8b28850d37514e8 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py @@ -0,0 +1,61 @@ +import torch +from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from transformers import AutoTokenizer, GlmConfig, GlmModel + +from finetrainers.models.cogview4 import CogView4ControlModelSpecification +from finetrainers.models.utils import _expand_linear_with_zeroed_weights + + +class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # This needs to be updated for the test to work correctly. + # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded + # with ModelSpecification::_load_configs + self.transformer_config.in_channels = 4 + + def load_condition_models(self): + text_encoder_config = GlmConfig( + hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8 + ) + text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype) + # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer + tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True) + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ).to(self.vae_dtype) + return {"vae": vae} + + def load_diffusion_models(self, new_in_features: int): + torch.manual_seed(0) + transformer = CogView4Transformer2DModel( + patch_size=2, + in_channels=4, + num_layers=2, + attention_head_dim=4, + num_attention_heads=4, + out_channels=4, + text_embed_dim=32, + time_embed_dim=8, + condition_dim=4, + ).to(self.transformer_dtype) + actual_new_in_features = new_in_features * transformer.config.patch_size**2 + transformer.patch_embed.proj = _expand_linear_with_zeroed_weights( + transformer.patch_embed.proj, new_in_features=actual_new_in_features + ) + transformer.register_to_config(in_channels=new_in_features) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/flux/__init__.py b/docs/finetrainers-src-codebase/tests/models/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py b/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..59a2391535527e7f078c0c9759808cb2fa1b0d00 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py @@ -0,0 +1,6 @@ +from finetrainers.models.flux import FluxModelSpecification + + +class DummyFluxModelSpecification(FluxModelSpecification): + def __init__(self, **kwargs): + super().__init__(pretrained_model_name_or_path="hf-internal-testing/tiny-flux-pipe", **kwargs) diff --git a/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py b/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..620a0ef1b9b431c7397a1db9d56cb6ebb17b9e6f --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py @@ -0,0 +1,119 @@ +import torch +from diffusers import AutoencoderKLHunyuanVideo, FlowMatchEulerDiscreteScheduler, HunyuanVideoTransformer3DModel +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + LlamaConfig, + LlamaModel, + LlamaTokenizer, +) + +from finetrainers.models.hunyuan_video import HunyuanVideoModelSpecification + + +class DummyHunyuanVideoModelSpecification(HunyuanVideoModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder.to(self.text_encoder_dtype) + text_encoder_2.to(self.text_encoder_2_dtype) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=2, + num_single_layers=2, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c3432d716a88a98026620654d22a4a5bbcc63ae7 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py @@ -0,0 +1,245 @@ +import copy + +import torch +import torch.distributed as dist +from diffusers import LTXVideoTransformer3DModel +from torch._utils import _get_device_module +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel.api import parallelize_module +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + RowwiseParallel, +) + + +# from torch.utils._python_dispatch import TorchDispatchMode + + +DEVICE_TYPE = "cuda" +PG_BACKEND = "nccl" +DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count() + + +def main(world_size: int, rank: int): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(rank) + + CHANNELS = 128 + CROSS_ATTENTION_DIM = 2048 + CAPTION_CHANNELS = 4096 + NUM_LAYERS = 28 + NUM_ATTENTION_HEADS = 32 + ATTENTION_HEAD_DIM = 64 + + # CHANNELS = 4 + # CROSS_ATTENTION_DIM = 32 + # CAPTION_CHANNELS = 64 + # NUM_LAYERS = 1 + # NUM_ATTENTION_HEADS = 4 + # ATTENTION_HEAD_DIM = 8 + + config = { + "in_channels": CHANNELS, + "out_channels": CHANNELS, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": NUM_ATTENTION_HEADS, + "attention_head_dim": ATTENTION_HEAD_DIM, + "cross_attention_dim": CROSS_ATTENTION_DIM, + "num_layers": NUM_LAYERS, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": CAPTION_CHANNELS, + "attention_bias": True, + "attention_out_bias": True, + } + + # Normal model + torch.manual_seed(0) + model = LTXVideoTransformer3DModel(**config).to(DEVICE_TYPE) + + # TP model + model_tp = copy.deepcopy(model) + device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size)) + print(f"Device mesh: {device_mesh}") + + transformer_tp_plan = { + # ===== Condition embeddings ===== + # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), + # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), + # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), + # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), + # "caption_projection.linear_1": ColwiseParallel(), + # "caption_projection.linear_2": RowwiseParallel(), + # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), + # ===== ===== + } + + for block in model_tp.transformer_blocks: + block_tp_plan = {} + + # ===== Attention ===== + # 8 all-to-all, 3 all-reduce + # block_tp_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn1.norm_q"] = SequenceParallel() + # block_tp_plan["attn1.norm_k"] = SequenceParallel() + # block_tp_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # block_tp_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) + # block_tp_plan["attn2.norm_q"] = SequenceParallel() + # block_tp_plan["attn2.norm_k"] = SequenceParallel() + # block_tp_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # ===== ===== + + block_tp_plan["ff.net.0.proj"] = ColwiseParallel() + block_tp_plan["ff.net.2"] = RowwiseParallel() + parallelize_module(block, device_mesh, block_tp_plan) + + parallelize_module(model_tp, device_mesh, transformer_tp_plan) + + comm_mode = CommDebugMode() + + batch_size = 2 + num_frames, height, width = 49, 512, 512 + temporal_compression_ratio, spatial_compression_ratio = 8, 32 + latent_num_frames, latent_height, latent_width = ( + (num_frames - 1) // temporal_compression_ratio + 1, + height // spatial_compression_ratio, + width // spatial_compression_ratio, + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + caption_sequence_length = 64 + + hidden_states = torch.randn(batch_size, video_sequence_length, CHANNELS, device=DEVICE_TYPE) + encoder_hidden_states = torch.randn(batch_size, caption_sequence_length, CAPTION_CHANNELS, device=DEVICE_TYPE) + encoder_attention_mask = None + timestep = torch.randint(0, 1000, (batch_size, 1), device=DEVICE_TYPE) + inputs = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "timestep": timestep, + "num_frames": latent_num_frames, + "height": latent_height, + "width": latent_width, + "rope_interpolation_scale": [1 / (8 / 25), 8, 8], + "return_dict": False, + } + + output = model(**inputs)[0] + + with comm_mode: + output_tp = model_tp(**inputs)[0] + + output_tp = ( + output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() + if isinstance(output_tp, DTensor) + else output_tp + ) + + print("Output shapes:", output.shape, output_tp.shape) + print( + "Comparing output:", + rank, + torch.allclose(output, output_tp, atol=1e-5, rtol=1e-5), + (output - output_tp).abs().max(), + ) + print(f"Max memory reserved ({rank=}): {torch.cuda.max_memory_reserved(rank) / 1024**3:.2f} GB") + + if rank == 0: + print() + print("get_comm_counts:", comm_mode.get_comm_counts()) + # print() + # print("get_parameter_info:", comm_mode.get_parameter_info()) # Too much noise + print() + print("Sharding info:\n" + "".join(f"{k} - {v}\n" for k, v in comm_mode.get_sharding_info().items())) + print() + print("get_total_counts:", comm_mode.get_total_counts()) + comm_mode.generate_json_dump("dump_comm_mode_log.json", noise_level=1) + comm_mode.log_comm_debug_tracing_table_to_file("dump_comm_mode_tracing_table.txt", noise_level=1) + + +dist.init_process_group(PG_BACKEND) +WORLD_SIZE = dist.get_world_size() +RANK = dist.get_rank() + +torch.cuda.set_device(RANK) + +if RANK == 0: + print(f"World size: {WORLD_SIZE}") + print(f"Device count: {DEVICE_COUNT}") + +try: + with torch.no_grad(): + main(WORLD_SIZE, RANK) +finally: + dist.destroy_process_group() + + +# LTXVideoTransformer3DModel( +# (proj_in): Linear(in_features=128, out_features=2048, bias=True) +# (time_embed): AdaLayerNormSingle( +# (emb): PixArtAlphaCombinedTimestepSizeEmbeddings( +# (time_proj): Timesteps() +# (timestep_embedder): TimestepEmbedding( +# (linear_1): Linear(in_features=256, out_features=2048, bias=True) +# (act): SiLU() +# (linear_2): Linear(in_features=2048, out_features=2048, bias=True) +# ) +# ) +# (silu): SiLU() +# (linear): Linear(in_features=2048, out_features=12288, bias=True) +# ) +# (caption_projection): PixArtAlphaTextProjection( +# (linear_1): Linear(in_features=4096, out_features=2048, bias=True) +# (act_1): GELU(approximate='tanh') +# (linear_2): Linear(in_features=2048, out_features=2048, bias=True) +# ) +# (rope): LTXVideoRotaryPosEmbed() +# (transformer_blocks): ModuleList( +# (0-27): 28 x LTXVideoTransformerBlock( +# (norm1): RMSNorm() +# (attn1): Attention( +# (norm_q): RMSNorm() +# (norm_k): RMSNorm() +# (to_q): Linear(in_features=2048, out_features=2048, bias=True) +# (to_k): Linear(in_features=2048, out_features=2048, bias=True) +# (to_v): Linear(in_features=2048, out_features=2048, bias=True) +# (to_out): ModuleList( +# (0): Linear(in_features=2048, out_features=2048, bias=True) +# (1): Dropout(p=0.0, inplace=False) +# ) +# ) +# (norm2): RMSNorm() +# (attn2): Attention( +# (norm_q): RMSNorm() +# (norm_k): RMSNorm() +# (to_q): Linear(in_features=2048, out_features=2048, bias=True) +# (to_k): Linear(in_features=2048, out_features=2048, bias=True) +# (to_v): Linear(in_features=2048, out_features=2048, bias=True) +# (to_out): ModuleList( +# (0): Linear(in_features=2048, out_features=2048, bias=True) +# (1): Dropout(p=0.0, inplace=False) +# ) +# ) +# (ff): FeedForward( +# (net): ModuleList( +# (0): GELU( +# (proj): Linear(in_features=2048, out_features=8192, bias=True) +# ) +# (1): Dropout(p=0.0, inplace=False) +# (2): Linear(in_features=8192, out_features=2048, bias=True) +# ) +# ) +# ) +# ) +# (norm_out): LayerNorm((2048,), eps=1e-06, elementwise_affine=False) +# (proj_out): Linear(in_features=2048, out_features=128, bias=True) +# ) diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8e65aee6b4b7ee2b679e6e9c1cdf12d3597d97 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py @@ -0,0 +1,63 @@ +import torch +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXVideoTransformer3DModel +from transformers import AutoTokenizer, T5EncoderModel + +from finetrainers.models.ltx_video import LTXVideoModelSpecification + + +class DummyLTXVideoModelSpecification(LTXVideoModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/wan/__init__.py b/docs/finetrainers-src-codebase/tests/models/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py b/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..59bb4809edb6702672334e10e176b955b3ef6a5e --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py @@ -0,0 +1,54 @@ +import torch +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel +from transformers import AutoTokenizer, T5EncoderModel + +from finetrainers.models.wan import WanModelSpecification + + +class DummyWanModelSpecification(WanModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py b/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0498a144d54422b6ae38c00ce83f4e201d67ae --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py @@ -0,0 +1,66 @@ +import torch +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel +from transformers import AutoTokenizer, T5EncoderModel + +from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights +from finetrainers.models.wan import WanControlModelSpecification + + +class DummyWanControlModelSpecification(WanControlModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # This needs to be updated for the test to work correctly. + # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded + # with ModelSpecification::_load_configs + self.transformer_config.in_channels = 16 + + def load_condition_models(self): + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config + return {"vae": vae} + + def load_diffusion_models(self, new_in_features: int): + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ).to(self.transformer_dtype) + + transformer.patch_embedding = _expand_conv3d_with_zeroed_weights( + transformer.patch_embedding, new_in_channels=new_in_features + ) + transformer.register_to_config(in_channels=new_in_features) + + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/docs/finetrainers-src-codebase/tests/test_lora_inference.py b/docs/finetrainers-src-codebase/tests/test_lora_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c0439b19b16fbba898e61545a2d65a8c649393 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/test_lora_inference.py @@ -0,0 +1,44 @@ +""" +Run this test in Lora adpater checking: + +```shell +python3 test_lora_inference.py --prompt "A girl is ridding a bike." --model_path "THUDM/CogVideoX-5B" --lora_path "path/to/lora" --lora_name "lora_adapter" --output_file "output.mp4" --fps 8 +``` + +""" + +import argparse + +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + + +def generate_video(model_path, prompt, lora_path, lora_name, output_file, fps): + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda") + pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=lora_name) + pipe.set_adapters([lora_name], [1.0]) + pipe.enable_model_cpu_offload() + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + + video = pipe(prompt=prompt).frames[0] + export_to_video(video, output_file, fps=fps) + + +def main(): + parser = argparse.ArgumentParser(description="Generate video using CogVideoX and LoRA weights") + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for the video generation") + parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5B", help="Base Model path or HF ID") + parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA weights") + parser.add_argument("--lora_name", type=str, default="lora_adapter", help="Name of the LoRA adapter") + parser.add_argument("--output_file", type=str, default="output.mp4", help="Output video file name") + parser.add_argument("--fps", type=int, default=8, help="Frames per second for the output video") + + args = parser.parse_args() + + generate_video(args.prompt, args.lora_path, args.lora_name, args.output_file, args.fps) + + +if __name__ == "__main__": + main() diff --git a/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh b/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh new file mode 100755 index 0000000000000000000000000000000000000000..ebcab61605afadf1a984953547c771ea8e034f36 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# This shell script is for the maintainers and contributors to QUICKLY check +# if the major changes they're introducing still work with the rest of the models supported +# in `finetrainers`. It DOES NOT give a sense of implementation correctness as that requires +# much longer training runs but it DOES ensure basic functionalities work in the large training +# setup. + +# It should be run as so from the root of `finetrainers`: `bash tests/test_model_runs_minimally_lora.sh` + +###################################################### +# Set common variables. +###################################################### + +ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)" +export ROOT_DIR +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL=DEBUG + +echo "Using $ROOT_DIR as rootdir." + +###################################################### +# Download Disney dataset. +###################################################### + +# Ensure dataset is downloaded +DATA_ROOT="$ROOT_DIR/video-dataset-disney" +if [ ! -d "$DATA_ROOT" ]; then + echo "Downloading Disney dataset to $DATA_ROOT..." + huggingface-cli download \ + --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \ + --local-dir "$DATA_ROOT" +else + echo "Dataset already exists at $DATA_ROOT. Skipping download." +fi + +###################################################### +# Run models +###################################################### + +# Define models to test +models=("dummy_ltx_video_lora" "dummy_cogvideox_lora" "dummy_hunyuanvideo_lora") +for model_script in "${models[@]}"; do + echo "Running $model_script test..." + bash $ROOT_DIR/tests/scripts/$model_script.sh +done \ No newline at end of file diff --git a/docs/finetrainers-src-codebase/tests/test_trackers.py b/docs/finetrainers-src-codebase/tests/test_trackers.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fee180a53669983b3e8709f68a9240131dcf48 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/test_trackers.py @@ -0,0 +1,26 @@ +import logging +import os +import pathlib +import tempfile +import unittest + +from diffusers.utils.testing_utils import CaptureLogger + +from finetrainers.trackers import WandbTracker + + +os.environ["WANDB_MODE"] = "offline" + + +class WandbFastTests(unittest.TestCase): + def test_wandb_logdir(self): + logger = logging.getLogger("finetrainers") + + with tempfile.TemporaryDirectory() as tempdir, CaptureLogger(logger) as cap_log: + tracker = WandbTracker("finetrainers-experiment", log_dir=tempdir, config={}) + tracker.log({"loss": 0.1}, step=0) + tracker.log({"loss": 0.2}, step=1) + tracker.finish() + self.assertTrue(pathlib.Path(tempdir).exists()) + + self.assertTrue("WandB logging enabled" in cap_log.out) diff --git a/docs/finetrainers-src-codebase/tests/trainer/__init__.py b/docs/finetrainers-src-codebase/tests/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py b/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4325904b81bcc077a16921adb2f47a999e44e62 --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py @@ -0,0 +1,274 @@ +# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py + +import json +import os +import pathlib +import tempfile +import time +import unittest + +import pytest +from diffusers.utils import export_to_video +from parameterized import parameterized +from PIL import Image + +from finetrainers import BaseArgs, ControlTrainer, TrainingType, get_logger +from finetrainers.trainer.control_trainer.config import ControlType + + +os.environ["WANDB_MODE"] = "disabled" +os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" + +from ..models.cogview4.control_specification import DummyCogView4ControlModelSpecification # noqa +from ..models.wan.control_specification import DummyWanControlModelSpecification # noqa + + +logger = get_logger() + + +@pytest.fixture(autouse=True) +def slow_down_tests(): + yield + # Sleep between each test so that process groups are cleaned and resources are released. + # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. + # !!!Look into this in future!!! + time.sleep(5) + + +class ControlTrainerFastTestsMixin: + model_specification_cls = None + num_data_files = 4 + num_frames = 4 + height = 64 + width = 64 + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.data_files = [] + for i in range(self.num_data_files): + data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" + export_to_video( + [Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2 + ) + self.data_files.append(data_file.as_posix()) + + csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" + with open(csv_filename.as_posix(), "w") as f: + f.write("file_name,caption\n") + for i in range(self.num_data_files): + prompt = f"A cat ruling the world - {i}" + f.write(f'{i}.mp4,"{prompt}"\n') + + dataset_config = { + "datasets": [ + { + "data_root": self.tmpdir.name, + "dataset_type": "video", + "id_token": "TEST", + "video_resolution_buckets": [[self.num_frames, self.height, self.width]], + "reshape_mode": "bicubic", + } + ] + } + + self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json" + with open(self.dataset_config_filename.as_posix(), "w") as f: + json.dump(dataset_config, f) + + def tearDown(self): + self.tmpdir.cleanup() + + def get_base_args(self) -> BaseArgs: + args = BaseArgs() + args.dataset_config = self.dataset_config_filename.as_posix() + args.train_steps = 10 + args.max_data_samples = 25 + args.batch_size = 1 + args.gradient_checkpointing = True + args.output_dir = self.tmpdir.name + args.checkpointing_steps = 6 + args.enable_precomputation = False + args.precomputation_items = self.num_data_files + args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed") + args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set + + args.control_type = ControlType.CANNY + args.train_qk_norm = True + args.frame_conditioning_type = "random" + args.frame_conditioning_index = None + args.frame_conditioning_concatenate_mask = False + + return args + + def get_args(self) -> BaseArgs: + raise NotImplementedError("`get_args` must be implemented in the subclass.") + + def _test_training(self, args: BaseArgs): + model_specification = self.model_specification_cls() + trainer = ControlTrainer(args, model_specification) + trainer.run() + + +class ControlTrainerLoRATestsMixin___PTD(ControlTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "ptd" + args.training_type = TrainingType.CONTROL_LORA + args.rank = 4 + args.lora_alpha = 4 + args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.tp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class ControlTrainerFullFinetuneTestsMixin___PTD(ControlTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "ptd" + args.training_type = TrainingType.CONTROL_FULL_FINETUNE + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.tp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class ControlTrainerCogView4LoRATests___PTD(ControlTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogView4ControlModelSpecification + + +class ControlTrainerCogView4FullFinetuneTests___PTD(ControlTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogView4ControlModelSpecification + + +class ControlTrainerWanLoRATests___PTD(ControlTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyWanControlModelSpecification + + +class ControlTrainerWanFullFinetuneTests___PTD(ControlTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyWanControlModelSpecification diff --git a/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py b/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..96a09ce47275d307354e2cac69899c44139a719a --- /dev/null +++ b/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py @@ -0,0 +1,537 @@ +# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py + +import json +import os +import pathlib +import tempfile +import time +import unittest + +import pytest +import torch +from diffusers.utils import export_to_video +from parameterized import parameterized +from PIL import Image + +from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger + + +os.environ["WANDB_MODE"] = "disabled" +os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" + +from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification # noqa +from ..models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa +from ..models.flux.base_specification import DummyFluxModelSpecification # noqa +from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification # noqa +from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa +from ..models.wan.base_specification import DummyWanModelSpecification # noqa + + +logger = get_logger() + + +@pytest.fixture(autouse=True) +def slow_down_tests(): + yield + # Sleep between each test so that process groups are cleaned and resources are released. + # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. + # !!!Look into this in future!!! + time.sleep(5) + + +class SFTTrainerFastTestsMixin: + model_specification_cls = None + num_data_files = 4 + num_frames = 4 + height = 64 + width = 64 + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.data_files = [] + for i in range(self.num_data_files): + data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" + export_to_video( + [Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2 + ) + self.data_files.append(data_file.as_posix()) + + csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" + with open(csv_filename.as_posix(), "w") as f: + f.write("file_name,caption\n") + for i in range(self.num_data_files): + prompt = f"A cat ruling the world - {i}" + f.write(f'{i}.mp4,"{prompt}"\n') + + dataset_config = { + "datasets": [ + { + "data_root": self.tmpdir.name, + "dataset_type": "video", + "id_token": "TEST", + "video_resolution_buckets": [[self.num_frames, self.height, self.width]], + "reshape_mode": "bicubic", + } + ] + } + + self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json" + with open(self.dataset_config_filename.as_posix(), "w") as f: + json.dump(dataset_config, f) + + def tearDown(self): + self.tmpdir.cleanup() + # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually + # make sure to destroy it here. + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + time.sleep(3) + + def get_base_args(self) -> BaseArgs: + args = BaseArgs() + args.dataset_config = self.dataset_config_filename.as_posix() + args.train_steps = 10 + args.max_data_samples = 25 + args.batch_size = 1 + args.gradient_checkpointing = True + args.output_dir = self.tmpdir.name + args.checkpointing_steps = 6 + args.enable_precomputation = False + args.precomputation_items = self.num_data_files + args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed") + args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set + # args.attn_provider_training = ["transformer:_native_cudnn"] + # args.attn_provider_inference = ["transformer:_native_cudnn"] + return args + + def get_args(self) -> BaseArgs: + raise NotImplementedError("`get_args` must be implemented in the subclass.") + + def _test_training(self, args: BaseArgs): + model_specification = self.model_specification_cls() + trainer = SFTTrainer(args, model_specification) + trainer.run() + + +# =============== =============== + + +class SFTTrainerLoRATestsMixin___Accelerate(SFTTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "accelerate" + args.training_type = TrainingType.LORA + args.rank = 4 + args.lora_alpha = 4 + args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.layerwise_upcasting_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class SFTTrainerFullFinetuneTestsMixin___Accelerate(SFTTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "accelerate" + args.training_type = TrainingType.FULL_FINETUNE + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class SFTTrainerCogVideoXLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyCogVideoXModelSpecification + + +class SFTTrainerCogVideoXFullFinetuneTests___Accelerate( + SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase +): + model_specification_cls = DummyCogVideoXModelSpecification + + +class SFTTrainerCogView4LoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyCogView4ModelSpecification + + +class SFTTrainerCogView4FullFinetuneTests___Accelerate( + SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase +): + model_specification_cls = DummyCogView4ModelSpecification + + +class SFTTrainerFluxLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyFluxModelSpecification + + +class SFTTrainerFluxFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyFluxModelSpecification + + +class SFTTrainerHunyuanVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerHunyuanVideoFullFinetuneTests___Accelerate( + SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase +): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerLTXVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyLTXVideoModelSpecification + + +class SFTTrainerLTXVideoFullFinetuneTests___Accelerate( + SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase +): + model_specification_cls = DummyLTXVideoModelSpecification + + +class SFTTrainerWanLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyWanModelSpecification + + +class SFTTrainerWanFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): + model_specification_cls = DummyWanModelSpecification + + +# =============== =============== + +# =============== =============== + + +class SFTTrainerLoRATestsMixin___PTD(SFTTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "ptd" + args.training_type = TrainingType.LORA + args.rank = 4 + args.lora_alpha = 4 + args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.layerwise_upcasting_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.compile_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.layerwise_upcasting_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.compile_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.tp_degree = 2 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @unittest.skip( + "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." + ) + @parameterized.expand([(True,)]) + def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.cp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @unittest.skip( + "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." + ) + @parameterized.expand([(True,)]) + def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.cp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class SFTTrainerFullFinetuneTestsMixin___PTD(SFTTrainerFastTestsMixin): + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.parallel_backend = "ptd" + args.training_type = TrainingType.FULL_FINETUNE + return args + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.compile_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.compile_modules = ["transformer"] + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 2 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(True,)]) + def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.dp_shards = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @parameterized.expand([(False,), (True,)]) + def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): + args = self.get_args() + args.tp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @unittest.skip( + "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." + ) + @parameterized.expand([(True,)]) + def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.cp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + @unittest.skip( + "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." + ) + @parameterized.expand([(True,)]) + def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.cp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + self._test_training(args) + + +class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogVideoXModelSpecification + + +class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogVideoXModelSpecification + + +class SFTTrainerCogView4LoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogView4ModelSpecification + + +class SFTTrainerCogView4FullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyCogView4ModelSpecification + + +class SFTTrainerFluxLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyFluxModelSpecification + + +class SFTTrainerFluxFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyFluxModelSpecification + + +class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyLTXVideoModelSpecification + + +class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyLTXVideoModelSpecification + + +class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyWanModelSpecification + + +class SFTTrainerWanFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyWanModelSpecification + + +# =============== =============== diff --git a/docs/finetrainers-src-codebase/train.py b/docs/finetrainers-src-codebase/train.py new file mode 100644 index 0000000000000000000000000000000000000000..183ba1eddebf98c0d268442e9d8d79543b311c08 --- /dev/null +++ b/docs/finetrainers-src-codebase/train.py @@ -0,0 +1,86 @@ +import sys +import traceback + +from finetrainers import BaseArgs, ControlTrainer, SFTTrainer, TrainingType, get_logger +from finetrainers.config import _get_model_specifiction_cls +from finetrainers.trainer.control_trainer.config import ControlFullRankConfig, ControlLowRankConfig +from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig + + +logger = get_logger() + + +def main(): + try: + import multiprocessing + + multiprocessing.set_start_method("fork") + except Exception as e: + logger.error( + f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' + f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" + f"Error: {e}" + ) + + try: + args = BaseArgs() + + argv = [y.strip() for x in sys.argv for y in x.split()] + training_type_index = argv.index("--training_type") + if training_type_index == -1: + raise ValueError("Training type not provided in command line arguments.") + + training_type = argv[training_type_index + 1] + training_cls = None + if training_type == TrainingType.LORA: + training_cls = SFTLowRankConfig + elif training_type == TrainingType.FULL_FINETUNE: + training_cls = SFTFullRankConfig + elif training_type == TrainingType.CONTROL_LORA: + training_cls = ControlLowRankConfig + elif training_type == TrainingType.CONTROL_FULL_FINETUNE: + training_cls = ControlFullRankConfig + else: + raise ValueError(f"Training type {training_type} not supported.") + + args.register_args(training_cls()) + args = args.parse_args() + + model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type) + model_specification = model_specification_cls( + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + tokenizer_id=args.tokenizer_id, + tokenizer_2_id=args.tokenizer_2_id, + tokenizer_3_id=args.tokenizer_3_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + text_encoder_3_id=args.text_encoder_3_id, + transformer_id=args.transformer_id, + vae_id=args.vae_id, + text_encoder_dtype=args.text_encoder_dtype, + text_encoder_2_dtype=args.text_encoder_2_dtype, + text_encoder_3_dtype=args.text_encoder_3_dtype, + transformer_dtype=args.transformer_dtype, + vae_dtype=args.vae_dtype, + revision=args.revision, + cache_dir=args.cache_dir, + ) + + if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]: + trainer = SFTTrainer(args, model_specification) + elif args.training_type in [TrainingType.CONTROL_LORA, TrainingType.CONTROL_FULL_FINETUNE]: + trainer = ControlTrainer(args, model_specification) + else: + raise ValueError(f"Training type {args.training_type} not supported.") + + trainer.run() + + except KeyboardInterrupt: + logger.info("Received keyboard interrupt. Exiting...") + except Exception as e: + logger.error(f"An error occurred during training: {e}") + logger.error(traceback.format_exc()) + + +if __name__ == "__main__": + main() diff --git a/docs/finetrainers/documentation_global_README.md b/docs/finetrainers/documentation_global_README.md deleted file mode 100644 index 052d7d45af231d88beb2f9a3d946e6c593556fa3..0000000000000000000000000000000000000000 --- a/docs/finetrainers/documentation_global_README.md +++ /dev/null @@ -1,99 +0,0 @@ -# finetrainers 🧪 - -FineTrainers is a work-in-progress library to support (accessible) training of video models. Our first priority is to support LoRA training for all popular video models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc. - -`cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`. - - - - - -
- -## News - -- 🔥 **2025-03-03**: Wan T2V support added! -- 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more. -- 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details! -- 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference. -- 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions. -- 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative! -- 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added! -- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254). -- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added! - -## Table of Contents - -- [Quickstart](#quickstart) -- [Support Matrix](#support-matrix) -- [Featured Projects](#featured-projects) -- [Acknowledgements](#acknowledgements) - -## Quickstart - -Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install `diffusers` from source by `pip install git+https://github.com/huggingface/diffusers`. The requirements specify `diffusers>=0.32.1`, but it is always recommended to use the `main` branch of Diffusers for the latest features and bugfixes. Note that the `main` branch for `finetrainers` is also the development branch, and stable support should be expected from the release tags. - -Checkout to the latest release tag: - -```bash -git fetch --all --tags -git checkout tags/v0.0.1 -``` - -Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the release tag. - -#### Using the main branch - -To get started quickly with example training scripts on the main development branch, refer to the following: -- [LTX-Video Pika Effects Crush](./examples/training/sft/ltx_video/crush_smol_lora/) -- [CogVideoX Pika Effects Crush](./examples/training/sft/cogvideox/crush_smol_lora/) -- [Wan T2V Pika Effects Crush](./examples/training/sft/wan/crush_smol_lora/) - -The following are some simple datasets/HF orgs with good datasets to test training with quickly: -- [Disney Video Generation Dataset](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) -- [bigdatapw Video Dataset Collection](https://huggingface.co/bigdata-pw) -- [Finetrainers HF Dataset Collection](https://huggingface.co/finetrainers) - -Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./examples/training/) to learn more about supported models for training & example reproducible training launch scripts. - -> [!IMPORTANT] -> It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. For fully reproducible training, please use the same environment as mentioned in [environment.md](./docs/environment.md). - -## Support Matrix - -> [!NOTE] -> The following numbers were obtained from the [release branch](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1). The `main` branch is unstable at the moment and may use higher memory. - -
- -| **Model Name** | **Tasks** | **Min. LoRA VRAM*** | **Min. Full Finetuning VRAM^** | -|:----------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:| -| [LTX-Video](./docs/models/ltx_video.md) | Text-to-Video | 5 GB | 21 GB | -| [HunyuanVideo](./docs/models/hunyuan_video.md) | Text-to-Video | 32 GB | OOM | -| [CogVideoX-5b](./docs/models/cogvideox.md) | Text-to-Video | 18 GB | 53 GB | -| [Wan](./docs/models/wan.md) | Text-to-Video | TODO | TODO | - -
- -*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
-^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing. - -If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md). - -## Featured Projects 🔥 - -Checkout some amazing projects citing `finetrainers`: -- [Diffusion as Shader](https://github.com/IGL-HKUST/DiffusionAsShader) -- [SkyworkAI's SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1) -- [eisneim's LTX Image-to-Video](https://github.com/eisneim/ltx_lora_training_i2v_t2v/) -- [wileewang's TransPixar](https://github.com/wileewang/TransPixar) -- [Feizc's Video-In-Context](https://github.com/feizc/Video-In-Context) - -Checkout the following UIs built for `finetrainers`: -- [jbilcke's VideoModelStudio](https://github.com/jbilcke-hf/VideoModelStudio) -- [neph1's finetrainers-ui](https://github.com/neph1/finetrainers-ui) - -## Acknowledgements - -* `finetrainers` builds on top of & takes inspiration from great open-source libraries - `transformers`, `accelerate`, `torchtune`, `torchtitan`, `peft`, `diffusers`, `bitsandbytes`, `torchao` and `deepspeed` - to name a few. -* Some of the design choices of `finetrainers` were inspired by [`SimpleTuner`](https://github.com/bghira/SimpleTuner). diff --git a/vms/ui/project/tabs/manage_tab.py b/vms/ui/project/tabs/manage_tab.py index 48a617f4017072a6dd6e94a229070ecd3cabb10f..4f1827f5466672f52efbb489c538ed230494c1be 100644 --- a/vms/ui/project/tabs/manage_tab.py +++ b/vms/ui/project/tabs/manage_tab.py @@ -79,7 +79,8 @@ class ManageTab(BaseTab): self.components["download_output_btn"] = gr.DownloadButton( "📁 Download output directory (.zip)", variant="secondary", - size="lg" + size="lg", + visible=False ) with gr.Row(): with gr.Column():