jbilcke-hf HF Staff commited on
Commit
9fd1204
·
1 Parent(s): 76a0a50

we are going to hack into finetrainers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml +51 -0
  2. docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml +34 -0
  3. docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml +30 -0
  4. docs/finetrainers-src-codebase/.gitignore +179 -0
  5. docs/finetrainers-src-codebase/CONTRIBUTING.md +41 -0
  6. docs/finetrainers-src-codebase/LICENSE +201 -0
  7. docs/finetrainers-src-codebase/Makefile +11 -0
  8. docs/{finetrainers/documentation_README.md → finetrainers-src-codebase/README.md} +6 -3
  9. docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml +22 -0
  10. docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml +23 -0
  11. docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml +17 -0
  12. docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml +17 -0
  13. docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml +17 -0
  14. docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml +17 -0
  15. docs/finetrainers-src-codebase/assets/contribute.md +16 -0
  16. docs/finetrainers-src-codebase/assets/contribute_zh.md +16 -0
  17. docs/finetrainers-src-codebase/assets/dataset_zh.md +72 -0
  18. docs/finetrainers-src-codebase/assets/sft_2b.png +0 -0
  19. docs/finetrainers-src-codebase/assets/sft_5b.png +0 -0
  20. docs/finetrainers-src-codebase/assets/tests/metadata.csv +2 -0
  21. docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md +20 -0
  22. docs/{finetrainers/documentation_args.md → finetrainers-src-codebase/docs/args.md} +44 -5
  23. docs/{finetrainers/documentation_dataset_README.md → finetrainers-src-codebase/docs/dataset/README.md} +11 -4
  24. docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md +44 -0
  25. docs/{finetrainers/documentation_environment.md → finetrainers-src-codebase/docs/environment.md} +11 -0
  26. docs/{finetrainers/documentation_models_README.md → finetrainers-src-codebase/docs/models/README.md} +0 -0
  27. docs/finetrainers-src-codebase/docs/models/attention.md +263 -0
  28. docs/{finetrainers/documentation_models_cogvideox.md → finetrainers-src-codebase/docs/models/cogvideox.md} +6 -6
  29. docs/finetrainers-src-codebase/docs/models/cogview4.md +94 -0
  30. docs/finetrainers-src-codebase/docs/models/flux.md +53 -0
  31. docs/{finetrainers/documentation_models_hunyuan_video.md → finetrainers-src-codebase/docs/models/hunyuan_video.md} +3 -3
  32. docs/{finetrainers/documentation_models_ltx_video.md → finetrainers-src-codebase/docs/models/ltx_video.md} +3 -3
  33. docs/{finetrainers/documentation_models_optimization.md → finetrainers-src-codebase/docs/models/optimization.md} +0 -0
  34. docs/{finetrainers/documentation_models_wan.md → finetrainers-src-codebase/docs/models/wan.md} +11 -1
  35. docs/{finetrainers/documentation_optimizers.md → finetrainers-src-codebase/docs/optimizer.md} +0 -0
  36. docs/{finetrainers/documentation_parallel_processing_README.md → finetrainers-src-codebase/docs/parallel/README.md} +8 -3
  37. docs/{finetrainers/documentation_trainers_control_trainer.md → finetrainers-src-codebase/docs/trainer/control_trainer.md} +0 -0
  38. docs/{finetrainers/documentation_trainers_sft_trainer.md → finetrainers-src-codebase/docs/trainer/sft_trainer.md} +0 -0
  39. docs/finetrainers-src-codebase/examples/_legacy/training/README.md +459 -0
  40. docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md +455 -0
  41. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py +0 -0
  42. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py +484 -0
  43. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py +1016 -0
  44. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py +947 -0
  45. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py +955 -0
  46. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py +917 -0
  47. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py +428 -0
  48. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py +669 -0
  49. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py +1 -0
  50. docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py +99 -0
docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F41B Bug Report"
2
+ description: Submit a bug report to help us improve CogVideoX-Factory / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX-Factory 开源框架
3
+ body:
4
+ - type: textarea
5
+ id: system-info
6
+ attributes:
7
+ label: System Info / 系統信息
8
+ description: Your operating environment / 您的运行环境信息
9
+ placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
10
+ validations:
11
+ required: true
12
+
13
+ - type: checkboxes
14
+ id: information-scripts-examples
15
+ attributes:
16
+ label: Information / 问题信息
17
+ description: 'The problem arises when using: / 问题出现在'
18
+ options:
19
+ - label: "The official example scripts / 官方的示例脚本"
20
+ - label: "My own modified scripts / 我自己修改的脚本和任务"
21
+
22
+ - type: textarea
23
+ id: reproduction
24
+ validations:
25
+ required: true
26
+ attributes:
27
+ label: Reproduction / 复现过程
28
+ description: |
29
+ Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
30
+ If you have code snippets, error messages, stack traces, please provide them here as well.
31
+ Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
32
+ Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
33
+
34
+ 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
35
+ 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
36
+ 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
37
+ 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
38
+ placeholder: |
39
+ Steps to reproduce the behavior/复现Bug的步骤:
40
+
41
+ 1.
42
+ 2.
43
+ 3.
44
+
45
+ - type: textarea
46
+ id: expected-behavior
47
+ validations:
48
+ required: true
49
+ attributes:
50
+ label: Expected behavior / 期待表现
51
+ description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F680 Feature request"
2
+ description: Submit a request for a new CogVideoX-Factory feature / 提交一个新的 CogVideoX-Factory 开源项目的功能建议
3
+ labels: [ "feature" ]
4
+ body:
5
+ - type: textarea
6
+ id: feature-request
7
+ validations:
8
+ required: true
9
+ attributes:
10
+ label: Feature request / 功能建议
11
+ description: |
12
+ A brief description of the functional proposal. Links to corresponding papers and code are desirable.
13
+ 对功能建议的简述。最好提供对应的论文和代码链接。
14
+
15
+ - type: textarea
16
+ id: motivation
17
+ validations:
18
+ required: true
19
+ attributes:
20
+ label: Motivation / 动机
21
+ description: |
22
+ Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
23
+ 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
24
+
25
+ - type: textarea
26
+ id: contribution
27
+ validations:
28
+ required: true
29
+ attributes:
30
+ label: Your contribution / 您的贡献
31
+ description: |
32
+
33
+ Your PR link or any other link you can help with.
34
+ 您的PR链接或者其他您能提供帮助的链接。
docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Fast tests for PRs
2
+
3
+ on:
4
+ pull_request:
5
+ branches:
6
+ - main
7
+
8
+ concurrency:
9
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ check_code_quality:
14
+ runs-on: ubuntu-22.04
15
+ steps:
16
+ - uses: actions/checkout@v3
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v4
19
+ with:
20
+ python-version: "3.8"
21
+ - name: Install dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install ruff==0.9.10
25
+ - name: Check quality
26
+ run: make quality
27
+ - name: Check if failure
28
+ if: ${{ failure() }}
29
+ run: |
30
+ echo "Quality check failed. Please install ruff: `pip install ruff` and then run `make style && make quality` from the root of the repository." >> $GITHUB_STEP_SUMMARY
docs/finetrainers-src-codebase/.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # JetBrains
7
+ .idea
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ # manually added
168
+ wandb/
169
+ *.txt
170
+ dump*
171
+ outputs*
172
+ *.slurm
173
+ .vscode/
174
+ *dummy*
175
+ *curated*
176
+ validation_dataset/
177
+ wan-framepack/
178
+
179
+ !requirements.txt
docs/finetrainers-src-codebase/CONTRIBUTING.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute to Finetrainers
2
+
3
+ Finetrainers is an early-stage library for training diffusion models. Everyone is welcome to contribute - models, algorithms, refactors, docs, etc. - but due to the early stage of the project, we recommend bigger contributions be discussed in an issue before submitting a PR. Eventually, we will have a better process for this!
4
+
5
+ ## How to contribute
6
+
7
+ ### Adding a new model
8
+
9
+ If you would like to add a new model, please follow these steps:
10
+
11
+ - Create a new file in the `finetrainers/models` directory with the model name (if it's new), or use the same directory if it's a variant of an existing model.
12
+ - Implement the model specification in the file. For more details on what a model specification should look like, see the [ModelSpecification](TODO(aryan): add link) documentation.
13
+ - Update the supported configs in `finetrainers/config.py` to include the new model and the training types supported.
14
+ - Add a dummy model specification in the `tests/models` directory.
15
+ - Make sure to test training with the following settings:
16
+ - Single GPU
17
+ - 2x GPU with `--dp_degree 2 --dp_shards 1`
18
+ - 2x GPU with `--dp_degree 1 --dp_shards 2`
19
+
20
+ For `SFTTrainer` additions, please make sure to train with atleast 1000 steps (atleast 2000 data points) to ensure the model training is working as expected.
21
+ - Open a PR with your changes. Please make sure to share your wandb logs for the above training settings in the PR description. This will help us verify the training is working as expected.
22
+
23
+ ### Adding a new algorithm
24
+
25
+ Currently, we are not accepting algorithm contributions. We will update this section once we are better ready 🤗
26
+
27
+ ### Refactors
28
+
29
+ The library is in a very early stage. There are many instances of dead code, poorly written abstractions, and other issues. If you would like to refactor/clean-up a part of the codebase, please open an issue to discuss the changes before submitting a PR.
30
+
31
+ ### Dataset improvements
32
+
33
+ Any changes to dataset/dataloader implementations can be submitted directly. The improvements and reasons for the changes should be conveyed appropriately for us to move quickly 🤗
34
+
35
+ ### Documentation
36
+
37
+ Due to the early stage of the project, the documentation is not as comprehensive as we would like. Any improvements/refactors are welcome directly!
38
+
39
+ ## Asking for help
40
+
41
+ If you have any questions, feel free to open an issue and we will be sure to help you out asap! Please make sure to describe your issues in either English (preferable) or Chinese. Any other language will make it hard for us to help you, so we will most likely close such issues without explanation/answer.
docs/finetrainers-src-codebase/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
docs/finetrainers-src-codebase/Makefile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style
2
+
3
+ check_dirs := finetrainers tests examples train.py setup.py
4
+
5
+ quality:
6
+ ruff check $(check_dirs) --exclude examples/_legacy
7
+ ruff format --check $(check_dirs) --exclude examples/_legacy
8
+
9
+ style:
10
+ ruff check $(check_dirs) --fix --exclude examples/_legacy
11
+ ruff format $(check_dirs) --exclude examples/_legacy
docs/{finetrainers/documentation_README.md → finetrainers-src-codebase/README.md} RENAMED
@@ -30,10 +30,10 @@ Checkout to the latest stable release tag:
30
 
31
  ```bash
32
  git fetch --all --tags
33
- git checkout tags/v0.1.0
34
  ```
35
 
36
- Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.1.0) for the latest stable release.
37
 
38
  #### Using the main branch
39
 
@@ -54,9 +54,10 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
54
 
55
  ## Features
56
 
57
- - DDP, FSDP-2 & HSDP support for all models
58
  - LoRA and full-rank finetuning; Conditional Control training
59
  - Memory-efficient single-GPU training
 
60
  - Auto-detection of commonly used dataset formats
61
  - Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more
62
  - Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets
@@ -65,6 +66,8 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
65
 
66
  ## News
67
 
 
 
68
  - 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan!
69
  - 🔥 **2025-04-08**: `torch.compile` support added!
70
  - 🔥 **2025-04-06**: Flux support added!
 
30
 
31
  ```bash
32
  git fetch --all --tags
33
+ git checkout tags/v0.2.0
34
  ```
35
 
36
+ Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.2.0-release) for the latest stable release.
37
 
38
  #### Using the main branch
39
 
 
54
 
55
  ## Features
56
 
57
+ - DDP, FSDP-2 & HSDP, CP support
58
  - LoRA and full-rank finetuning; Conditional Control training
59
  - Memory-efficient single-GPU training
60
+ - Multiple attention backends supported - `flash`, `flex`, `sage`, `xformers` (see [attention](./docs/models/attention.md) docs)
61
  - Auto-detection of commonly used dataset formats
62
  - Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more
63
  - Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets
 
66
 
67
  ## News
68
 
69
+ - 🔥 **2025-04-25**: Support for different attention providers added!
70
+ - 🔥 **2025-04-21**: Wan I2V supported added!
71
  - 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan!
72
  - 🔥 **2025-04-08**: `torch.compile` support added!
73
  - 🔥 **2025-04-06**: Flux support added!
docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ dynamo_config:
6
+ dynamo_backend: INDUCTOR
7
+ dynamo_mode: max-autotune
8
+ dynamo_use_dynamic: true
9
+ dynamo_use_fullgraph: false
10
+ enable_cpu_affinity: false
11
+ gpu_ids: '3'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 1
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: false
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ enable_cpu_affinity: false
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: bf16
16
+ num_machines: 1
17
+ num_processes: 2
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: '3'
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 1
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: 0,1
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 2
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: 0,1,2,3
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
docs/finetrainers-src-codebase/assets/contribute.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributions Welcome
2
+
3
+ This project is in a very early stage, and we welcome contributions from everyone. We hope to receive contributions and support in the following areas:
4
+
5
+ 1. Support for more models. In addition to CogVideoX models, we also highly encourage contributions supporting other models.
6
+ 2. Support for richer datasets. In our example, we used a Disney video generation dataset, but we hope to support more datasets as the current one is too limited for deeper fine-tuning exploration.
7
+ 3. Anything in `TODO` we mention in our README.md
8
+
9
+ ## How to Submit
10
+
11
+ We welcome you to create a new PR and describe the corresponding contribution. We will review it as soon as possible.
12
+
13
+ ## Naming Conventions
14
+
15
+ - Please use English for naming, avoid using pinyin or other languages. All comments should be in English.
16
+ - Strictly follow PEP8 conventions, and use underscores to separate words. Please avoid using names like a, b, c.
docs/finetrainers-src-codebase/assets/contribute_zh.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 欢迎你们的贡献
2
+
3
+ 本项目属于非常初级的阶段,欢迎大家进行贡献。我们希望在以下方面得到贡献和支持:
4
+
5
+ 1. 支持更多的模型,除了 CogVideoX 模型之外的模型,我们也非常支持。
6
+ 2. 更丰富的数据集支持。在我们的例子中,我们使用了一个 Disney 视频生成数据集,但是我们希望能够支持更多的数据集,这个数据集太少了,并不足以进行更深的微调探索。
7
+ 3. 任何我们在README中`TODO`提到的内容。
8
+
9
+ ## 提交方式
10
+
11
+ 我们欢迎您直接创建一个新的PR,并说明对应的贡献,我们将第一时间查看。
12
+
13
+ ## 命名规范
14
+
15
+ - 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
16
+ - 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
docs/finetrainers-src-codebase/assets/dataset_zh.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 数据集格式
2
+
3
+ ### 提示词数据集要求
4
+
5
+ 创建 `prompt.txt` 文件,文件应包含逐行分隔的提示。请注意,提示必须是英文,并且建议使用 [提示润色脚本](https://github.com/THUDM/CogVideo/blob/main/inference/convert_demo.py) 进行润色。或者可以使用 [CogVideo-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) 进行数据标注:
6
+
7
+ ```
8
+ A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.
9
+ A black and white animated sequence on a ship’s deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language...
10
+ ...
11
+ ```
12
+
13
+ ### 视频数据集要求
14
+
15
+ 该框架支持的分辨率和帧数需要满足以下条件:
16
+
17
+ - **支持的分辨率(宽 * 高)**:
18
+ - 任意分辨率且必须能被32整除。例如,`720 * 480`, `1920 * 1020` 等分辨率。
19
+
20
+ - **支持的帧数(Frames)**:
21
+ - 必须是 `4 * k` 或 `4 * k + 1`(例如:16, 32, 49, 81)
22
+
23
+ 所有的视频建议放在一个文件夹中。
24
+
25
+
26
+ 接着,创建 `videos.txt` 文件。 `videos.txt` 文件应包含逐行分隔的视频文件路径。请注意,路径必须相对于 `--data_root` 目录。格式如下:
27
+
28
+ ```
29
+ videos/00000.mp4
30
+ videos/00001.mp4
31
+ ...
32
+ ```
33
+
34
+ 对于有兴趣了解更多细节的开发者,您可以查看相关的 `BucketSampler` 代码。
35
+
36
+ ### 数据集结构
37
+
38
+ 您的数据集结构应如下所示,通过运行`tree`命令,你能看到:
39
+
40
+ ```
41
+ dataset
42
+ ├── prompt.txt
43
+ ├── videos.txt
44
+ ├── videos
45
+ ├── videos/00000.mp4
46
+ ├── videos/00001.mp4
47
+ ├── ...
48
+ ```
49
+
50
+ ### 使用数据集
51
+
52
+ 当使用此格式时,`--caption_column` 应为 `prompt.txt`,`--video_column` 应为 `videos.txt`。如果您的数据存储在 CSV
53
+ 文件中,也可以指定 `--dataset_file` 为 CSV 文件的路径,`--caption_column` 和 `--video_column` 为 CSV
54
+ 文件中的实际列名。请参考 [test_dataset](../tests/test_dataset.py) 文件中的一些简单示例。
55
+
56
+ 例如,使用 [这个](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) Disney 数据集进行微调。下载可通过🤗
57
+ Hugging Face CLI 完成:
58
+
59
+ ```
60
+ huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
61
+ ```
62
+
63
+ 该数据集已按照预期格式准备好,可直接使用。但是,直接使用视频数据集可能会导致较小 VRAM 的 GPU 出现
64
+ OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae)
65
+ (将视频编码为潜在空间)和大型 [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/)
66
+
67
+ 文本编码器。为了降低内存需求,您可以使用 `training/prepare_dataset.py` 脚本预先计算潜在变量和嵌入。
68
+
69
+ 填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings`
70
+ 以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。
71
+ 在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和
72
+ VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。
docs/finetrainers-src-codebase/assets/sft_2b.png ADDED
docs/finetrainers-src-codebase/assets/sft_5b.png ADDED
docs/finetrainers-src-codebase/assets/tests/metadata.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ video,caption
2
+ "videos/hiker.mp4","""A hiker standing at the top of a mountain, triumphantly, high quality"""
docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Notes for Future Me
2
+
3
+ >![NOTE]
4
+ > This doc page is intended for developers and contributors.
5
+
6
+ FSDP dump:
7
+ - https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
8
+ - https://github.com/pytorch/pytorch/issues/114299
9
+ - Using FSDP1 requires that all FSDP flat parameters are of the same dtype. For LoRA training, we default lora parameters to fp32 and transformer parameters to dtype chosen by user. There seems to be no easy workaround than performing lora training in same dtype.
10
+ - https://github.com/pytorch/pytorch/issues/100945
11
+ - https://github.com/pytorch/torchtune/blob/9b3836028fd0b48f593ea43474b86880c49a4d74/recipes/lora_finetune_distributed.py
12
+ - https://github.com/KellerJordan/modded-nanogpt/pull/68
13
+ - https://github.com/pytorch/pytorch/pull/125394: monkey-patch method for FSDP pre/post-hooks to be triggered for method other than `forward`
14
+ - https://github.com/pytorch/pytorch/pull/127786:
15
+ - https://github.com/pytorch/pytorch/pull/130949:
16
+ - Sanity saver: create optimizers after parallelizing/activation-checkpointing models
17
+
18
+ DTensor:
19
+ - https://github.com/pytorch/pytorch/issues/88838
20
+ - https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/parallel/test_parallelize_api.py
docs/{finetrainers/documentation_args.md → finetrainers-src-codebase/docs/args.md} RENAMED
@@ -75,7 +75,10 @@ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embe
75
  naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
76
  by default, and recommend adding more layers to the default list based on the model architecture.
77
  compile_modules (`List[str]`, defaults to `[]`):
78
- Modules that should be regionally compiled with `torch.compile`. Choose one or more from ['transformer'].
 
 
 
79
 
80
  DATASET ARGUMENTS
81
  -----------------
@@ -109,6 +112,9 @@ dataset_config (`str`):
109
  dataset_shuffle_buffer_size (`int`, defaults to `1`):
110
  The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
111
  value of `1` means that the dataset will not be shuffled.
 
 
 
112
  precomputation_items (`int`, defaults to `512`):
113
  Number of data samples to precompute at once for memory-efficient training. The higher this value,
114
  the more disk memory will be used to save the precomputed samples (conditions and latents).
@@ -118,8 +124,16 @@ precomputation_dir (`str`, defaults to `None`):
118
  precomputation_once (`bool`, defaults to `False`):
119
  Precompute embeddings from all datasets at once before training. This is useful to save time during training
120
  with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
121
- training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size
122
- of your dataset(s).
 
 
 
 
 
 
 
 
123
 
124
  DATALOADER_ARGUMENTS
125
  --------------------
@@ -248,8 +262,6 @@ logging_dir (`str`, defaults to `logs`):
248
  The directory where the logs will be stored.
249
  logging_steps (`int`, defaults to `1`):
250
  Training logs will be tracked every `logging_steps` steps.
251
- allow_tf32 (`bool`, defaults to `False`):
252
- Whether or not to allow the use of TF32 matmul on compatible hardware.
253
  nccl_timeout (`int`, defaults to `1800`):
254
  Timeout for the NCCL communication.
255
  report_to (`str`, defaults to `wandb`):
@@ -260,6 +272,33 @@ verbose (`int`, defaults to `1`):
260
  - 1: Diffusers/Transformers info logging on local main process only
261
  - 2: Diffusers/Transformers debug logging on local main process only
262
  - 3: Diffusers/Transformers debug logging on all processes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  ```
264
 
265
  ## SFT training
 
75
  naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
76
  by default, and recommend adding more layers to the default list based on the model architecture.
77
  compile_modules (`List[str]`, defaults to `[]`):
78
+ Modules that should be regionally compiled with `torch.compile`.
79
+ compile_scopes (`str`, defaults to `None`):
80
+ The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as
81
+ `--compile_modules`. If `None`, will default to `regional` for all modules.
82
 
83
  DATASET ARGUMENTS
84
  -----------------
 
112
  dataset_shuffle_buffer_size (`int`, defaults to `1`):
113
  The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
114
  value of `1` means that the dataset will not be shuffled.
115
+ enable_precomputation (`bool`, defaults to `False`):
116
+ Whether or not to precompute the embeddings for the dataset. This is useful for faster training. If set to `True`,
117
+ the embeddings will be precomputed and saved to disk and loaded as required.
118
  precomputation_items (`int`, defaults to `512`):
119
  Number of data samples to precompute at once for memory-efficient training. The higher this value,
120
  the more disk memory will be used to save the precomputed samples (conditions and latents).
 
124
  precomputation_once (`bool`, defaults to `False`):
125
  Precompute embeddings from all datasets at once before training. This is useful to save time during training
126
  with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
127
+ training when required (that is, computing embeddings of more data samples once `precomputation_items` of them
128
+ have been exhausted across all distributed ranks). Make sure to set `precomputation_items` to a reasonable value
129
+ in line with the size of your dataset(s).
130
+ precomputation_reuse (`bool`, defaults to `False`):
131
+ Reuse precomputed embeddings from previous training runs. This is useful to save time during training
132
+ with medium/large datasets. By default, old precomputed embeddings that exist in the specified precomputation
133
+ directory, or default precomputation dir `{output_dir}/precomputed` will be deleted if this is not set to `True`.
134
+ This flag is ignored if `enable_precomputation` is `False`. The topology of the distributed training run must be
135
+ the same as the one used to precompute the embeddings for this to work correctly (this limitation will be
136
+ addressed in the future).
137
 
138
  DATALOADER_ARGUMENTS
139
  --------------------
 
262
  The directory where the logs will be stored.
263
  logging_steps (`int`, defaults to `1`):
264
  Training logs will be tracked every `logging_steps` steps.
 
 
265
  nccl_timeout (`int`, defaults to `1800`):
266
  Timeout for the NCCL communication.
267
  report_to (`str`, defaults to `wandb`):
 
272
  - 1: Diffusers/Transformers info logging on local main process only
273
  - 2: Diffusers/Transformers debug logging on local main process only
274
  - 3: Diffusers/Transformers debug logging on all processes
275
+
276
+ TORCH CONFIG ARGUMENTS
277
+ ----------------------
278
+ allow_tf32 (`bool`, defaults to `False`):
279
+ Whether or not to allow the use of TF32 matmul on compatible hardware.
280
+ float32_matmul_precision (`str`, defaults to `highest`):
281
+ The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].
282
+ ```
283
+
284
+ ### Attention Provider
285
+
286
+ These arguments are relevant to setting attention provider for different modeling components. The attention providers may be set differently for training and validation/inference.
287
+
288
+ ```
289
+ attn_provider_training (`str`, defaults to "native"):
290
+ The attention provider to use for training. Choose between
291
+ [
292
+ 'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
293
+ '_native_math'
294
+ ]
295
+ attn_provider_inference (`str`, defaults to "native"):
296
+ The attention provider to use for validation. Choose between
297
+ [
298
+ 'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
299
+ '_native_math', 'sage', 'sage_varlen', '_sage_qk_int8_pv_fp8_cuda', '_sage_qk_int8_pv_fp8_cuda_sm90',
300
+ '_sage_qk_int8_pv_fp16_cuda', '_sage_qk_int8_pv_fp16_triton', 'xformers'
301
+ ]
302
  ```
303
 
304
  ## SFT training
docs/{finetrainers/documentation_dataset_README.md → finetrainers-src-codebase/docs/dataset/README.md} RENAMED
@@ -57,6 +57,8 @@ dataset
57
 
58
  #### CSV/JSON/JSONL format
59
 
 
 
60
  > [!NOTE]
61
  > Relevant classes to look for implementation:
62
  > - ImageFolderDataset
@@ -75,6 +77,8 @@ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely
75
 
76
  Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
77
 
 
 
78
  ## Validation Dataset Format
79
 
80
  Arguments related to validation are:
@@ -148,18 +152,21 @@ For memory efficient training, it is important to precompute conditional and lat
148
 
149
  The following is a high-level overview of how datasets are loaded and preprocessed:
150
 
151
- - Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Users can chain together multiple datasets too! For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and chain the data together. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions.
 
152
  - The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
153
- - If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way yet].
154
  - The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
155
- - The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precompute_batch_size`. This is done to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used.
156
  - When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
157
 
158
  ## Understanding how datasets are precomputed
159
 
160
- There are 3 arguments related to precomputation:
 
161
  - `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
162
  - `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
163
  - `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
 
164
 
165
  Batching is not yet supported for precomputation. This will be added in the future.
 
57
 
58
  #### CSV/JSON/JSONL format
59
 
60
+ - Supported names are: `metadata.json`, `metadata.jsonl`, `metadata.csv`
61
+
62
  > [!NOTE]
63
  > Relevant classes to look for implementation:
64
  > - ImageFolderDataset
 
77
 
78
  Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
79
 
80
+
81
+
82
  ## Validation Dataset Format
83
 
84
  Arguments related to validation are:
 
152
 
153
  The following is a high-level overview of how datasets are loaded and preprocessed:
154
 
155
+ - Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Multiple datasets can be chained together. For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and create a new copy of the dataset containing multiresolution data. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions.
156
+ - When chaining multiple different datasets, make sure they are roughly the same size to avoid having smaller datasets repeatedly being used in the training loop. This is because the datasets are loaded in a round-robin fashion. For example, if you have 2 datasets of size 1000 and 2000, the first dataset will be fully seen twice before the second dataset is fully seen once by the model.
157
  - The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
158
+ - If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way].
159
  - The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
160
+ - The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precomputation_items` to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used.
161
  - When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
162
 
163
  ## Understanding how datasets are precomputed
164
 
165
+ There are 4 arguments related to precomputation:
166
+ - `--enable_precomputation`: If set, precomputation will be enabled. The parameters that follow are only relevant if this flag is set. If this flag is not set, all models will be loaded in memory and training will take place without first precomputing embeddings.
167
  - `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
168
  - `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
169
  - `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
170
+ - `--precomputation_reuse`: If you're working with medium/large-size datasets and want to precompute all embeddings and re-use them across different training runs, make sure to set this flag.
171
 
172
  Batching is not yet supported for precomputation. This will be added in the future.
docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed dataset debugging
2
+
3
+ >![NOTE]
4
+ > This doc page is intended for developers and contributors.
5
+
6
+ If the number of samples in the dataset is lower than the number of processes per node, the training will hand indefinitely. I haven't been able to pin down on how this could be fixed due to limited time, but basically:
7
+ - Start training with `--dp_degree 2` and `torchrun --standalone --nnodes=1 --nproc_per_node=2`. This launches training with DDP across 2 ranks.
8
+ - The dataset has `< dp_degree` samples
9
+ - When `datasets.distributed.split_dataset_by_node` is called, the data is distributed correctly to one rank, but the other rank hangs indefinitely. Due to this edge case, fast tests seem to fail.
10
+ - For now, we should just use `>= dp_degree` samples in the test dataset. However, should be fixed in the future.
11
+
12
+ Minimal reproducer:
13
+
14
+ ```python
15
+ import torch
16
+ import torch.distributed as dist
17
+ from datasets import Dataset
18
+ from datasets.distributed import split_dataset_by_node
19
+ from torch.utils.data import DataLoader
20
+
21
+ ds = Dataset.from_dict({"x": [1]}).to_iterable_dataset()
22
+
23
+ dist.init_process_group()
24
+ rank, world_size = dist.get_rank(), dist.get_world_size()
25
+ ds = split_dataset_by_node(ds, rank=rank,world_size=world_size)
26
+ dl = DataLoader(ds)
27
+
28
+ exhausted = torch.zeros(world_size, dtype=torch.bool)
29
+
30
+ def loop():
31
+ while True:
32
+ print(rank, "hello", flush=True)
33
+ yield from dl
34
+ yield "end"
35
+
36
+ for x in loop():
37
+ if x == "end":
38
+ exhausted[rank] = True
39
+ continue
40
+ dist.all_reduce(exhausted)
41
+ if torch.all(exhausted):
42
+ break
43
+ print(f"{rank} {x}", flush=True)
44
+ ```
docs/{finetrainers/documentation_environment.md → finetrainers-src-codebase/docs/environment.md} RENAMED
@@ -26,3 +26,14 @@ NVIDIA A100-SXM4-80GB, 81920 MiB
26
  ```
27
 
28
  Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
 
 
 
 
 
 
 
 
 
 
 
 
26
  ```
27
 
28
  Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
29
+
30
+ > [!IMPORTANT]
31
+ >
32
+ > For context parallelism, PyTorch 2.6+ is required.
33
+
34
+ ## Configuration
35
+
36
+ The following environment variables may be configured to change the default behaviour of finetrainers:
37
+
38
+ `FINETRAINERS_ATTN_PROVIDER`: Sets the default attention provider for training/validation. Defaults to `native`, as in native PyTorch SDPA. See [attention docs](./models/attention.md) for more information.
39
+ `FINETRAINERS_ATTN_CHECKS`: Whether or not to run basic sanity checks when using different attention providers. This is useful for debugging but you should leave it disabled for longer training runs. Defaults to `"0"`. Can be set to a truthy env value.
docs/{finetrainers/documentation_models_README.md → finetrainers-src-codebase/docs/models/README.md} RENAMED
File without changes
docs/finetrainers-src-codebase/docs/models/attention.md ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attention backends
2
+
3
+ Finetrainers supports multiple attention backends to support different hardware and tradeoff between speed and memory usage. The following attention implementations are supported:
4
+ - Training:
5
+ - If model uses attention masks: `flash_varlen`, `flex`, `native`
6
+ - If model does not use attention masks: `flash`, `flex`, `native`, `xformers`
7
+ - Inference:
8
+ - If model uses attention masks: `flash_varlen`, `flex`, `native`, `sage_varlen`
9
+ - If model does not use attention masks: `flash`, `flash_varlen`, `flex`, `native`, `sage`, `sage_varlen`, `xformers`
10
+
11
+ Additionally, some specialized methods are available for debugging-specific purposes: `_native_cudnn`, `_native_efficient`, `_native_flash`, `_native_math`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`. With time, more attention-specific optimizations and custom implementations will be supported. Contributions are welcome!
12
+
13
+ Unfortunately, due to limited time for testing, only specific versions of packages that provide these implementations are supported. Other versions may work. The supported versions will be gradually made lower for more flexibility, but for now, please use the following versions:
14
+ - `flash-attn>=2.6.3`
15
+ - `sageattention>=2.1.1`
16
+ - `xformers>=0.0.29.post3`
17
+
18
+ This guide will help you quickly install flash-attn, sageattention, and xformers to make your models run faster and use less memory for training/inference. We'll cover installation on Linux (Ubuntu 22.04) and Windows (using WSL).
19
+
20
+ Before you start, make sure to use a clean python virtual environment to not mess up your system seriously, or to avoid conflicting dependencies leading to failed installations which might leave the environment in hard-to-recover state.
21
+
22
+ ### Flash attention
23
+
24
+ Providers covered: `flash`, `flash_varlen`
25
+
26
+ The installation steps have only been tested with Ubuntu 22.04; CUDA version higher than 12.2 and 12.6.
27
+ - Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
28
+ - You might need the following packages: `pip install packaging ninja`
29
+ - Linux: Run: `pip install flash-attn --no-build-isolation`. Verify the version with `pip show flash-attn`
30
+ - WSL: Same instruction as above should work. Native Windows might require building from source - check community guiders and follow the instruction [here](https://github.com/Dao-AILab/flash-attention).
31
+
32
+ ### Sage attention
33
+
34
+ Providers covered: `sage`, `sage_varlen`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`
35
+
36
+ FP8 implementations will require CUDA compute capability of 90 or higher (H100, RTX 5090, etc.). Some may work on compute capability 89 as well (RTX 4090, for example). For FP16 implementations, compute capability of atleast 80 is required (A100, RTX 3090, etc.). For other GPUs, FP16 implementations may or may not work (this is untested by me).
37
+
38
+ - Check your compute capability with the following command:
39
+ ```bash
40
+ python -c "import torch; print(torch.cuda.get_device_capability())"
41
+ ```
42
+ - Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
43
+ - You might need the following packages: `pip install triton`. For Windows, check out the [triton-windows](https://github.com/woct0rdho/triton-windows) project.
44
+ - Linux/WSL: Run: `pip install git+https://github.com/thu-ml/SageAttention`. Verify the version with `pip show sageattention`.
45
+ - Make sure to look at the official installation guide in [SageAttention](https://github.com/thu-ml/SageAttention) too!
46
+
47
+ ### xformers
48
+
49
+ Providers covered: `xformers`
50
+
51
+ - Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
52
+ - Linux/WSL: Run: `pip install -U xformers --index-url https://download.pytorch.org/whl/cu126` (assuming CUDA 12.6). Verify the version with `pip show xformers`.
53
+ - Make sure to look at the official installation guide in [xformers](https://github.com/facebookresearch/xformers) too!
54
+
55
+ ----------
56
+
57
+ All other providers are either native PyTorch implementations or require a specific PyTorch version (for example, Flex Attention requires torch version of atleast 2.5.0).
58
+
59
+ ----------
60
+
61
+ ## Usage
62
+
63
+ There are two ways to use the attention dispatcher mechanism:
64
+ - Replace `scaled_dot_product_attention` globally:
65
+ ```python
66
+ import torch.nn.functional as F
67
+ from finetrainers.models.attention_dispatch import attention_dispatch
68
+
69
+ F.scaled_dot_product_attention = attention_dispatch
70
+ ```
71
+ - Replace all occurrences of `scaled_dot_product_attention` in your code with `attention_dispatch`.
72
+
73
+ ```python
74
+ # Use dispatcher directly
75
+ from finetrainers.models.attention_dispatch import attention_provider, AttentionProvider
76
+
77
+ with attention_provider(AttentionProvider.FLASH_VARLEN):
78
+ model(...)
79
+
80
+ # or,
81
+ with attention_provider("sage_varlen"):
82
+ model(...)
83
+ ```
84
+
85
+ ## Context Parallel
86
+
87
+ References and reading material:
88
+ - https://docs.pytorch.org/tutorials/prototype/context_parallel.html
89
+ - https://insujang.github.io/2024-09-20/introducing-context-parallelism/
90
+ - https://www.youtube.com/watch?v=ws7angQYIxI
91
+ - https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
92
+ - https://arxiv.org/abs/2309.14509
93
+
94
+ There are three steps to enabling context parallelism with any model:
95
+ - Defining the context parallel plan: This is a dictionary that mentions what tensors to split and gather across CP region at different layers in the model
96
+ - Applying the CP plan with `apply_context_parallel` function: This registers the necessary hooks to split and gather tensors at the right places in the model without having to manually modify the model code.
97
+ - Running model under the `attention_provider` context manager
98
+
99
+ For a quick example, refer to the [inference example](#inference) below.
100
+
101
+ The CP plan is a dictionary that maps the name of the module to a list of `CPInput` or `CPOutput` objects. The keys in the dictionary are the names of the internal modules in the model, and the values are dictionaries that map a parameter identifier (either as an argument index or keyword argument as used in the forward method) to a `CPInput` or `CPOutput` object. The `CPInput` object specifies the input tensor to be split, and the `CPOutput` object specifies the output tensor to be gathered.
102
+
103
+ ```python
104
+ class ParamId:
105
+ name: Optional[str] = None
106
+ index: Optional[int] = None
107
+
108
+ class CPInput:
109
+ split_dim: int
110
+ expected_dims: Optional[int] = None
111
+ split_output: bool = False
112
+
113
+ class CPOutput:
114
+ gather_dim: int
115
+ expected_dims: Optional[int] = None
116
+ ```
117
+
118
+ - The `split_dim` and `gather_dim` parameters specify the dimension along which to split or gather the tensor. When using CP with native scaled dot product attention from pytorch, the tensor shape is `[B, N, S, D]`, so the `split_dim` and `gather_dim` parameters should be set to `2` as it is the sequence dimension.
119
+ - The `expected_dims` parameter is an optional parameter that is used for sanity checking if the tensor contains the expected number of dimensions.
120
+ - By default, `CPInput`'s are split in a pre-forward hook and `CPOutput`'s are gathered in a post-forward hook. If you want to split the output of a module, you can set the `split_output` parameter to `True`. This will split the output tensor in the post-forward hook instead of the pre-forward hook.
121
+
122
+ - Attention providers supported for training with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash`
123
+ - Attention providers supported for inference with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash`
124
+
125
+ ### Training
126
+
127
+ To enable training with context parallelism, you need to make sure a suitable CP plan is registered for the model you are using and launch training with `--cp_degree 2`. For models supported in finetrainers, this is internally done in the [transformer metadata](https://github.com/a-r-r-o-w/finetrainers/tree/main/finetrainers/models/_metadata/transformer.py) file. For custom models, make sure to pass the `plan` argument to the `apply_context_parallel` function.
128
+
129
+ Currently supported models include: CogVideoX, CogView4, Flux, Wan 2.1. Support for more models and attention providers is in progress.
130
+
131
+ ### Inference
132
+
133
+ The following example shows how to run context parallel inference. For more examples and ready-to-use inference scripts, check out the [examples/inference](https://github.com/a-r-r-o-w/finetrainers/tree/main/examples/inference/) folder.
134
+
135
+ <details>
136
+ <summary> Example </summary>
137
+
138
+ ```python
139
+ import torch
140
+ import torch.distributed as dist
141
+ from diffusers import AutoencoderKLWan, WanPipeline
142
+ from diffusers.utils import export_to_video
143
+
144
+ from finetrainers._metadata import ParamId, CPInput, CPOutput
145
+ from finetrainers.parallel.ptd import apply_context_parallel
146
+ from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch
147
+
148
+ torch.nn.functional.scaled_dot_product_attention = attention_dispatch
149
+
150
+
151
+ def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module:
152
+ r"""Apply torch.compile to a model or its submodules if not already compiled."""
153
+ if getattr(model, "_torch_compiled", False):
154
+ return model # Already compiled
155
+
156
+ if compile_scope == "full":
157
+ model = torch.compile(model)
158
+ setattr(model, "_torch_compiled", True)
159
+ elif compile_scope == "regional":
160
+ if isinstance(model, torch.nn.ModuleList):
161
+ for name, module in model.named_children():
162
+ if not getattr(module, "_torch_compiled", False):
163
+ compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False)
164
+ setattr(compiled_module, "_torch_compiled", True)
165
+ model.register_module(name, compiled_module)
166
+ else:
167
+ for name, module in model.named_children():
168
+ apply_compile(module, compile_scope)
169
+ else:
170
+ raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.")
171
+
172
+ return model
173
+
174
+
175
+ torch.manual_seed(0)
176
+ dist.init_process_group("nccl")
177
+ rank, world_size = dist.get_rank(), dist.get_world_size()
178
+ torch.cuda.set_device(rank)
179
+ cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"])
180
+
181
+ cp_plan = {
182
+ "rope": {
183
+ ParamId(index=0): CPInput(2, 4, split_output=True),
184
+ },
185
+ "blocks.*": {
186
+ ParamId("encoder_hidden_states", 1): CPInput(1, 3),
187
+ },
188
+ "blocks.0": {
189
+ ParamId("hidden_states", 0): CPInput(1, 3),
190
+ },
191
+ "proj_out": [CPOutput(1, 3)],
192
+ }
193
+
194
+ try:
195
+ model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
196
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
197
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
198
+ pipe.to("cuda")
199
+
200
+ apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan)
201
+
202
+ apply_compile(pipe.transformer, compile_scope="regional")
203
+
204
+ prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
205
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
206
+
207
+ with torch.no_grad():
208
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
209
+ prompt=prompt, negative_prompt=negative_prompt, device="cuda",
210
+ )
211
+
212
+ attention_backend = "_native_flash"
213
+ generator = torch.Generator().manual_seed(0)
214
+
215
+ # Warmup for compilation
216
+ with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"):
217
+ latents = pipe(
218
+ prompt_embeds=prompt_embeds,
219
+ negative_prompt_embeds=negative_prompt_embeds,
220
+ height=480,
221
+ width=832,
222
+ num_frames=81,
223
+ num_inference_steps=2,
224
+ guidance_scale=5.0,
225
+ output_type="latent",
226
+ generator=generator,
227
+ ).frames[0]
228
+
229
+ # Inference
230
+ with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"):
231
+ latents = pipe(
232
+ prompt_embeds=prompt_embeds,
233
+ negative_prompt_embeds=negative_prompt_embeds,
234
+ height=480,
235
+ width=832,
236
+ num_frames=81,
237
+ guidance_scale=5.0,
238
+ num_inference_steps=30,
239
+ output_type="latent",
240
+ generator=generator,
241
+ ).frames[0]
242
+
243
+ with torch.no_grad():
244
+ latents = latents.to(pipe.vae.dtype)
245
+ latents_mean = (
246
+ torch.tensor(pipe.vae.config.latents_mean)
247
+ .view(1, pipe.vae.config.z_dim, 1, 1, 1)
248
+ .to(latents.device, latents.dtype)
249
+ )
250
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
251
+ latents.device, latents.dtype
252
+ )
253
+ latents = latents / latents_std + latents_mean
254
+ video = pipe.vae.decode(latents, return_dict=False)[0]
255
+ video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
256
+
257
+ if rank == 0:
258
+ export_to_video(video, "output.mp4", fps=16)
259
+ finally:
260
+ dist.destroy_process_group()
261
+ ```
262
+
263
+ </details>
docs/{finetrainers/documentation_models_cogvideox.md → finetrainers-src-codebase/docs/models/cogvideox.md} RENAMED
@@ -20,9 +20,9 @@ On Windows, you will have to modify the script to a compatible format to run it.
20
 
21
  CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:
22
 
23
- * [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b)
24
- * [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B)
25
- * [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B)
26
 
27
  ## Inference
28
 
@@ -45,6 +45,6 @@ export_to_video(video, "output.mp4")
45
 
46
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
47
 
48
- * [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox)
49
- * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
50
- * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
 
20
 
21
  CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:
22
 
23
+ - [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b)
24
+ - [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B)
25
+ - [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B)
26
 
27
  ## Inference
28
 
 
45
 
46
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
47
 
48
+ - [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox)
49
+ - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
50
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers-src-codebase/docs/models/cogview4.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogView4
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ Examples available:
8
+ - [Raider White Tarot cards style](../../examples/training/sft/cogview4/raider_white_tarot/)
9
+ - [Omni Edit Control LoRA](../../examples/training/control/cogview4/omni_edit/)
10
+ - [Canny Control LoRA](../../examples/training/control/cogview4/canny/)
11
+
12
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
13
+
14
+ ```bash
15
+ chmod +x ./examples/training/sft/cogview4/raider_white_tarot/train.sh
16
+ ./examples/training/sft/cogview4/raider_white_tarot/train.sh
17
+ ```
18
+
19
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
20
+
21
+ ## Supported checkpoints
22
+
23
+ The following checkpoints were tested with `finetrainers` and are known to be working:
24
+
25
+ - [THUDM/CogView4-6B](https://huggingface.co/THUDM/CogView4-6B)
26
+
27
+ ## Inference
28
+
29
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
30
+
31
+ ```diff
32
+ import torch
33
+ from diffusers import CogView4Pipeline
34
+ from diffusers.utils import export_to_video
35
+
36
+ pipe = CogView4Pipeline.from_pretrained(
37
+ "THUDM/CogView4-6B", torch_dtype=torch.bfloat16
38
+ ).to("cuda")
39
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogview4-lora")
40
+ + pipe.set_adapters(["cogview4-lora"], [0.9])
41
+
42
+ video = pipe("<my-awesome-prompt>").frames[0]
43
+ export_to_video(video, "output.mp4")
44
+ ```
45
+
46
+ To use trained Control LoRAs, the following can be used for inference (ideally, you should raise a support request in Diffusers):
47
+
48
+ <details>
49
+ <summary> Control Lora inference </summary>
50
+
51
+ ```python
52
+ import torch
53
+ from diffusers import CogView4Pipeline
54
+ from diffusers.utils import load_image
55
+ from finetrainers.models.utils import _expand_linear_with_zeroed_weights
56
+ from finetrainers.patches import load_lora_weights
57
+ from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
58
+
59
+ dtype = torch.bfloat16
60
+ device = torch.device("cuda")
61
+ generator = torch.Generator().manual_seed(0)
62
+
63
+ pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype)
64
+
65
+ in_channels = pipe.transformer.config.in_channels
66
+ patch_channels = pipe.transformer.patch_embed.proj.in_features
67
+ pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels)
68
+
69
+ load_lora_weights(pipe, "/raid/aryan/cogview4-control-lora", "cogview4-lora")
70
+ pipe.to(device)
71
+
72
+ prompt = "Make the image look like it's from an ancient Egyptian mural."
73
+ control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png")
74
+ height, width = 1024, 1024
75
+
76
+ with torch.no_grad():
77
+ latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator)
78
+ control_image = pipe.image_processor.preprocess(control_image, height=height, width=width)
79
+ control_image = control_image.to(device=device, dtype=dtype)
80
+ control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator)
81
+ control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
82
+
83
+ with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]):
84
+ image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0]
85
+
86
+ image.save("output.png")
87
+ ```
88
+ </details>
89
+
90
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
91
+
92
+ - [CogView4 in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4)
93
+ - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
94
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers-src-codebase/docs/models/flux.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flux
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ Examples available:
8
+ - [Raider White Tarot cards style](../../examples/training/sft/flux_dev/raider_white_tarot/)
9
+
10
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
11
+
12
+ ```bash
13
+ chmod +x ./examples/training/sft/flux_dev/raider_white_tarot/train.sh
14
+ ./examples/training/sft/flux_dev/raider_white_tarot/train.sh
15
+ ```
16
+
17
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
18
+
19
+ > [!NOTE]
20
+ > Currently, only FLUX.1-dev is supported. It is a guidance-distilled model which directly predicts the outputs of its teacher model when the teacher is run with CFG. To match the output distribution of the distilled model with that of the teacher model, a guidance scale of 1.0 is hardcoded into the codebase. However, other values may work too but it is experimental.
21
+ > FLUX.1-schnell is not supported for training yet. It is a timestep-distilled model. Matching its output distribution for training is significantly more difficult.
22
+
23
+ ## Supported checkpoints
24
+
25
+ The following checkpoints were tested with `finetrainers` and are known to be working:
26
+
27
+ - [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
28
+ - [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
29
+
30
+ ## Inference
31
+
32
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
33
+
34
+ ```diff
35
+ import torch
36
+ from diffusers import FluxPipeline
37
+
38
+ pipe = FluxPipeline.from_pretrained(
39
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
40
+ ).to("cuda")
41
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="flux-lora")
42
+ + pipe.set_adapters(["flux-lora"], [0.9])
43
+
44
+ # Make sure to set guidance_scale to 0.0 when inferencing with FLUX.1-schnell or derivative models
45
+ image = pipe("<my-awesome-prompt>").images[0]
46
+ image.save("output.png")
47
+ ```
48
+
49
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
50
+
51
+ - [Flux in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux)
52
+ - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
53
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/{finetrainers/documentation_models_hunyuan_video.md → finetrainers-src-codebase/docs/models/hunyuan_video.md} RENAMED
@@ -50,6 +50,6 @@ export_to_video(output, "output.mp4", fps=15)
50
 
51
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
52
 
53
- * [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)
54
- * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
55
- * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
 
50
 
51
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
52
 
53
+ - [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)
54
+ - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
55
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/{finetrainers/documentation_models_ltx_video.md → finetrainers-src-codebase/docs/models/ltx_video.md} RENAMED
@@ -37,6 +37,6 @@ export_to_video(video, "output.mp4", fps=8)
37
 
38
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
39
 
40
- * [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video)
41
- * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
42
- * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
 
37
 
38
  You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
39
 
40
+ - [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video)
41
+ - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
42
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/{finetrainers/documentation_models_optimization.md → finetrainers-src-codebase/docs/models/optimization.md} RENAMED
File without changes
docs/{finetrainers/documentation_models_wan.md → finetrainers-src-codebase/docs/models/wan.md} RENAMED
@@ -18,6 +18,16 @@ chmod +x ./examples/training/sft/wan/crush_smol_lora/train.sh
18
 
19
  On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
20
 
 
 
 
 
 
 
 
 
 
 
21
  ## Inference
22
 
23
  Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
@@ -102,4 +112,4 @@ You can refer to the following guides to know more about the model pipeline and
102
 
103
  - [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
104
  - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
105
- - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
 
18
 
19
  On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
20
 
21
+ ## Supported checkpoints
22
+
23
+ Wan has multiple checkpoints as one can find [here](https://huggingface.co/Wan-AI). The following checkpoints were tested with `finetrainers` and are known to be working:
24
+
25
+ - [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
26
+ - [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
27
+ - [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
28
+ - [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
29
+ - [Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
30
+
31
  ## Inference
32
 
33
  Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
 
112
 
113
  - [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
114
  - [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
115
+ - [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/{finetrainers/documentation_optimizers.md → finetrainers-src-codebase/docs/optimizer.md} RENAMED
File without changes
docs/{finetrainers/documentation_parallel_processing_README.md → finetrainers-src-codebase/docs/parallel/README.md} RENAMED
@@ -14,11 +14,12 @@ As an experiment for comparing performance of different training backends, Finet
14
 
15
  ## Support matrix
16
 
17
- There are various algorithms for parallel training. Currently, we only support:
18
  - [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
19
  - [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
20
  - [HSDP](https://pytorch.org/docs/stable/fsdp.html)
21
- - [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html)
 
22
 
23
  ## Training
24
 
@@ -28,7 +29,7 @@ The following parameters are relevant for launching training:
28
  - `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
29
  - `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
30
  - `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
31
- - `cp_degree`: The degree of context parallelism. Currently unsupported.
32
  - `tp_degree`: The degree of tensor parallelism.
33
 
34
  For launching training with the Pytorch DTensor backend, use the following:
@@ -57,3 +58,7 @@ accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0
57
  # Multi-node - Nx8 GPUs available
58
  # TODO(aryan): Add slurm script
59
  ```
 
 
 
 
 
14
 
15
  ## Support matrix
16
 
17
+ Currently supported parallelizations include:
18
  - [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
19
  - [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
20
  - [HSDP](https://pytorch.org/docs/stable/fsdp.html)
21
+ - [CP](https://docs.pytorch.org/tutorials/prototype/context_parallel.html)
22
+ <!-- - [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) -->
23
 
24
  ## Training
25
 
 
29
  - `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
30
  - `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
31
  - `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
32
+ - `cp_degree`: The degree of context parallelism.
33
  - `tp_degree`: The degree of tensor parallelism.
34
 
35
  For launching training with the Pytorch DTensor backend, use the following:
 
58
  # Multi-node - Nx8 GPUs available
59
  # TODO(aryan): Add slurm script
60
  ```
61
+
62
+ ## Inference
63
+
64
+ For inference-only purposes, the example implementation can be found in the [examples/inference/](../../examples/inference/) directory.
docs/{finetrainers/documentation_trainers_control_trainer.md → finetrainers-src-codebase/docs/trainer/control_trainer.md} RENAMED
File without changes
docs/{finetrainers/documentation_trainers_sft_trainer.md → finetrainers-src-codebase/docs/trainer/sft_trainer.md} RENAMED
File without changes
docs/finetrainers-src-codebase/examples/_legacy/training/README.md ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX Factory 🧪
2
+
3
+ [中文阅读](./README_zh.md)
4
+
5
+ Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼
6
+
7
+ <table align="center">
8
+ <tr>
9
+ <td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
10
+ </tr>
11
+ </table>
12
+
13
+ **Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)!
14
+
15
+ ## Quickstart
16
+
17
+ Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`.
18
+
19
+ Then download a dataset:
20
+
21
+ ```bash
22
+ # install `huggingface_hub`
23
+ huggingface-cli download \
24
+ --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
25
+ --local-dir video-dataset-disney
26
+ ```
27
+
28
+ Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice):
29
+
30
+ ```bash
31
+ # For LoRA finetuning of the text-to-video CogVideoX models
32
+ ./train_text_to_video_lora.sh
33
+
34
+ # For full finetuning of the text-to-video CogVideoX models
35
+ ./train_text_to_video_sft.sh
36
+
37
+ # For LoRA finetuning of the image-to-video CogVideoX models
38
+ ./train_image_to_video_lora.sh
39
+ ```
40
+
41
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
42
+
43
+ ```diff
44
+ import torch
45
+ from diffusers import CogVideoXPipeline
46
+ from diffusers.utils import export_to_video
47
+
48
+ pipe = CogVideoXPipeline.from_pretrained(
49
+ "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
50
+ ).to("cuda")
51
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
52
+ + pipe.set_adapters(["cogvideox-lora"], [1.0])
53
+
54
+ video = pipe("<my-awesome-prompt>").frames[0]
55
+ export_to_video(video, "output.mp4", fps=8)
56
+ ```
57
+
58
+ For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):
59
+
60
+ ```python
61
+ from diffusers import CogVideoXImageToVideoPipeline
62
+
63
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
64
+ "THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
65
+ ).to("cuda")
66
+
67
+ # ...
68
+
69
+ del pipe.transformer.patch_embed.pos_embedding
70
+ pipe.transformer.patch_embed.use_learned_positional_embeddings = False
71
+ pipe.transformer.config.use_learned_positional_embeddings = False
72
+ ```
73
+
74
+ You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
75
+
76
+ Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.
77
+
78
+ ## Prepare Dataset and Training
79
+
80
+ Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example.
81
+
82
+ - Configure environment variables as per your choice:
83
+
84
+ ```bash
85
+ export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
86
+ export TORCHDYNAMO_VERBOSE=1
87
+ export WANDB_MODE="offline"
88
+ export NCCL_P2P_DISABLE=1
89
+ export TORCH_NCCL_ENABLE_MONITORING=0
90
+ ```
91
+
92
+ - Configure which GPUs to use for training: `GPU_IDS="0,1"`
93
+
94
+ - Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example:
95
+
96
+ ```bash
97
+ LEARNING_RATES=("1e-4" "1e-3")
98
+ LR_SCHEDULES=("cosine_with_restarts")
99
+ OPTIMIZERS=("adamw" "adam")
100
+ MAX_TRAIN_STEPS=("3000")
101
+ ```
102
+
103
+ - Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`.
104
+
105
+ - Specify the absolute paths and columns/files for captions and videos.
106
+
107
+ ```bash
108
+ DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
109
+ CAPTION_COLUMN="prompt.txt"
110
+ VIDEO_COLUMN="videos.txt"
111
+ ```
112
+
113
+ - Launch experiments sweeping different hyperparameters:
114
+ ```
115
+ for learning_rate in "${LEARNING_RATES[@]}"; do
116
+ for lr_schedule in "${LR_SCHEDULES[@]}"; do
117
+ for optimizer in "${OPTIMIZERS[@]}"; do
118
+ for steps in "${MAX_TRAIN_STEPS[@]}"; do
119
+ output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
120
+
121
+ cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
122
+ --pretrained_model_name_or_path THUDM/CogVideoX-5b \
123
+ --data_root $DATA_ROOT \
124
+ --caption_column $CAPTION_COLUMN \
125
+ --video_column $VIDEO_COLUMN \
126
+ --id_token BW_STYLE \
127
+ --height_buckets 480 \
128
+ --width_buckets 720 \
129
+ --frame_buckets 49 \
130
+ --dataloader_num_workers 8 \
131
+ --pin_memory \
132
+ --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
133
+ --validation_prompt_separator ::: \
134
+ --num_validation_videos 1 \
135
+ --validation_epochs 10 \
136
+ --seed 42 \
137
+ --rank 128 \
138
+ --lora_alpha 128 \
139
+ --mixed_precision bf16 \
140
+ --output_dir $output_dir \
141
+ --max_num_frames 49 \
142
+ --train_batch_size 1 \
143
+ --max_train_steps $steps \
144
+ --checkpointing_steps 1000 \
145
+ --gradient_accumulation_steps 1 \
146
+ --gradient_checkpointing \
147
+ --learning_rate $learning_rate \
148
+ --lr_scheduler $lr_schedule \
149
+ --lr_warmup_steps 400 \
150
+ --lr_num_cycles 1 \
151
+ --enable_slicing \
152
+ --enable_tiling \
153
+ --optimizer $optimizer \
154
+ --beta1 0.9 \
155
+ --beta2 0.95 \
156
+ --weight_decay 0.001 \
157
+ --max_grad_norm 1.0 \
158
+ --allow_tf32 \
159
+ --report_to wandb \
160
+ --nccl_timeout 1800"
161
+
162
+ echo "Running command: $cmd"
163
+ eval $cmd
164
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
165
+ done
166
+ done
167
+ done
168
+ done
169
+ ```
170
+
171
+ To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`.
172
+
173
+ Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below.
174
+
175
+ ## Memory requirements
176
+
177
+ <table align="center">
178
+ <tr>
179
+ <td align="center" colspan="2"><b>CogVideoX LoRA Finetuning</b></td>
180
+ </tr>
181
+ <tr>
182
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
183
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
184
+ </tr>
185
+ <tr>
186
+ <td align="center"><img src="../assets/lora_2b.png" /></td>
187
+ <td align="center"><img src="../assets/lora_5b.png" /></td>
188
+ </tr>
189
+
190
+ <tr>
191
+ <td align="center" colspan="2"><b>CogVideoX Full Finetuning</b></td>
192
+ </tr>
193
+ <tr>
194
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
195
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
196
+ </tr>
197
+ <tr>
198
+ <td align="center"><img src="../assets/sft_2b.png" /></td>
199
+ <td align="center"><img src="../assets/sft_5b.png" /></td>
200
+ </tr>
201
+ </table>
202
+
203
+ Supported and verified memory optimizations for training include:
204
+
205
+ - `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either.
206
+ - Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work
207
+ - DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations.
208
+
209
+ > [!IMPORTANT]
210
+ > The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
211
+ >
212
+ > If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`.
213
+
214
+ ### LoRA finetuning
215
+
216
+ > [!NOTE]
217
+ > The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
218
+ >
219
+ > Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
220
+ > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
221
+ > or provide a URL to a valid and accessible image.
222
+
223
+ <details>
224
+ <summary> AdamW </summary>
225
+
226
+ **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
227
+
228
+ With `train_batch_size = 1`:
229
+
230
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
231
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
232
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
233
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
234
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
235
+ | THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
236
+ | THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
237
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
238
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
239
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
240
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
241
+
242
+ With `train_batch_size = 4`:
243
+
244
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
245
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
246
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
247
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
248
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
249
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
250
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
251
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
252
+
253
+ </details>
254
+
255
+ <details>
256
+ <summary> AdamW (8-bit bitsandbytes) </summary>
257
+
258
+ **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
259
+
260
+ With `train_batch_size = 1`:
261
+
262
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
263
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
264
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
265
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
266
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
267
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
268
+ | THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
269
+ | THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
270
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
271
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
272
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
273
+
274
+ With `train_batch_size = 4`:
275
+
276
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
277
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
278
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
279
+ | THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
280
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
281
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
282
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
283
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
284
+
285
+ </details>
286
+
287
+ <details>
288
+ <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
289
+
290
+ **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
291
+
292
+ With `train_batch_size = 1`:
293
+
294
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
295
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
296
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
297
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
298
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
299
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
300
+ | THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
301
+ | THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
302
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
303
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
304
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
305
+
306
+ With `train_batch_size = 4`:
307
+
308
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
309
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
310
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
311
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
312
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
313
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
314
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
315
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
316
+
317
+ </details>
318
+
319
+ <details>
320
+ <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
321
+
322
+ **Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
323
+
324
+ With `train_batch_size = 1`:
325
+
326
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
327
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
328
+ | THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
329
+ | THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
330
+
331
+ With `train_batch_size = 4`:
332
+
333
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
334
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
335
+ | THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
336
+ | THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
337
+
338
+ </details>
339
+
340
+ ### Full finetuning
341
+
342
+ > [!NOTE]
343
+ > The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
344
+ >
345
+ > Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
346
+ > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
347
+ > or provide a URL to a valid and accessible image.
348
+
349
+ > [!NOTE]
350
+ > Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
351
+
352
+ <details>
353
+ <summary> AdamW </summary>
354
+
355
+ With `train_batch_size = 1`:
356
+
357
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
358
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
359
+ | THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
360
+ | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
361
+
362
+ With `train_batch_size = 4`:
363
+
364
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
365
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
366
+ | THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
367
+ | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
368
+
369
+ </details>
370
+
371
+ <details>
372
+ <summary> AdamW (8-bit bitsandbytes) </summary>
373
+
374
+ With `train_batch_size = 1`:
375
+
376
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
377
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
378
+ | THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
379
+ | THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
380
+
381
+ With `train_batch_size = 4`:
382
+
383
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
384
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
385
+ | THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
386
+ | THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
387
+
388
+ </details>
389
+
390
+ <details>
391
+ <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
392
+
393
+ With `train_batch_size = 1`:
394
+
395
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
396
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
397
+ | THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
398
+ | THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
399
+
400
+ With `train_batch_size = 4`:
401
+
402
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
403
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
404
+ | THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
405
+ | THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
406
+
407
+ </details>
408
+
409
+ <details>
410
+ <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
411
+
412
+ **Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
413
+
414
+ With `train_batch_size = 1`:
415
+
416
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
417
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
418
+ | THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
419
+ | THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
420
+
421
+ With `train_batch_size = 4`:
422
+
423
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
424
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
425
+ | THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
426
+ | THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
427
+
428
+ </details>
429
+
430
+ > [!NOTE]
431
+ > - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script.
432
+ >
433
+ > - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing.
434
+
435
+ <table align="center">
436
+ <tr>
437
+ <td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
438
+ </tr>
439
+ <tr>
440
+ <td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
441
+ </tr>
442
+ </table>
443
+
444
+ ## TODOs
445
+
446
+ - [x] Make scripts compatible with DDP
447
+ - [ ] Make scripts compatible with FSDP
448
+ - [x] Make scripts compatible with DeepSpeed
449
+ - [ ] vLLM-powered captioning script
450
+ - [x] Multi-resolution/frame support in `prepare_dataset.py`
451
+ - [ ] Analyzing traces for potential speedups and removing as many syncs as possible
452
+ - [x] Test scripts with memory-efficient optimizer from bitsandbytes
453
+ - [x] Test scripts with CPUOffloadOptimizer, etc.
454
+ - [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao))
455
+ - [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out)
456
+ - [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100)
457
+
458
+ > [!IMPORTANT]
459
+ > Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training.
docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX Factory 🧪
2
+
3
+ [Read in English](./README.md)
4
+
5
+ 在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼
6
+
7
+ <table align="center">
8
+ <tr>
9
+ <td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">您的浏览器不支持视频标签。</video></td>
10
+ </tr>
11
+ </table>
12
+
13
+ ## 快速开始
14
+
15
+ 克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。
16
+
17
+ 接着下载数据集:
18
+
19
+ ```
20
+ # 安装 `huggingface_hub`
21
+ huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
22
+ ```
23
+
24
+ 然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项):
25
+
26
+ ```
27
+ # 对 CogVideoX 模型进行文本到视频的 LoRA 微调
28
+ ./train_text_to_video_lora.sh
29
+
30
+ # 对 CogVideoX 模型进行文本到视频的完整微调
31
+ ./train_text_to_video_sft.sh
32
+
33
+ # 对 CogVideoX 模型进行图像到视频的 LoRA 微调
34
+ ./train_image_to_video_lora.sh
35
+ ```
36
+
37
+ 假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理:
38
+
39
+ ```
40
+ import torch
41
+ from diffusers import CogVideoXPipeline
42
+ from diffusers import export_to_video
43
+
44
+ pipe = CogVideoXPipeline.from_pretrained(
45
+ "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
46
+ ).to("cuda")
47
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"])
48
+ + pipe.set_adapters(["cogvideox-lora"], [1.0])
49
+
50
+ video = pipe("<my-awesome-prompt>").frames[0]
51
+ export_to_video(video, "output.mp4", fps=8)
52
+ ```
53
+
54
+ 你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。
55
+
56
+ **注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装
57
+ diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。
58
+
59
+ 以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。
60
+
61
+ ## 训练
62
+
63
+ 在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。
64
+
65
+ - 根据你的需求配置环境变量:
66
+
67
+ ```
68
+ export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
69
+ export TORCHDYNAMO_VERBOSE=1
70
+ export WANDB_MODE="offline"
71
+ export NCCL_P2P_DISABLE=1
72
+ export TORCH_NCCL_ENABLE_MONITORING=0
73
+ ```
74
+
75
+ - 配置用于训练的 GPU:`GPU_IDS="0,1"`
76
+
77
+ - 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例:
78
+
79
+ ```
80
+ LEARNING_RATES=("1e-4" "1e-3")
81
+ LR_SCHEDULES=("cosine_with_restarts")
82
+ OPTIMIZERS=("adamw" "adam")
83
+ MAX_TRAIN_STEPS=("3000")
84
+ ```
85
+
86
+ - 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`
87
+ 。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed
88
+ 等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。
89
+
90
+ - 指定字幕和视频的绝对路径以及列/文件。
91
+
92
+ ```
93
+ DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
94
+ CAPTION_COLUMN="prompt.txt"
95
+ VIDEO_COLUMN="videos.txt"
96
+ ```
97
+
98
+ - 运行实验,遍历不同的超参数:
99
+ ```
100
+ for learning_rate in "${LEARNING_RATES[@]}"; do
101
+ for lr_schedule in "${LR_SCHEDULES[@]}"; do
102
+ for optimizer in "${OPTIMIZERS[@]}"; do
103
+ for steps in "${MAX_TRAIN_STEPS[@]}"; do
104
+ output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
105
+
106
+ cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
107
+ --pretrained_model_name_or_path THUDM/CogVideoX-5b \
108
+ --data_root $DATA_ROOT \
109
+ --caption_column $CAPTION_COLUMN \
110
+ --video_column $VIDEO_COLUMN \
111
+ --id_token BW_STYLE \
112
+ --height_buckets 480 \
113
+ --width_buckets 720 \
114
+ --frame_buckets 49 \
115
+ --dataloader_num_workers 8 \
116
+ --pin_memory \
117
+ --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
118
+ --validation_prompt_separator ::: \
119
+ --num_validation_videos 1 \
120
+ --validation_epochs 10 \
121
+ --seed 42 \
122
+ --rank 128 \
123
+ --lora_alpha 128 \
124
+ --mixed_precision bf16 \
125
+ --output_dir $output_dir \
126
+ --max_num_frames 49 \
127
+ --train_batch_size 1 \
128
+ --max_train_steps $steps \
129
+ --checkpointing_steps 1000 \
130
+ --gradient_accumulation_steps 1 \
131
+ --gradient_checkpointing \
132
+ --learning_rate $learning_rate \
133
+ --lr_scheduler $lr_schedule \
134
+ --lr_warmup_steps 400 \
135
+ --lr_num_cycles 1 \
136
+ --enable_slicing \
137
+ --enable_tiling \
138
+ --optimizer $optimizer \
139
+ --beta1 0.9 \
140
+ --beta2 0.95 \
141
+ --weight_decay 0.001 \
142
+ --max_grad_norm 1.0 \
143
+ --allow_tf32 \
144
+ --report_to wandb \
145
+ --nccl_timeout 1800"
146
+
147
+ echo "Running command: $cmd"
148
+ eval $cmd
149
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
150
+ done
151
+ done
152
+ done
153
+ done
154
+ ```
155
+
156
+ 要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。
157
+
158
+ 注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。
159
+
160
+ ## 内存需求
161
+
162
+ <table align="center">
163
+ <tr>
164
+ <td align="center" colspan="2"><b>CogVideoX LoRA 微调</b></td>
165
+ </tr>
166
+ <tr>
167
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
168
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
169
+ </tr>
170
+ <tr>
171
+ <td align="center"><img src="assets/lora_2b.png" /></td>
172
+ <td align="center"><img src="assets/lora_5b.png" /></td>
173
+ </tr>
174
+
175
+ <tr>
176
+ <td align="center" colspan="2"><b>CogVideoX 全量微调</b></td>
177
+ </tr>
178
+ <tr>
179
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
180
+ <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
181
+ </tr>
182
+ <tr>
183
+ <td align="center"><img src="assets/sft_2b.png" /></td>
184
+ <td align="center"><img src="assets/sft_5b.png" /></td>
185
+ </tr>
186
+ </table>
187
+
188
+ 支持和验证的训练内存优化包括:
189
+
190
+ - `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao)
191
+ 。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload)
192
+ 阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU
193
+ 优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile`
194
+ 。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。
195
+ - 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers)
196
+ 的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。
197
+ - DeepSpeed Zero2:由于我们依赖 `accelerate`
198
+ ,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed
199
+ Zero2 优化训练。
200
+
201
+ > [!重要提示]
202
+ > 内存需求是运行 `training/prepare_dataset.py`
203
+ >
204
+ 后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。
205
+ >
206
+ > 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。
207
+
208
+ ### LoRA微调
209
+
210
+ > [!重要提示]
211
+ > 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。
212
+ >
213
+ > 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧:
214
+ > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
215
+ > 或提供一个有效且可访问的图像URL。
216
+
217
+ <details>
218
+ <summary> AdamW </summary>
219
+
220
+ **注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。
221
+
222
+ 当 `train_batch_size = 1` 时:
223
+
224
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
225
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
226
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
227
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
228
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
229
+ | THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
230
+ | THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
231
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
232
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
233
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
234
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
235
+
236
+ 当 `train_batch_size = 4` 时:
237
+
238
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
239
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
240
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
241
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
242
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
243
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
244
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
245
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
246
+
247
+ </details>
248
+
249
+ <details>
250
+ <summary> AdamW (8-bit bitsandbytes) </summary>
251
+
252
+ **注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
253
+
254
+ 当 `train_batch_size = 1` 时:
255
+
256
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
257
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
258
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
259
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
260
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
261
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
262
+ | THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
263
+ | THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
264
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
265
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
266
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
267
+
268
+ 当 `train_batch_size = 4` 时:
269
+
270
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
271
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
272
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
273
+ | THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
274
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
275
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
276
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
277
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
278
+
279
+ </details>
280
+
281
+ <details>
282
+ <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
283
+
284
+ **注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
285
+
286
+ 当 `train_batch_size = 1` 时:
287
+
288
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
289
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
290
+ | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
291
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
292
+ | THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
293
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
294
+ | THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
295
+ | THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
296
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
297
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
298
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
299
+
300
+ 当 `train_batch_size = 4` 时:
301
+
302
+ | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
303
+ |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
304
+ | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
305
+ | THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
306
+ | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
307
+ | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
308
+ | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
309
+ | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
310
+
311
+ </details>
312
+
313
+ <details>
314
+ <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
315
+
316
+ **注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。
317
+
318
+ 当 `train_batch_size = 1` 时:
319
+
320
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
321
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
322
+ | THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
323
+ | THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
324
+
325
+ 当 `train_batch_size = 4` 时:
326
+
327
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
328
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
329
+ | THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
330
+ | THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
331
+
332
+ </details>
333
+
334
+ ### Full finetuning
335
+
336
+ > [!注意]
337
+ > 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。
338
+ >
339
+ > 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧:
340
+ > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
341
+ > 或提供一个有效且可访问的图像 URL。
342
+
343
+ > [!注意]
344
+ > 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。
345
+
346
+ <details>
347
+ <summary> AdamW </summary>
348
+
349
+ 当 `train_batch_size = 1` 时:
350
+
351
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
352
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
353
+ | THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
354
+ | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
355
+
356
+ 当 `train_batch_size = 4` 时:
357
+
358
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
359
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
360
+ | THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
361
+ | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
362
+
363
+ </details>
364
+
365
+ <details>
366
+ <summary> AdamW (8-bit 量化) </summary>
367
+
368
+ 当 `train_batch_size = 1` 时:
369
+
370
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
371
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
372
+ | THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
373
+ | THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
374
+
375
+ 当 `train_batch_size = 4` 时:
376
+
377
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
378
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
379
+ | THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
380
+ | THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
381
+
382
+ </details>
383
+
384
+ <details>
385
+ <summary> AdamW + CPUOffloadOptimizer(带有梯度卸载)</summary>
386
+
387
+ 当 `train_batch_size = 1` 时:
388
+
389
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
390
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
391
+ | THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
392
+ | THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
393
+
394
+ 当 `train_batch_size = 4` 时:
395
+
396
+ | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
397
+ |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
398
+ | THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
399
+ | THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
400
+
401
+ </details>
402
+
403
+ <details>
404
+ <summary> DeepSpeed(AdamW + CPU/参数卸载) </summary>
405
+
406
+ **注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。
407
+
408
+ 当 `train_batch_size = 1` 时:
409
+
410
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
411
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
412
+ | THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
413
+ | THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
414
+
415
+ 当 `train_batch_size = 4` 时:
416
+
417
+ | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
418
+ |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
419
+ | THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
420
+ | THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
421
+
422
+ </details>
423
+
424
+ > [!注意]
425
+ > - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载
426
+ VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。
427
+ >
428
+ > - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。
429
+
430
+ <table align="center">
431
+ <tr>
432
+ <td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
433
+ </tr>
434
+ <tr>
435
+ <td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
436
+ </tr>
437
+ </table>
438
+
439
+ ## 待办事项
440
+
441
+ - [x] 使脚本兼容 DDP
442
+ - [ ] 使脚本兼容 FSDP
443
+ - [x] 使脚本兼容 DeepSpeed
444
+ - [ ] 基于 vLLM 的字幕脚本
445
+ - [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数
446
+ - [ ] 分析性能瓶颈并尽可能减少同步操作
447
+ - [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法
448
+ - [x] 使用 bitsandbytes 的节省内存优化器测试脚本
449
+ - [x] 使用 CPUOffloadOptimizer 等测试脚本
450
+ - [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错)
451
+ - [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错)
452
+ - [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化)
453
+
454
+ > [!重要]
455
+ > 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py ADDED
File without changes
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def _get_model_args(parser: argparse.ArgumentParser) -> None:
5
+ parser.add_argument(
6
+ "--pretrained_model_name_or_path",
7
+ type=str,
8
+ default=None,
9
+ required=True,
10
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
11
+ )
12
+ parser.add_argument(
13
+ "--revision",
14
+ type=str,
15
+ default=None,
16
+ required=False,
17
+ help="Revision of pretrained model identifier from huggingface.co/models.",
18
+ )
19
+ parser.add_argument(
20
+ "--variant",
21
+ type=str,
22
+ default=None,
23
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
24
+ )
25
+ parser.add_argument(
26
+ "--cache_dir",
27
+ type=str,
28
+ default=None,
29
+ help="The directory where the downloaded models and datasets will be stored.",
30
+ )
31
+
32
+
33
+ def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
34
+ parser.add_argument(
35
+ "--data_root",
36
+ type=str,
37
+ default=None,
38
+ help=("A folder containing the training data."),
39
+ )
40
+ parser.add_argument(
41
+ "--dataset_file",
42
+ type=str,
43
+ default=None,
44
+ help=("Path to a CSV file if loading prompts/video paths using this format."),
45
+ )
46
+ parser.add_argument(
47
+ "--video_column",
48
+ type=str,
49
+ default="video",
50
+ help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
51
+ )
52
+ parser.add_argument(
53
+ "--caption_column",
54
+ type=str,
55
+ default="text",
56
+ help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
57
+ )
58
+ parser.add_argument(
59
+ "--id_token",
60
+ type=str,
61
+ default=None,
62
+ help="Identifier token appended to the start of each prompt if provided.",
63
+ )
64
+ parser.add_argument(
65
+ "--height_buckets",
66
+ nargs="+",
67
+ type=int,
68
+ default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
69
+ )
70
+ parser.add_argument(
71
+ "--width_buckets",
72
+ nargs="+",
73
+ type=int,
74
+ default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
75
+ )
76
+ parser.add_argument(
77
+ "--frame_buckets",
78
+ nargs="+",
79
+ type=int,
80
+ default=[49],
81
+ help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53"
82
+ )
83
+ parser.add_argument(
84
+ "--load_tensors",
85
+ action="store_true",
86
+ help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
87
+ )
88
+ parser.add_argument(
89
+ "--random_flip",
90
+ type=float,
91
+ default=None,
92
+ help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
93
+ )
94
+ parser.add_argument(
95
+ "--dataloader_num_workers",
96
+ type=int,
97
+ default=0,
98
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
99
+ )
100
+ parser.add_argument(
101
+ "--pin_memory",
102
+ action="store_true",
103
+ help="Whether or not to use the pinned memory setting in pytorch dataloader.",
104
+ )
105
+
106
+
107
+ def _get_validation_args(parser: argparse.ArgumentParser) -> None:
108
+ parser.add_argument(
109
+ "--validation_prompt",
110
+ type=str,
111
+ default=None,
112
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
113
+ )
114
+ parser.add_argument(
115
+ "--validation_images",
116
+ type=str,
117
+ default=None,
118
+ help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
119
+ )
120
+ parser.add_argument(
121
+ "--validation_prompt_separator",
122
+ type=str,
123
+ default=":::",
124
+ help="String that separates multiple validation prompts",
125
+ )
126
+ parser.add_argument(
127
+ "--num_validation_videos",
128
+ type=int,
129
+ default=1,
130
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
131
+ )
132
+ parser.add_argument(
133
+ "--validation_epochs",
134
+ type=int,
135
+ default=None,
136
+ help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
137
+ )
138
+ parser.add_argument(
139
+ "--validation_steps",
140
+ type=int,
141
+ default=None,
142
+ help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
143
+ )
144
+ parser.add_argument(
145
+ "--guidance_scale",
146
+ type=float,
147
+ default=6,
148
+ help="The guidance scale to use while sampling validation videos.",
149
+ )
150
+ parser.add_argument(
151
+ "--use_dynamic_cfg",
152
+ action="store_true",
153
+ default=False,
154
+ help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
155
+ )
156
+ parser.add_argument(
157
+ "--enable_model_cpu_offload",
158
+ action="store_true",
159
+ default=False,
160
+ help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
161
+ )
162
+
163
+
164
+ def _get_training_args(parser: argparse.ArgumentParser) -> None:
165
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
166
+ parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
167
+ parser.add_argument(
168
+ "--lora_alpha",
169
+ type=int,
170
+ default=64,
171
+ help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
172
+ )
173
+ parser.add_argument(
174
+ "--mixed_precision",
175
+ type=str,
176
+ default=None,
177
+ choices=["no", "fp16", "bf16"],
178
+ help=(
179
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. "
180
+ "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this "
181
+ "argument to override the accelerate config."
182
+ ),
183
+ )
184
+ parser.add_argument(
185
+ "--output_dir",
186
+ type=str,
187
+ default="cogvideox-sft",
188
+ help="The output directory where the model predictions and checkpoints will be written.",
189
+ )
190
+ parser.add_argument(
191
+ "--height",
192
+ type=int,
193
+ default=480,
194
+ help="All input videos are resized to this height.",
195
+ )
196
+ parser.add_argument(
197
+ "--width",
198
+ type=int,
199
+ default=720,
200
+ help="All input videos are resized to this width.",
201
+ )
202
+ parser.add_argument(
203
+ "--video_reshape_mode",
204
+ type=str,
205
+ default=None,
206
+ help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
207
+ )
208
+ parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
209
+ parser.add_argument(
210
+ "--max_num_frames",
211
+ type=int,
212
+ default=49,
213
+ help="All input videos will be truncated to these many frames.",
214
+ )
215
+ parser.add_argument(
216
+ "--skip_frames_start",
217
+ type=int,
218
+ default=0,
219
+ help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
220
+ )
221
+ parser.add_argument(
222
+ "--skip_frames_end",
223
+ type=int,
224
+ default=0,
225
+ help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
226
+ )
227
+ parser.add_argument(
228
+ "--train_batch_size",
229
+ type=int,
230
+ default=4,
231
+ help="Batch size (per device) for the training dataloader.",
232
+ )
233
+ parser.add_argument("--num_train_epochs", type=int, default=1)
234
+ parser.add_argument(
235
+ "--max_train_steps",
236
+ type=int,
237
+ default=None,
238
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
239
+ )
240
+ parser.add_argument(
241
+ "--checkpointing_steps",
242
+ type=int,
243
+ default=500,
244
+ help=(
245
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
246
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
247
+ " training using `--resume_from_checkpoint`."
248
+ ),
249
+ )
250
+ parser.add_argument(
251
+ "--checkpoints_total_limit",
252
+ type=int,
253
+ default=None,
254
+ help=("Max number of checkpoints to store."),
255
+ )
256
+ parser.add_argument(
257
+ "--resume_from_checkpoint",
258
+ type=str,
259
+ default=None,
260
+ help=(
261
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
262
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
263
+ ),
264
+ )
265
+ parser.add_argument(
266
+ "--gradient_accumulation_steps",
267
+ type=int,
268
+ default=1,
269
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
270
+ )
271
+ parser.add_argument(
272
+ "--gradient_checkpointing",
273
+ action="store_true",
274
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
275
+ )
276
+ parser.add_argument(
277
+ "--learning_rate",
278
+ type=float,
279
+ default=1e-4,
280
+ help="Initial learning rate (after the potential warmup period) to use.",
281
+ )
282
+ parser.add_argument(
283
+ "--scale_lr",
284
+ action="store_true",
285
+ default=False,
286
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
287
+ )
288
+ parser.add_argument(
289
+ "--lr_scheduler",
290
+ type=str,
291
+ default="constant",
292
+ help=(
293
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
294
+ ' "constant", "constant_with_warmup"]'
295
+ ),
296
+ )
297
+ parser.add_argument(
298
+ "--lr_warmup_steps",
299
+ type=int,
300
+ default=500,
301
+ help="Number of steps for the warmup in the lr scheduler.",
302
+ )
303
+ parser.add_argument(
304
+ "--lr_num_cycles",
305
+ type=int,
306
+ default=1,
307
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
308
+ )
309
+ parser.add_argument(
310
+ "--lr_power",
311
+ type=float,
312
+ default=1.0,
313
+ help="Power factor of the polynomial scheduler.",
314
+ )
315
+ parser.add_argument(
316
+ "--enable_slicing",
317
+ action="store_true",
318
+ default=False,
319
+ help="Whether or not to use VAE slicing for saving memory.",
320
+ )
321
+ parser.add_argument(
322
+ "--enable_tiling",
323
+ action="store_true",
324
+ default=False,
325
+ help="Whether or not to use VAE tiling for saving memory.",
326
+ )
327
+ parser.add_argument(
328
+ "--noised_image_dropout",
329
+ type=float,
330
+ default=0.05,
331
+ help="Image condition dropout probability when finetuning image-to-video.",
332
+ )
333
+ parser.add_argument(
334
+ "--ignore_learned_positional_embeddings",
335
+ action="store_true",
336
+ default=False,
337
+ help=(
338
+ "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
339
+ "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
340
+ "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
341
+ ),
342
+ )
343
+
344
+
345
+ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
346
+ parser.add_argument(
347
+ "--optimizer",
348
+ type=lambda s: s.lower(),
349
+ default="adam",
350
+ choices=["adam", "adamw", "prodigy", "came"],
351
+ help=("The optimizer type to use."),
352
+ )
353
+ parser.add_argument(
354
+ "--use_8bit",
355
+ action="store_true",
356
+ help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.",
357
+ )
358
+ parser.add_argument(
359
+ "--use_4bit",
360
+ action="store_true",
361
+ help="Whether or not to use 4-bit optimizers from `torchao`.",
362
+ )
363
+ parser.add_argument(
364
+ "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers."
365
+ )
366
+ parser.add_argument(
367
+ "--beta1",
368
+ type=float,
369
+ default=0.9,
370
+ help="The beta1 parameter for the Adam and Prodigy optimizers.",
371
+ )
372
+ parser.add_argument(
373
+ "--beta2",
374
+ type=float,
375
+ default=0.95,
376
+ help="The beta2 parameter for the Adam and Prodigy optimizers.",
377
+ )
378
+ parser.add_argument(
379
+ "--beta3",
380
+ type=float,
381
+ default=None,
382
+ help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
383
+ )
384
+ parser.add_argument(
385
+ "--prodigy_decouple",
386
+ action="store_true",
387
+ help="Use AdamW style decoupled weight decay.",
388
+ )
389
+ parser.add_argument(
390
+ "--weight_decay",
391
+ type=float,
392
+ default=1e-04,
393
+ help="Weight decay to use for optimizer.",
394
+ )
395
+ parser.add_argument(
396
+ "--epsilon",
397
+ type=float,
398
+ default=1e-8,
399
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
400
+ )
401
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
402
+ parser.add_argument(
403
+ "--prodigy_use_bias_correction",
404
+ action="store_true",
405
+ help="Turn on Adam's bias correction.",
406
+ )
407
+ parser.add_argument(
408
+ "--prodigy_safeguard_warmup",
409
+ action="store_true",
410
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
411
+ )
412
+ parser.add_argument(
413
+ "--use_cpu_offload_optimizer",
414
+ action="store_true",
415
+ help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.",
416
+ )
417
+ parser.add_argument(
418
+ "--offload_gradients",
419
+ action="store_true",
420
+ help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.",
421
+ )
422
+
423
+
424
+ def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
425
+ parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
426
+ parser.add_argument(
427
+ "--push_to_hub",
428
+ action="store_true",
429
+ help="Whether or not to push the model to the Hub.",
430
+ )
431
+ parser.add_argument(
432
+ "--hub_token",
433
+ type=str,
434
+ default=None,
435
+ help="The token to use to push to the Model Hub.",
436
+ )
437
+ parser.add_argument(
438
+ "--hub_model_id",
439
+ type=str,
440
+ default=None,
441
+ help="The name of the repository to keep in sync with the local `output_dir`.",
442
+ )
443
+ parser.add_argument(
444
+ "--logging_dir",
445
+ type=str,
446
+ default="logs",
447
+ help="Directory where logs are stored.",
448
+ )
449
+ parser.add_argument(
450
+ "--allow_tf32",
451
+ action="store_true",
452
+ help=(
453
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
454
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
455
+ ),
456
+ )
457
+ parser.add_argument(
458
+ "--nccl_timeout",
459
+ type=int,
460
+ default=600,
461
+ help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
462
+ )
463
+ parser.add_argument(
464
+ "--report_to",
465
+ type=str,
466
+ default=None,
467
+ help=(
468
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
469
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
470
+ ),
471
+ )
472
+
473
+
474
+ def get_args():
475
+ parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
476
+
477
+ _get_model_args(parser)
478
+ _get_dataset_args(parser)
479
+ _get_training_args(parser)
480
+ _get_validation_args(parser)
481
+ _get_optimizer_args(parser)
482
+ _get_configuration_args(parser)
483
+
484
+ return parser.parse_args()
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ import shutil
22
+ from datetime import timedelta
23
+ from pathlib import Path
24
+ from typing import Any, Dict
25
+
26
+ import diffusers
27
+ import torch
28
+ import transformers
29
+ import wandb
30
+ from accelerate import Accelerator, DistributedType
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import (
33
+ DistributedDataParallelKwargs,
34
+ InitProcessGroupKwargs,
35
+ ProjectConfiguration,
36
+ set_seed,
37
+ )
38
+ from diffusers import (
39
+ AutoencoderKLCogVideoX,
40
+ CogVideoXDPMScheduler,
41
+ CogVideoXImageToVideoPipeline,
42
+ CogVideoXTransformer3DModel,
43
+ )
44
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
45
+ from diffusers.optimization import get_scheduler
46
+ from diffusers.training_utils import cast_training_params
47
+ from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
48
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
+ from huggingface_hub import create_repo, upload_folder
50
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
51
+ from torch.utils.data import DataLoader
52
+ from tqdm.auto import tqdm
53
+ from transformers import AutoTokenizer, T5EncoderModel
54
+
55
+
56
+ from args import get_args # isort:skip
57
+ from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
58
+ from text_encoder import compute_prompt_embeddings # isort:skip
59
+ from utils import (
60
+ get_gradient_norm,
61
+ get_optimizer,
62
+ prepare_rotary_positional_embeddings,
63
+ print_memory,
64
+ reset_memory,
65
+ unwrap_model,
66
+ )
67
+
68
+
69
+ logger = get_logger(__name__)
70
+
71
+
72
+ def save_model_card(
73
+ repo_id: str,
74
+ videos=None,
75
+ base_model: str = None,
76
+ validation_prompt=None,
77
+ repo_folder=None,
78
+ fps=8,
79
+ ):
80
+ widget_dict = []
81
+ if videos is not None:
82
+ for i, video in enumerate(videos):
83
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
84
+ widget_dict.append(
85
+ {
86
+ "text": validation_prompt if validation_prompt else " ",
87
+ "output": {"url": f"video_{i}.mp4"},
88
+ }
89
+ )
90
+
91
+ model_description = f"""
92
+ # CogVideoX LoRA Finetune
93
+
94
+ <Gallery />
95
+
96
+ ## Model description
97
+
98
+ This is a lora finetune of the CogVideoX model `{base_model}`.
99
+
100
+ The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
101
+
102
+ ## Download model
103
+
104
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
105
+
106
+ ## Usage
107
+
108
+ Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
109
+
110
+ ```py
111
+ import torch
112
+ from diffusers import CogVideoXImageToVideoPipeline
113
+ from diffusers.utils import export_to_video, load_image
114
+
115
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda")
116
+ pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
117
+
118
+ # The LoRA adapter weights are determined by what was used for training.
119
+ # In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
120
+ # It can be made lower or higher from what was used in training to decrease or amplify the effect
121
+ # of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
122
+ pipe.set_adapters(["cogvideox-lora"], [32 / 64])
123
+
124
+ image = load_image("/path/to/image.png")
125
+ video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
126
+ export_to_video(video, "output.mp4", fps=8)
127
+ ```
128
+
129
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
130
+
131
+ ## License
132
+
133
+ Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
134
+ """
135
+ model_card = load_or_create_model_card(
136
+ repo_id_or_path=repo_id,
137
+ from_training=True,
138
+ license="other",
139
+ base_model=base_model,
140
+ prompt=validation_prompt,
141
+ model_description=model_description,
142
+ widget=widget_dict,
143
+ )
144
+ tags = [
145
+ "text-to-video",
146
+ "image-to-video",
147
+ "diffusers-training",
148
+ "diffusers",
149
+ "lora",
150
+ "cogvideox",
151
+ "cogvideox-diffusers",
152
+ "template:sd-lora",
153
+ ]
154
+
155
+ model_card = populate_model_card(model_card, tags=tags)
156
+ model_card.save(os.path.join(repo_folder, "README.md"))
157
+
158
+
159
+ def log_validation(
160
+ accelerator: Accelerator,
161
+ pipe: CogVideoXImageToVideoPipeline,
162
+ args: Dict[str, Any],
163
+ pipeline_args: Dict[str, Any],
164
+ is_final_validation: bool = False,
165
+ ):
166
+ logger.info(
167
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
168
+ )
169
+
170
+ pipe = pipe.to(accelerator.device)
171
+
172
+ # run inference
173
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
174
+
175
+ videos = []
176
+ for _ in range(args.num_validation_videos):
177
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
178
+ videos.append(video)
179
+
180
+ for tracker in accelerator.trackers:
181
+ phase_name = "test" if is_final_validation else "validation"
182
+ if tracker.name == "wandb":
183
+ video_filenames = []
184
+ for i, video in enumerate(videos):
185
+ prompt = (
186
+ pipeline_args["prompt"][:25]
187
+ .replace(" ", "_")
188
+ .replace(" ", "_")
189
+ .replace("'", "_")
190
+ .replace('"', "_")
191
+ .replace("/", "_")
192
+ )
193
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
194
+ export_to_video(video, filename, fps=8)
195
+ video_filenames.append(filename)
196
+
197
+ tracker.log(
198
+ {
199
+ phase_name: [
200
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
201
+ for i, filename in enumerate(video_filenames)
202
+ ]
203
+ }
204
+ )
205
+
206
+ return videos
207
+
208
+
209
+ def run_validation(
210
+ args: Dict[str, Any],
211
+ accelerator: Accelerator,
212
+ transformer,
213
+ scheduler,
214
+ model_config: Dict[str, Any],
215
+ weight_dtype: torch.dtype,
216
+ ) -> None:
217
+ accelerator.print("===== Memory before validation =====")
218
+ print_memory(accelerator.device)
219
+ torch.cuda.synchronize(accelerator.device)
220
+
221
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
222
+ args.pretrained_model_name_or_path,
223
+ transformer=unwrap_model(accelerator, transformer),
224
+ scheduler=scheduler,
225
+ revision=args.revision,
226
+ variant=args.variant,
227
+ torch_dtype=weight_dtype,
228
+ )
229
+
230
+ if args.enable_slicing:
231
+ pipe.vae.enable_slicing()
232
+ if args.enable_tiling:
233
+ pipe.vae.enable_tiling()
234
+ if args.enable_model_cpu_offload:
235
+ pipe.enable_model_cpu_offload()
236
+
237
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
238
+ validation_images = args.validation_images.split(args.validation_prompt_separator)
239
+ for validation_image, validation_prompt in zip(validation_images, validation_prompts):
240
+ pipeline_args = {
241
+ "image": load_image(validation_image),
242
+ "prompt": validation_prompt,
243
+ "guidance_scale": args.guidance_scale,
244
+ "use_dynamic_cfg": args.use_dynamic_cfg,
245
+ "height": args.height,
246
+ "width": args.width,
247
+ "max_sequence_length": model_config.max_text_seq_length,
248
+ }
249
+
250
+ log_validation(
251
+ pipe=pipe,
252
+ args=args,
253
+ accelerator=accelerator,
254
+ pipeline_args=pipeline_args,
255
+ )
256
+
257
+ accelerator.print("===== Memory after validation =====")
258
+ print_memory(accelerator.device)
259
+ reset_memory(accelerator.device)
260
+
261
+ del pipe
262
+ gc.collect()
263
+ torch.cuda.empty_cache()
264
+ torch.cuda.synchronize(accelerator.device)
265
+
266
+
267
+ class CollateFunction:
268
+ def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
269
+ self.weight_dtype = weight_dtype
270
+ self.load_tensors = load_tensors
271
+
272
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
273
+ prompts = [x["prompt"] for x in data[0]]
274
+
275
+ if self.load_tensors:
276
+ prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
277
+
278
+ images = [x["image"] for x in data[0]]
279
+ images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
280
+
281
+ videos = [x["video"] for x in data[0]]
282
+ videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
283
+
284
+ return {
285
+ "images": images,
286
+ "videos": videos,
287
+ "prompts": prompts,
288
+ }
289
+
290
+
291
+ def main(args):
292
+ if args.report_to == "wandb" and args.hub_token is not None:
293
+ raise ValueError(
294
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
295
+ " Please use `huggingface-cli login` to authenticate with the Hub."
296
+ )
297
+
298
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
299
+ # due to pytorch#99272, MPS does not yet support bfloat16.
300
+ raise ValueError(
301
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
302
+ )
303
+
304
+ logging_dir = Path(args.output_dir, args.logging_dir)
305
+
306
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
307
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
308
+ init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
309
+ accelerator = Accelerator(
310
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
311
+ mixed_precision=args.mixed_precision,
312
+ log_with=args.report_to,
313
+ project_config=accelerator_project_config,
314
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
315
+ )
316
+
317
+ # Disable AMP for MPS.
318
+ if torch.backends.mps.is_available():
319
+ accelerator.native_amp = False
320
+
321
+ # Make one log on every process with the configuration for debugging.
322
+ logging.basicConfig(
323
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
+ datefmt="%m/%d/%Y %H:%M:%S",
325
+ level=logging.INFO,
326
+ )
327
+ logger.info(accelerator.state, main_process_only=False)
328
+ if accelerator.is_local_main_process:
329
+ transformers.utils.logging.set_verbosity_warning()
330
+ diffusers.utils.logging.set_verbosity_info()
331
+ else:
332
+ transformers.utils.logging.set_verbosity_error()
333
+ diffusers.utils.logging.set_verbosity_error()
334
+
335
+ # If passed along, set the training seed now.
336
+ if args.seed is not None:
337
+ set_seed(args.seed)
338
+
339
+ # Handle the repository creation
340
+ if accelerator.is_main_process:
341
+ if args.output_dir is not None:
342
+ os.makedirs(args.output_dir, exist_ok=True)
343
+
344
+ if args.push_to_hub:
345
+ repo_id = create_repo(
346
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
347
+ exist_ok=True,
348
+ ).repo_id
349
+
350
+ # Prepare models and scheduler
351
+ tokenizer = AutoTokenizer.from_pretrained(
352
+ args.pretrained_model_name_or_path,
353
+ subfolder="tokenizer",
354
+ revision=args.revision,
355
+ )
356
+
357
+ text_encoder = T5EncoderModel.from_pretrained(
358
+ args.pretrained_model_name_or_path,
359
+ subfolder="text_encoder",
360
+ revision=args.revision,
361
+ )
362
+
363
+ # CogVideoX-2b weights are stored in float16
364
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
365
+ load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
366
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
367
+ args.pretrained_model_name_or_path,
368
+ subfolder="transformer",
369
+ torch_dtype=load_dtype,
370
+ revision=args.revision,
371
+ variant=args.variant,
372
+ )
373
+
374
+ # These changes will also be required when trying to run inference with the trained lora
375
+ if args.ignore_learned_positional_embeddings:
376
+ del transformer.patch_embed.pos_embedding
377
+ transformer.patch_embed.use_learned_positional_embeddings = False
378
+ transformer.config.use_learned_positional_embeddings = False
379
+
380
+ vae = AutoencoderKLCogVideoX.from_pretrained(
381
+ args.pretrained_model_name_or_path,
382
+ subfolder="vae",
383
+ revision=args.revision,
384
+ variant=args.variant,
385
+ )
386
+
387
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
388
+
389
+ if args.enable_slicing:
390
+ vae.enable_slicing()
391
+ if args.enable_tiling:
392
+ vae.enable_tiling()
393
+
394
+ # We only train the additional adapter LoRA layers
395
+ text_encoder.requires_grad_(False)
396
+ transformer.requires_grad_(False)
397
+ vae.requires_grad_(False)
398
+
399
+ VAE_SCALING_FACTOR = vae.config.scaling_factor
400
+ VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
401
+ RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
402
+ RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
403
+
404
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
405
+ # as these weights are only used for inference, keeping weights in full precision is not required.
406
+ weight_dtype = torch.float32
407
+ if accelerator.state.deepspeed_plugin:
408
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
409
+ if (
410
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
411
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
412
+ ):
413
+ weight_dtype = torch.float16
414
+ if (
415
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
416
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
417
+ ):
418
+ weight_dtype = torch.bfloat16
419
+ else:
420
+ if accelerator.mixed_precision == "fp16":
421
+ weight_dtype = torch.float16
422
+ elif accelerator.mixed_precision == "bf16":
423
+ weight_dtype = torch.bfloat16
424
+
425
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
426
+ # due to pytorch#99272, MPS does not yet support bfloat16.
427
+ raise ValueError(
428
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
429
+ )
430
+
431
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
432
+ transformer.to(accelerator.device, dtype=weight_dtype)
433
+ vae.to(accelerator.device, dtype=weight_dtype)
434
+
435
+ if args.gradient_checkpointing:
436
+ transformer.enable_gradient_checkpointing()
437
+
438
+ # now we will add new LoRA weights to the attention layers
439
+ transformer_lora_config = LoraConfig(
440
+ r=args.rank,
441
+ lora_alpha=args.lora_alpha,
442
+ init_lora_weights=True,
443
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
444
+ )
445
+ transformer.add_adapter(transformer_lora_config)
446
+
447
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
448
+ def save_model_hook(models, weights, output_dir):
449
+ if accelerator.is_main_process:
450
+ transformer_lora_layers_to_save = None
451
+
452
+ for model in models:
453
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
454
+ model = unwrap_model(accelerator, model)
455
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
456
+ else:
457
+ raise ValueError(f"Unexpected save model: {model.__class__}")
458
+
459
+ # make sure to pop weight so that corresponding model is not saved again
460
+ if weights:
461
+ weights.pop()
462
+
463
+ CogVideoXImageToVideoPipeline.save_lora_weights(
464
+ output_dir,
465
+ transformer_lora_layers=transformer_lora_layers_to_save,
466
+ )
467
+
468
+ def load_model_hook(models, input_dir):
469
+ transformer_ = None
470
+
471
+ # This is a bit of a hack but I don't know any other solution.
472
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
473
+ while len(models) > 0:
474
+ model = models.pop()
475
+
476
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
477
+ transformer_ = unwrap_model(accelerator, model)
478
+ else:
479
+ raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
480
+ else:
481
+ transformer_ = CogVideoXTransformer3DModel.from_pretrained(
482
+ args.pretrained_model_name_or_path, subfolder="transformer"
483
+ )
484
+ transformer_.add_adapter(transformer_lora_config)
485
+
486
+ lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
487
+
488
+ transformer_state_dict = {
489
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
490
+ }
491
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
492
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
493
+ if incompatible_keys is not None:
494
+ # check only for unexpected keys
495
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
496
+ if unexpected_keys:
497
+ logger.warning(
498
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
499
+ f" {unexpected_keys}. "
500
+ )
501
+
502
+ # Make sure the trainable params are in float32. This is again needed since the base models
503
+ # are in `weight_dtype`. More details:
504
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
505
+ if args.mixed_precision == "fp16":
506
+ # only upcast trainable parameters (LoRA) into fp32
507
+ cast_training_params([transformer_])
508
+
509
+ accelerator.register_save_state_pre_hook(save_model_hook)
510
+ accelerator.register_load_state_pre_hook(load_model_hook)
511
+
512
+ # Enable TF32 for faster training on Ampere GPUs,
513
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
514
+ if args.allow_tf32 and torch.cuda.is_available():
515
+ torch.backends.cuda.matmul.allow_tf32 = True
516
+
517
+ if args.scale_lr:
518
+ args.learning_rate = (
519
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
520
+ )
521
+
522
+ # Make sure the trainable params are in float32.
523
+ if args.mixed_precision == "fp16":
524
+ # only upcast trainable parameters (LoRA) into fp32
525
+ cast_training_params([transformer], dtype=torch.float32)
526
+
527
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
528
+
529
+ # Optimization parameters
530
+ transformer_parameters_with_lr = {
531
+ "params": transformer_lora_parameters,
532
+ "lr": args.learning_rate,
533
+ }
534
+ params_to_optimize = [transformer_parameters_with_lr]
535
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
536
+
537
+ use_deepspeed_optimizer = (
538
+ accelerator.state.deepspeed_plugin is not None
539
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
540
+ )
541
+ use_deepspeed_scheduler = (
542
+ accelerator.state.deepspeed_plugin is not None
543
+ and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
544
+ )
545
+
546
+ optimizer = get_optimizer(
547
+ params_to_optimize=params_to_optimize,
548
+ optimizer_name=args.optimizer,
549
+ learning_rate=args.learning_rate,
550
+ beta1=args.beta1,
551
+ beta2=args.beta2,
552
+ beta3=args.beta3,
553
+ epsilon=args.epsilon,
554
+ weight_decay=args.weight_decay,
555
+ prodigy_decouple=args.prodigy_decouple,
556
+ prodigy_use_bias_correction=args.prodigy_use_bias_correction,
557
+ prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
558
+ use_8bit=args.use_8bit,
559
+ use_4bit=args.use_4bit,
560
+ use_torchao=args.use_torchao,
561
+ use_deepspeed=use_deepspeed_optimizer,
562
+ use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
563
+ offload_gradients=args.offload_gradients,
564
+ )
565
+
566
+ # Dataset and DataLoader
567
+ dataset_init_kwargs = {
568
+ "data_root": args.data_root,
569
+ "dataset_file": args.dataset_file,
570
+ "caption_column": args.caption_column,
571
+ "video_column": args.video_column,
572
+ "max_num_frames": args.max_num_frames,
573
+ "id_token": args.id_token,
574
+ "height_buckets": args.height_buckets,
575
+ "width_buckets": args.width_buckets,
576
+ "frame_buckets": args.frame_buckets,
577
+ "load_tensors": args.load_tensors,
578
+ "random_flip": args.random_flip,
579
+ "image_to_video": True,
580
+ }
581
+ if args.video_reshape_mode is None:
582
+ train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
583
+ else:
584
+ train_dataset = VideoDatasetWithResizeAndRectangleCrop(
585
+ video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
586
+ )
587
+
588
+ collate_fn = CollateFunction(weight_dtype, args.load_tensors)
589
+
590
+ train_dataloader = DataLoader(
591
+ train_dataset,
592
+ batch_size=1,
593
+ sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
594
+ collate_fn=collate_fn,
595
+ num_workers=args.dataloader_num_workers,
596
+ pin_memory=args.pin_memory,
597
+ )
598
+
599
+ # Scheduler and math around the number of training steps.
600
+ overrode_max_train_steps = False
601
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
602
+ if args.max_train_steps is None:
603
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
604
+ overrode_max_train_steps = True
605
+
606
+ if args.use_cpu_offload_optimizer:
607
+ lr_scheduler = None
608
+ accelerator.print(
609
+ "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
610
+ "you are training with those settings, they will be ignored."
611
+ )
612
+ else:
613
+ if use_deepspeed_scheduler:
614
+ from accelerate.utils import DummyScheduler
615
+
616
+ lr_scheduler = DummyScheduler(
617
+ name=args.lr_scheduler,
618
+ optimizer=optimizer,
619
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
620
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
621
+ )
622
+ else:
623
+ lr_scheduler = get_scheduler(
624
+ args.lr_scheduler,
625
+ optimizer=optimizer,
626
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
627
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
628
+ num_cycles=args.lr_num_cycles,
629
+ power=args.lr_power,
630
+ )
631
+
632
+ # Prepare everything with our `accelerator`.
633
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
634
+ transformer, optimizer, train_dataloader, lr_scheduler
635
+ )
636
+
637
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
638
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
639
+ if overrode_max_train_steps:
640
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
641
+ # Afterwards we recalculate our number of training epochs
642
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
643
+
644
+ # We need to initialize the trackers we use, and also store our configuration.
645
+ # The trackers initializes automatically on the main process.
646
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
647
+ tracker_name = args.tracker_name or "cogvideox-lora"
648
+ accelerator.init_trackers(tracker_name, config=vars(args))
649
+
650
+ accelerator.print("===== Memory before training =====")
651
+ reset_memory(accelerator.device)
652
+ print_memory(accelerator.device)
653
+
654
+ # Train!
655
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
656
+
657
+ accelerator.print("***** Running training *****")
658
+ accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
659
+ accelerator.print(f" Num examples = {len(train_dataset)}")
660
+ accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
661
+ accelerator.print(f" Num epochs = {args.num_train_epochs}")
662
+ accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
663
+ accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
664
+ accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
665
+ accelerator.print(f" Total optimization steps = {args.max_train_steps}")
666
+ global_step = 0
667
+ first_epoch = 0
668
+
669
+ # Potentially load in the weights and states from a previous save
670
+ if not args.resume_from_checkpoint:
671
+ initial_global_step = 0
672
+ else:
673
+ if args.resume_from_checkpoint != "latest":
674
+ path = os.path.basename(args.resume_from_checkpoint)
675
+ else:
676
+ # Get the most recent checkpoint
677
+ dirs = os.listdir(args.output_dir)
678
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
679
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
680
+ path = dirs[-1] if len(dirs) > 0 else None
681
+
682
+ if path is None:
683
+ accelerator.print(
684
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
685
+ )
686
+ args.resume_from_checkpoint = None
687
+ initial_global_step = 0
688
+ else:
689
+ accelerator.print(f"Resuming from checkpoint {path}")
690
+ accelerator.load_state(os.path.join(args.output_dir, path))
691
+ global_step = int(path.split("-")[1])
692
+
693
+ initial_global_step = global_step
694
+ first_epoch = global_step // num_update_steps_per_epoch
695
+
696
+ progress_bar = tqdm(
697
+ range(0, args.max_train_steps),
698
+ initial=initial_global_step,
699
+ desc="Steps",
700
+ # Only show the progress bar once on each machine.
701
+ disable=not accelerator.is_local_main_process,
702
+ )
703
+
704
+ # For DeepSpeed training
705
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
706
+
707
+ if args.load_tensors:
708
+ del vae, text_encoder
709
+ gc.collect()
710
+ torch.cuda.empty_cache()
711
+ torch.cuda.synchronize(accelerator.device)
712
+
713
+ alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
714
+
715
+ for epoch in range(first_epoch, args.num_train_epochs):
716
+ transformer.train()
717
+
718
+ for step, batch in enumerate(train_dataloader):
719
+ models_to_accumulate = [transformer]
720
+ logs = {}
721
+
722
+ with accelerator.accumulate(models_to_accumulate):
723
+ images = batch["images"].to(accelerator.device, non_blocking=True)
724
+ videos = batch["videos"].to(accelerator.device, non_blocking=True)
725
+ prompts = batch["prompts"]
726
+
727
+ # Encode videos
728
+ if not args.load_tensors:
729
+ images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
730
+ image_noise_sigma = torch.normal(
731
+ mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
732
+ )
733
+ image_noise_sigma = torch.exp(image_noise_sigma)
734
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
735
+ image_latent_dist = vae.encode(noisy_images).latent_dist
736
+
737
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
738
+ latent_dist = vae.encode(videos).latent_dist
739
+ else:
740
+ image_latent_dist = DiagonalGaussianDistribution(images)
741
+ latent_dist = DiagonalGaussianDistribution(videos)
742
+
743
+ image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
744
+ image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
745
+ image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
746
+
747
+ video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
748
+ video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
749
+ video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
750
+
751
+ padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
752
+ latent_padding = image_latents.new_zeros(padding_shape)
753
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
754
+
755
+ if random.random() < args.noised_image_dropout:
756
+ image_latents = torch.zeros_like(image_latents)
757
+
758
+ # Encode prompts
759
+ if not args.load_tensors:
760
+ prompt_embeds = compute_prompt_embeddings(
761
+ tokenizer,
762
+ text_encoder,
763
+ prompts,
764
+ model_config.max_text_seq_length,
765
+ accelerator.device,
766
+ weight_dtype,
767
+ requires_grad=False,
768
+ )
769
+ else:
770
+ prompt_embeds = prompts.to(dtype=weight_dtype)
771
+
772
+ # Sample noise that will be added to the latents
773
+ noise = torch.randn_like(video_latents)
774
+ batch_size, num_frames, num_channels, height, width = video_latents.shape
775
+
776
+ # Sample a random timestep for each image
777
+ timesteps = torch.randint(
778
+ 0,
779
+ scheduler.config.num_train_timesteps,
780
+ (batch_size,),
781
+ dtype=torch.int64,
782
+ device=accelerator.device,
783
+ )
784
+
785
+ # Prepare rotary embeds
786
+ image_rotary_emb = (
787
+ prepare_rotary_positional_embeddings(
788
+ height=height * VAE_SCALE_FACTOR_SPATIAL,
789
+ width=width * VAE_SCALE_FACTOR_SPATIAL,
790
+ num_frames=num_frames,
791
+ vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
792
+ patch_size=model_config.patch_size,
793
+ patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
794
+ attention_head_dim=model_config.attention_head_dim,
795
+ device=accelerator.device,
796
+ base_height=RoPE_BASE_HEIGHT,
797
+ base_width=RoPE_BASE_WIDTH,
798
+ )
799
+ if model_config.use_rotary_positional_embeddings
800
+ else None
801
+ )
802
+
803
+ # Add noise to the model input according to the noise magnitude at each timestep
804
+ # (this is the forward diffusion process)
805
+ noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
806
+ noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
807
+
808
+ ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
809
+ ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
810
+ # Predict the noise residual
811
+ model_output = transformer(
812
+ hidden_states=noisy_model_input,
813
+ encoder_hidden_states=prompt_embeds,
814
+ timestep=timesteps,
815
+ ofs=ofs_emb,
816
+ image_rotary_emb=image_rotary_emb,
817
+ return_dict=False,
818
+ )[0]
819
+
820
+ model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
821
+
822
+ weights = 1 / (1 - alphas_cumprod[timesteps])
823
+ while len(weights.shape) < len(model_pred.shape):
824
+ weights = weights.unsqueeze(-1)
825
+
826
+ target = video_latents
827
+
828
+ loss = torch.mean(
829
+ (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
830
+ dim=1,
831
+ )
832
+ loss = loss.mean()
833
+ accelerator.backward(loss)
834
+
835
+ if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
836
+ gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
837
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
838
+ gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
839
+ logs.update(
840
+ {
841
+ "gradient_norm_before_clip": gradient_norm_before_clip,
842
+ "gradient_norm_after_clip": gradient_norm_after_clip,
843
+ }
844
+ )
845
+
846
+ if accelerator.state.deepspeed_plugin is None:
847
+ optimizer.step()
848
+ optimizer.zero_grad()
849
+
850
+ if not args.use_cpu_offload_optimizer:
851
+ lr_scheduler.step()
852
+
853
+ # Checks if the accelerator has performed an optimization step behind the scenes
854
+ if accelerator.sync_gradients:
855
+ progress_bar.update(1)
856
+ global_step += 1
857
+
858
+ # Checkpointing
859
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
860
+ if global_step % args.checkpointing_steps == 0:
861
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
862
+ if args.checkpoints_total_limit is not None:
863
+ checkpoints = os.listdir(args.output_dir)
864
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
865
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
866
+
867
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
868
+ if len(checkpoints) >= args.checkpoints_total_limit:
869
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
870
+ removing_checkpoints = checkpoints[0:num_to_remove]
871
+
872
+ logger.info(
873
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
874
+ )
875
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
876
+
877
+ for removing_checkpoint in removing_checkpoints:
878
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
879
+ shutil.rmtree(removing_checkpoint)
880
+
881
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
882
+ accelerator.save_state(save_path)
883
+ logger.info(f"Saved state to {save_path}")
884
+
885
+ # Validation
886
+ should_run_validation = args.validation_prompt is not None and (
887
+ args.validation_steps is not None and global_step % args.validation_steps == 0
888
+ )
889
+ if should_run_validation:
890
+ run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
891
+
892
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
893
+ logs.update(
894
+ {
895
+ "loss": loss.detach().item(),
896
+ "lr": last_lr,
897
+ }
898
+ )
899
+ progress_bar.set_postfix(**logs)
900
+ accelerator.log(logs, step=global_step)
901
+
902
+ if global_step >= args.max_train_steps:
903
+ break
904
+
905
+ if accelerator.is_main_process:
906
+ should_run_validation = args.validation_prompt is not None and (
907
+ args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
908
+ )
909
+ if should_run_validation:
910
+ run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
911
+
912
+ accelerator.wait_for_everyone()
913
+
914
+ if accelerator.is_main_process:
915
+ transformer = unwrap_model(accelerator, transformer)
916
+ dtype = (
917
+ torch.float16
918
+ if args.mixed_precision == "fp16"
919
+ else torch.bfloat16
920
+ if args.mixed_precision == "bf16"
921
+ else torch.float32
922
+ )
923
+ transformer = transformer.to(dtype)
924
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
925
+
926
+ CogVideoXImageToVideoPipeline.save_lora_weights(
927
+ save_directory=args.output_dir,
928
+ transformer_lora_layers=transformer_lora_layers,
929
+ )
930
+
931
+ # Cleanup trained models to save memory
932
+ if args.load_tensors:
933
+ del transformer
934
+ else:
935
+ del transformer, text_encoder, vae
936
+
937
+ gc.collect()
938
+ torch.cuda.empty_cache()
939
+ torch.cuda.synchronize(accelerator.device)
940
+
941
+ accelerator.print("===== Memory before testing =====")
942
+ print_memory(accelerator.device)
943
+ reset_memory(accelerator.device)
944
+
945
+ # Final test inference
946
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
947
+ args.pretrained_model_name_or_path,
948
+ revision=args.revision,
949
+ variant=args.variant,
950
+ torch_dtype=weight_dtype,
951
+ )
952
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
953
+
954
+ if args.enable_slicing:
955
+ pipe.vae.enable_slicing()
956
+ if args.enable_tiling:
957
+ pipe.vae.enable_tiling()
958
+ if args.enable_model_cpu_offload:
959
+ pipe.enable_model_cpu_offload()
960
+
961
+ # Load LoRA weights
962
+ lora_scaling = args.lora_alpha / args.rank
963
+ pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
964
+ pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
965
+
966
+ # Run inference
967
+ validation_outputs = []
968
+ if args.validation_prompt and args.num_validation_videos > 0:
969
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
970
+ validation_images = args.validation_images.split(args.validation_prompt_separator)
971
+ for validation_image, validation_prompt in zip(validation_images, validation_prompts):
972
+ pipeline_args = {
973
+ "image": load_image(validation_image),
974
+ "prompt": validation_prompt,
975
+ "guidance_scale": args.guidance_scale,
976
+ "use_dynamic_cfg": args.use_dynamic_cfg,
977
+ "height": args.height,
978
+ "width": args.width,
979
+ }
980
+
981
+ video = log_validation(
982
+ accelerator=accelerator,
983
+ pipe=pipe,
984
+ args=args,
985
+ pipeline_args=pipeline_args,
986
+ is_final_validation=True,
987
+ )
988
+ validation_outputs.extend(video)
989
+
990
+ accelerator.print("===== Memory after testing =====")
991
+ print_memory(accelerator.device)
992
+ reset_memory(accelerator.device)
993
+ torch.cuda.synchronize(accelerator.device)
994
+
995
+ if args.push_to_hub:
996
+ save_model_card(
997
+ repo_id,
998
+ videos=validation_outputs,
999
+ base_model=args.pretrained_model_name_or_path,
1000
+ validation_prompt=args.validation_prompt,
1001
+ repo_folder=args.output_dir,
1002
+ fps=args.fps,
1003
+ )
1004
+ upload_folder(
1005
+ repo_id=repo_id,
1006
+ folder_path=args.output_dir,
1007
+ commit_message="End of training",
1008
+ ignore_patterns=["step_*", "epoch_*"],
1009
+ )
1010
+
1011
+ accelerator.end_training()
1012
+
1013
+
1014
+ if __name__ == "__main__":
1015
+ args = get_args()
1016
+ main(args)
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ import shutil
22
+ from datetime import timedelta
23
+ from pathlib import Path
24
+ from typing import Any, Dict
25
+
26
+ import diffusers
27
+ import torch
28
+ import transformers
29
+ import wandb
30
+ from accelerate import Accelerator, DistributedType, init_empty_weights
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import (
33
+ DistributedDataParallelKwargs,
34
+ InitProcessGroupKwargs,
35
+ ProjectConfiguration,
36
+ set_seed,
37
+ )
38
+ from diffusers import (
39
+ AutoencoderKLCogVideoX,
40
+ CogVideoXDPMScheduler,
41
+ CogVideoXImageToVideoPipeline,
42
+ CogVideoXTransformer3DModel,
43
+ )
44
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
45
+ from diffusers.optimization import get_scheduler
46
+ from diffusers.training_utils import cast_training_params
47
+ from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
48
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
+ from huggingface_hub import create_repo, upload_folder
50
+ from torch.utils.data import DataLoader
51
+ from tqdm.auto import tqdm
52
+ from transformers import AutoTokenizer, T5EncoderModel
53
+
54
+
55
+ from args import get_args # isort:skip
56
+ from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
57
+ from text_encoder import compute_prompt_embeddings # isort:skip
58
+ from utils import (
59
+ get_gradient_norm,
60
+ get_optimizer,
61
+ prepare_rotary_positional_embeddings,
62
+ print_memory,
63
+ reset_memory,
64
+ unwrap_model,
65
+ )
66
+
67
+
68
+ logger = get_logger(__name__)
69
+
70
+
71
+ def save_model_card(
72
+ repo_id: str,
73
+ videos=None,
74
+ base_model: str = None,
75
+ validation_prompt=None,
76
+ repo_folder=None,
77
+ fps=8,
78
+ ):
79
+ widget_dict = []
80
+ if videos is not None:
81
+ for i, video in enumerate(videos):
82
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
83
+ widget_dict.append(
84
+ {
85
+ "text": validation_prompt if validation_prompt else " ",
86
+ "output": {"url": f"video_{i}.mp4"},
87
+ }
88
+ )
89
+
90
+ model_description = f"""
91
+ # CogVideoX Full Finetune
92
+
93
+ <Gallery />
94
+
95
+ ## Model description
96
+
97
+ This is a full finetune of the CogVideoX model `{base_model}`.
98
+
99
+ ## License
100
+
101
+ Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
102
+ """
103
+ model_card = load_or_create_model_card(
104
+ repo_id_or_path=repo_id,
105
+ from_training=True,
106
+ license="other",
107
+ base_model=base_model,
108
+ prompt=validation_prompt,
109
+ model_description=model_description,
110
+ widget=widget_dict,
111
+ )
112
+ tags = [
113
+ "text-to-video",
114
+ "image-to-video",
115
+ "diffusers-training",
116
+ "diffusers",
117
+ "cogvideox",
118
+ "cogvideox-diffusers",
119
+ ]
120
+
121
+ model_card = populate_model_card(model_card, tags=tags)
122
+ model_card.save(os.path.join(repo_folder, "README.md"))
123
+
124
+
125
+ def log_validation(
126
+ accelerator: Accelerator,
127
+ pipe: CogVideoXImageToVideoPipeline,
128
+ args: Dict[str, Any],
129
+ pipeline_args: Dict[str, Any],
130
+ is_final_validation: bool = False,
131
+ ):
132
+ logger.info(
133
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
134
+ )
135
+
136
+ pipe = pipe.to(accelerator.device)
137
+
138
+ # run inference
139
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
140
+
141
+ videos = []
142
+ for _ in range(args.num_validation_videos):
143
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
144
+ videos.append(video)
145
+
146
+ for tracker in accelerator.trackers:
147
+ phase_name = "test" if is_final_validation else "validation"
148
+ if tracker.name == "wandb":
149
+ video_filenames = []
150
+ for i, video in enumerate(videos):
151
+ prompt = (
152
+ pipeline_args["prompt"][:25]
153
+ .replace(" ", "_")
154
+ .replace(" ", "_")
155
+ .replace("'", "_")
156
+ .replace('"', "_")
157
+ .replace("/", "_")
158
+ )
159
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
160
+ export_to_video(video, filename, fps=8)
161
+ video_filenames.append(filename)
162
+
163
+ tracker.log(
164
+ {
165
+ phase_name: [
166
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
167
+ for i, filename in enumerate(video_filenames)
168
+ ]
169
+ }
170
+ )
171
+
172
+ return videos
173
+
174
+
175
+ def run_validation(
176
+ args: Dict[str, Any],
177
+ accelerator: Accelerator,
178
+ transformer,
179
+ scheduler,
180
+ model_config: Dict[str, Any],
181
+ weight_dtype: torch.dtype,
182
+ ) -> None:
183
+ accelerator.print("===== Memory before validation =====")
184
+ print_memory(accelerator.device)
185
+ torch.cuda.synchronize(accelerator.device)
186
+
187
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
188
+ args.pretrained_model_name_or_path,
189
+ transformer=unwrap_model(accelerator, transformer),
190
+ scheduler=scheduler,
191
+ revision=args.revision,
192
+ variant=args.variant,
193
+ torch_dtype=weight_dtype,
194
+ )
195
+
196
+ if args.enable_slicing:
197
+ pipe.vae.enable_slicing()
198
+ if args.enable_tiling:
199
+ pipe.vae.enable_tiling()
200
+ if args.enable_model_cpu_offload:
201
+ pipe.enable_model_cpu_offload()
202
+
203
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
204
+ validation_images = args.validation_images.split(args.validation_prompt_separator)
205
+ for validation_image, validation_prompt in zip(validation_images, validation_prompts):
206
+ pipeline_args = {
207
+ "image": load_image(validation_image),
208
+ "prompt": validation_prompt,
209
+ "guidance_scale": args.guidance_scale,
210
+ "use_dynamic_cfg": args.use_dynamic_cfg,
211
+ "height": args.height,
212
+ "width": args.width,
213
+ "max_sequence_length": model_config.max_text_seq_length,
214
+ }
215
+
216
+ log_validation(
217
+ pipe=pipe,
218
+ args=args,
219
+ accelerator=accelerator,
220
+ pipeline_args=pipeline_args,
221
+ )
222
+
223
+ accelerator.print("===== Memory after validation =====")
224
+ print_memory(accelerator.device)
225
+ reset_memory(accelerator.device)
226
+
227
+ del pipe
228
+ gc.collect()
229
+ torch.cuda.empty_cache()
230
+ torch.cuda.synchronize(accelerator.device)
231
+
232
+
233
+ class CollateFunction:
234
+ def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
235
+ self.weight_dtype = weight_dtype
236
+ self.load_tensors = load_tensors
237
+
238
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
239
+ prompts = [x["prompt"] for x in data[0]]
240
+
241
+ if self.load_tensors:
242
+ prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
243
+
244
+ images = [x["image"] for x in data[0]]
245
+ images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
246
+
247
+ videos = [x["video"] for x in data[0]]
248
+ videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
249
+
250
+ return {
251
+ "images": images,
252
+ "videos": videos,
253
+ "prompts": prompts,
254
+ }
255
+
256
+
257
+ def main(args):
258
+ if args.report_to == "wandb" and args.hub_token is not None:
259
+ raise ValueError(
260
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
261
+ " Please use `huggingface-cli login` to authenticate with the Hub."
262
+ )
263
+
264
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
265
+ # due to pytorch#99272, MPS does not yet support bfloat16.
266
+ raise ValueError(
267
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
268
+ )
269
+
270
+ logging_dir = Path(args.output_dir, args.logging_dir)
271
+
272
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
273
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
274
+ init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
275
+ accelerator = Accelerator(
276
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
277
+ mixed_precision=args.mixed_precision,
278
+ log_with=args.report_to,
279
+ project_config=accelerator_project_config,
280
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
281
+ )
282
+
283
+ # Disable AMP for MPS.
284
+ if torch.backends.mps.is_available():
285
+ accelerator.native_amp = False
286
+
287
+ # Make one log on every process with the configuration for debugging.
288
+ logging.basicConfig(
289
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
290
+ datefmt="%m/%d/%Y %H:%M:%S",
291
+ level=logging.INFO,
292
+ )
293
+ logger.info(accelerator.state, main_process_only=False)
294
+ if accelerator.is_local_main_process:
295
+ transformers.utils.logging.set_verbosity_warning()
296
+ diffusers.utils.logging.set_verbosity_info()
297
+ else:
298
+ transformers.utils.logging.set_verbosity_error()
299
+ diffusers.utils.logging.set_verbosity_error()
300
+
301
+ # If passed along, set the training seed now.
302
+ if args.seed is not None:
303
+ set_seed(args.seed)
304
+
305
+ # Handle the repository creation
306
+ if accelerator.is_main_process:
307
+ if args.output_dir is not None:
308
+ os.makedirs(args.output_dir, exist_ok=True)
309
+
310
+ if args.push_to_hub:
311
+ repo_id = create_repo(
312
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
313
+ exist_ok=True,
314
+ ).repo_id
315
+
316
+ # Prepare models and scheduler
317
+ tokenizer = AutoTokenizer.from_pretrained(
318
+ args.pretrained_model_name_or_path,
319
+ subfolder="tokenizer",
320
+ revision=args.revision,
321
+ )
322
+
323
+ text_encoder = T5EncoderModel.from_pretrained(
324
+ args.pretrained_model_name_or_path,
325
+ subfolder="text_encoder",
326
+ revision=args.revision,
327
+ )
328
+
329
+ # CogVideoX-2b weights are stored in float16
330
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
331
+ load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
332
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
333
+ args.pretrained_model_name_or_path,
334
+ subfolder="transformer",
335
+ torch_dtype=load_dtype,
336
+ revision=args.revision,
337
+ variant=args.variant,
338
+ )
339
+
340
+ if args.ignore_learned_positional_embeddings:
341
+ del transformer.patch_embed.pos_embedding
342
+ transformer.patch_embed.use_learned_positional_embeddings = False
343
+ transformer.config.use_learned_positional_embeddings = False
344
+
345
+ vae = AutoencoderKLCogVideoX.from_pretrained(
346
+ args.pretrained_model_name_or_path,
347
+ subfolder="vae",
348
+ revision=args.revision,
349
+ variant=args.variant,
350
+ )
351
+
352
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
353
+
354
+ if args.enable_slicing:
355
+ vae.enable_slicing()
356
+ if args.enable_tiling:
357
+ vae.enable_tiling()
358
+
359
+ text_encoder.requires_grad_(False)
360
+ vae.requires_grad_(False)
361
+ transformer.requires_grad_(True)
362
+
363
+ VAE_SCALING_FACTOR = vae.config.scaling_factor
364
+ VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
365
+ RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
366
+ RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
367
+
368
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
369
+ # as these weights are only used for inference, keeping weights in full precision is not required.
370
+ weight_dtype = torch.float32
371
+ if accelerator.state.deepspeed_plugin:
372
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
373
+ if (
374
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
375
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
376
+ ):
377
+ weight_dtype = torch.float16
378
+ if (
379
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
380
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
381
+ ):
382
+ weight_dtype = torch.bfloat16
383
+ else:
384
+ if accelerator.mixed_precision == "fp16":
385
+ weight_dtype = torch.float16
386
+ elif accelerator.mixed_precision == "bf16":
387
+ weight_dtype = torch.bfloat16
388
+
389
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
390
+ # due to pytorch#99272, MPS does not yet support bfloat16.
391
+ raise ValueError(
392
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
393
+ )
394
+
395
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
396
+ transformer.to(accelerator.device, dtype=weight_dtype)
397
+ vae.to(accelerator.device, dtype=weight_dtype)
398
+
399
+ if args.gradient_checkpointing:
400
+ transformer.enable_gradient_checkpointing()
401
+
402
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
403
+ def save_model_hook(models, weights, output_dir):
404
+ if accelerator.is_main_process:
405
+ for model in models:
406
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
407
+ model = unwrap_model(accelerator, model)
408
+ model.save_pretrained(
409
+ os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
410
+ )
411
+ else:
412
+ raise ValueError(f"Unexpected save model: {model.__class__}")
413
+
414
+ # make sure to pop weight so that corresponding model is not saved again
415
+ if weights:
416
+ weights.pop()
417
+
418
+ def load_model_hook(models, input_dir):
419
+ transformer_ = None
420
+ init_under_meta = False
421
+
422
+ # This is a bit of a hack but I don't know any other solution.
423
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
424
+ while len(models) > 0:
425
+ model = models.pop()
426
+
427
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
428
+ transformer_ = unwrap_model(accelerator, model)
429
+ else:
430
+ raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
431
+ else:
432
+ with init_empty_weights():
433
+ transformer_ = CogVideoXTransformer3DModel.from_config(
434
+ args.pretrained_model_name_or_path, subfolder="transformer"
435
+ )
436
+ init_under_meta = True
437
+
438
+ load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
439
+ transformer_.register_to_config(**load_model.config)
440
+ transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
441
+ del load_model
442
+
443
+ # Make sure the trainable params are in float32. This is again needed since the base models
444
+ # are in `weight_dtype`. More details:
445
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
446
+ if args.mixed_precision == "fp16":
447
+ cast_training_params([transformer_])
448
+
449
+ accelerator.register_save_state_pre_hook(save_model_hook)
450
+ accelerator.register_load_state_pre_hook(load_model_hook)
451
+
452
+ # Enable TF32 for faster training on Ampere GPUs,
453
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
454
+ if args.allow_tf32 and torch.cuda.is_available():
455
+ torch.backends.cuda.matmul.allow_tf32 = True
456
+
457
+ if args.scale_lr:
458
+ args.learning_rate = (
459
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
460
+ )
461
+
462
+ # Make sure the trainable params are in float32.
463
+ if args.mixed_precision == "fp16":
464
+ cast_training_params([transformer], dtype=torch.float32)
465
+
466
+ transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
467
+
468
+ # Optimization parameters
469
+ transformer_parameters_with_lr = {
470
+ "params": transformer_parameters,
471
+ "lr": args.learning_rate,
472
+ }
473
+ params_to_optimize = [transformer_parameters_with_lr]
474
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
475
+
476
+ use_deepspeed_optimizer = (
477
+ accelerator.state.deepspeed_plugin is not None
478
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
479
+ )
480
+ use_deepspeed_scheduler = (
481
+ accelerator.state.deepspeed_plugin is not None
482
+ and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
483
+ )
484
+
485
+ optimizer = get_optimizer(
486
+ params_to_optimize=params_to_optimize,
487
+ optimizer_name=args.optimizer,
488
+ learning_rate=args.learning_rate,
489
+ beta1=args.beta1,
490
+ beta2=args.beta2,
491
+ beta3=args.beta3,
492
+ epsilon=args.epsilon,
493
+ weight_decay=args.weight_decay,
494
+ prodigy_decouple=args.prodigy_decouple,
495
+ prodigy_use_bias_correction=args.prodigy_use_bias_correction,
496
+ prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
497
+ use_8bit=args.use_8bit,
498
+ use_4bit=args.use_4bit,
499
+ use_torchao=args.use_torchao,
500
+ use_deepspeed=use_deepspeed_optimizer,
501
+ use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
502
+ offload_gradients=args.offload_gradients,
503
+ )
504
+
505
+ # Dataset and DataLoader
506
+ dataset_init_kwargs = {
507
+ "data_root": args.data_root,
508
+ "dataset_file": args.dataset_file,
509
+ "caption_column": args.caption_column,
510
+ "video_column": args.video_column,
511
+ "max_num_frames": args.max_num_frames,
512
+ "id_token": args.id_token,
513
+ "height_buckets": args.height_buckets,
514
+ "width_buckets": args.width_buckets,
515
+ "frame_buckets": args.frame_buckets,
516
+ "load_tensors": args.load_tensors,
517
+ "random_flip": args.random_flip,
518
+ "image_to_video": True,
519
+ }
520
+ if args.video_reshape_mode is None:
521
+ train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
522
+ else:
523
+ train_dataset = VideoDatasetWithResizeAndRectangleCrop(
524
+ video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
525
+ )
526
+
527
+ collate_fn = CollateFunction(weight_dtype, args.load_tensors)
528
+
529
+ train_dataloader = DataLoader(
530
+ train_dataset,
531
+ batch_size=1,
532
+ sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
533
+ collate_fn=collate_fn,
534
+ num_workers=args.dataloader_num_workers,
535
+ pin_memory=args.pin_memory,
536
+ )
537
+
538
+ # Scheduler and math around the number of training steps.
539
+ overrode_max_train_steps = False
540
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
541
+ if args.max_train_steps is None:
542
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
543
+ overrode_max_train_steps = True
544
+
545
+ if args.use_cpu_offload_optimizer:
546
+ lr_scheduler = None
547
+ accelerator.print(
548
+ "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
549
+ "you are training with those settings, they will be ignored."
550
+ )
551
+ else:
552
+ if use_deepspeed_scheduler:
553
+ from accelerate.utils import DummyScheduler
554
+
555
+ lr_scheduler = DummyScheduler(
556
+ name=args.lr_scheduler,
557
+ optimizer=optimizer,
558
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
559
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
560
+ )
561
+ else:
562
+ lr_scheduler = get_scheduler(
563
+ args.lr_scheduler,
564
+ optimizer=optimizer,
565
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
566
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
567
+ num_cycles=args.lr_num_cycles,
568
+ power=args.lr_power,
569
+ )
570
+
571
+ # Prepare everything with our `accelerator`.
572
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
573
+ transformer, optimizer, train_dataloader, lr_scheduler
574
+ )
575
+
576
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
577
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
578
+ if overrode_max_train_steps:
579
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
580
+ # Afterwards we recalculate our number of training epochs
581
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
582
+
583
+ # We need to initialize the trackers we use, and also store our configuration.
584
+ # The trackers initializes automatically on the main process.
585
+ if accelerator.is_main_process:
586
+ tracker_name = args.tracker_name or "cogvideox-sft"
587
+ accelerator.init_trackers(tracker_name, config=vars(args))
588
+
589
+ accelerator.print("===== Memory before training =====")
590
+ reset_memory(accelerator.device)
591
+ print_memory(accelerator.device)
592
+
593
+ # Train!
594
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
595
+
596
+ accelerator.print("***** Running training *****")
597
+ accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
598
+ accelerator.print(f" Num examples = {len(train_dataset)}")
599
+ accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
600
+ accelerator.print(f" Num epochs = {args.num_train_epochs}")
601
+ accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
602
+ accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
603
+ accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
604
+ accelerator.print(f" Total optimization steps = {args.max_train_steps}")
605
+ global_step = 0
606
+ first_epoch = 0
607
+
608
+ # Potentially load in the weights and states from a previous save
609
+ if not args.resume_from_checkpoint:
610
+ initial_global_step = 0
611
+ else:
612
+ if args.resume_from_checkpoint != "latest":
613
+ path = os.path.basename(args.resume_from_checkpoint)
614
+ else:
615
+ # Get the most recent checkpoint
616
+ dirs = os.listdir(args.output_dir)
617
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
618
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
619
+ path = dirs[-1] if len(dirs) > 0 else None
620
+
621
+ if path is None:
622
+ accelerator.print(
623
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
624
+ )
625
+ args.resume_from_checkpoint = None
626
+ initial_global_step = 0
627
+ else:
628
+ accelerator.print(f"Resuming from checkpoint {path}")
629
+ accelerator.load_state(os.path.join(args.output_dir, path))
630
+ global_step = int(path.split("-")[1])
631
+
632
+ initial_global_step = global_step
633
+ first_epoch = global_step // num_update_steps_per_epoch
634
+
635
+ progress_bar = tqdm(
636
+ range(0, args.max_train_steps),
637
+ initial=initial_global_step,
638
+ desc="Steps",
639
+ # Only show the progress bar once on each machine.
640
+ disable=not accelerator.is_local_main_process,
641
+ )
642
+
643
+ # For DeepSpeed training
644
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
645
+
646
+ if args.load_tensors:
647
+ del vae, text_encoder
648
+ gc.collect()
649
+ torch.cuda.empty_cache()
650
+ torch.cuda.synchronize(accelerator.device)
651
+
652
+ alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
653
+
654
+ for epoch in range(first_epoch, args.num_train_epochs):
655
+ transformer.train()
656
+ for step, batch in enumerate(train_dataloader):
657
+ models_to_accumulate = [transformer]
658
+ logs = {}
659
+
660
+ with accelerator.accumulate(models_to_accumulate):
661
+ images = batch["images"].to(accelerator.device, non_blocking=True)
662
+ videos = batch["videos"].to(accelerator.device, non_blocking=True)
663
+ prompts = batch["prompts"]
664
+
665
+ # Encode videos
666
+ if not args.load_tensors:
667
+ images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
668
+ image_noise_sigma = torch.normal(
669
+ mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
670
+ )
671
+ image_noise_sigma = torch.exp(image_noise_sigma)
672
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
673
+ image_latent_dist = vae.encode(noisy_images).latent_dist
674
+
675
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
676
+ latent_dist = vae.encode(videos).latent_dist
677
+ else:
678
+ image_latent_dist = DiagonalGaussianDistribution(images)
679
+ latent_dist = DiagonalGaussianDistribution(videos)
680
+
681
+ image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
682
+ image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
683
+ image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
684
+
685
+ video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
686
+ video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
687
+ video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
688
+
689
+ padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
690
+ latent_padding = image_latents.new_zeros(padding_shape)
691
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
692
+
693
+ if random.random() < args.noised_image_dropout:
694
+ image_latents = torch.zeros_like(image_latents)
695
+
696
+ # Encode prompts
697
+ if not args.load_tensors:
698
+ prompt_embeds = compute_prompt_embeddings(
699
+ tokenizer,
700
+ text_encoder,
701
+ prompts,
702
+ model_config.max_text_seq_length,
703
+ accelerator.device,
704
+ weight_dtype,
705
+ requires_grad=False,
706
+ )
707
+ else:
708
+ prompt_embeds = prompts.to(dtype=weight_dtype)
709
+
710
+ # Sample noise that will be added to the latents
711
+ noise = torch.randn_like(video_latents)
712
+ batch_size, num_frames, num_channels, height, width = video_latents.shape
713
+
714
+ # Sample a random timestep for each image
715
+ timesteps = torch.randint(
716
+ 0,
717
+ scheduler.config.num_train_timesteps,
718
+ (batch_size,),
719
+ dtype=torch.int64,
720
+ device=accelerator.device,
721
+ )
722
+
723
+ # Prepare rotary embeds
724
+ image_rotary_emb = (
725
+ prepare_rotary_positional_embeddings(
726
+ height=height * VAE_SCALE_FACTOR_SPATIAL,
727
+ width=width * VAE_SCALE_FACTOR_SPATIAL,
728
+ num_frames=num_frames,
729
+ vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
730
+ patch_size=model_config.patch_size,
731
+ patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
732
+ attention_head_dim=model_config.attention_head_dim,
733
+ device=accelerator.device,
734
+ base_height=RoPE_BASE_HEIGHT,
735
+ base_width=RoPE_BASE_WIDTH,
736
+ )
737
+ if model_config.use_rotary_positional_embeddings
738
+ else None
739
+ )
740
+
741
+ # Add noise to the model input according to the noise magnitude at each timestep
742
+ # (this is the forward diffusion process)
743
+ noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
744
+ noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
745
+ model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
746
+ ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
747
+ ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
748
+ # Predict the noise residual
749
+ model_output = transformer(
750
+ hidden_states=noisy_model_input,
751
+ encoder_hidden_states=prompt_embeds,
752
+ timestep=timesteps,
753
+ ofs=ofs_emb,
754
+ image_rotary_emb=image_rotary_emb,
755
+ return_dict=False,
756
+ )[0]
757
+
758
+ model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
759
+
760
+ weights = 1 / (1 - alphas_cumprod[timesteps])
761
+ while len(weights.shape) < len(model_pred.shape):
762
+ weights = weights.unsqueeze(-1)
763
+
764
+ target = video_latents
765
+
766
+ loss = torch.mean(
767
+ (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
768
+ dim=1,
769
+ )
770
+ loss = loss.mean()
771
+ accelerator.backward(loss)
772
+
773
+ if accelerator.sync_gradients:
774
+ gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
775
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
776
+ gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
777
+ logs.update(
778
+ {
779
+ "gradient_norm_before_clip": gradient_norm_before_clip,
780
+ "gradient_norm_after_clip": gradient_norm_after_clip,
781
+ }
782
+ )
783
+ if accelerator.state.deepspeed_plugin is None:
784
+ optimizer.step()
785
+ optimizer.zero_grad()
786
+
787
+ if not args.use_cpu_offload_optimizer:
788
+ lr_scheduler.step()
789
+
790
+ # Checks if the accelerator has performed an optimization step behind the scenes
791
+ if accelerator.sync_gradients:
792
+ progress_bar.update(1)
793
+ global_step += 1
794
+
795
+ # Checkpointing
796
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
797
+ if global_step % args.checkpointing_steps == 0:
798
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
799
+ if args.checkpoints_total_limit is not None:
800
+ checkpoints = os.listdir(args.output_dir)
801
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
802
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
803
+
804
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
805
+ if len(checkpoints) >= args.checkpoints_total_limit:
806
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
807
+ removing_checkpoints = checkpoints[0:num_to_remove]
808
+
809
+ logger.info(
810
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
811
+ )
812
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
813
+
814
+ for removing_checkpoint in removing_checkpoints:
815
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
816
+ shutil.rmtree(removing_checkpoint)
817
+
818
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
819
+ accelerator.save_state(save_path)
820
+ logger.info(f"Saved state to {save_path}")
821
+
822
+ # Validation
823
+ should_run_validation = args.validation_prompt is not None and (
824
+ args.validation_steps is not None and global_step % args.validation_steps == 0
825
+ )
826
+ if should_run_validation:
827
+ run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
828
+
829
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
830
+ logs.update(
831
+ {
832
+ "loss": loss.detach().item(),
833
+ "lr": last_lr,
834
+ }
835
+ )
836
+ progress_bar.set_postfix(**logs)
837
+ accelerator.log(logs, step=global_step)
838
+
839
+ if global_step >= args.max_train_steps:
840
+ break
841
+
842
+ if accelerator.is_main_process:
843
+ should_run_validation = args.validation_prompt is not None and (
844
+ args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
845
+ )
846
+ if should_run_validation:
847
+ run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
848
+ accelerator.wait_for_everyone()
849
+
850
+ if accelerator.is_main_process:
851
+ transformer = unwrap_model(accelerator, transformer)
852
+ dtype = (
853
+ torch.float16
854
+ if args.mixed_precision == "fp16"
855
+ else torch.bfloat16
856
+ if args.mixed_precision == "bf16"
857
+ else torch.float32
858
+ )
859
+ transformer = transformer.to(dtype)
860
+
861
+ transformer.save_pretrained(
862
+ os.path.join(args.output_dir, "transformer"),
863
+ safe_serialization=True,
864
+ max_shard_size="5GB",
865
+ )
866
+
867
+ # Cleanup trained models to save memory
868
+ if args.load_tensors:
869
+ del transformer
870
+ else:
871
+ del transformer, text_encoder, vae
872
+
873
+ gc.collect()
874
+ torch.cuda.empty_cache()
875
+ torch.cuda.synchronize(accelerator.device)
876
+
877
+ accelerator.print("===== Memory before testing =====")
878
+ print_memory(accelerator.device)
879
+ reset_memory(accelerator.device)
880
+
881
+ # Final test inference
882
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
883
+ args.pretrained_model_name_or_path,
884
+ revision=args.revision,
885
+ variant=args.variant,
886
+ torch_dtype=weight_dtype,
887
+ )
888
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
889
+
890
+ if args.enable_slicing:
891
+ pipe.vae.enable_slicing()
892
+ if args.enable_tiling:
893
+ pipe.vae.enable_tiling()
894
+ if args.enable_model_cpu_offload:
895
+ pipe.enable_model_cpu_offload()
896
+
897
+ # Run inference
898
+ validation_outputs = []
899
+ if args.validation_prompt and args.num_validation_videos > 0:
900
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
901
+ validation_images = args.validation_images.split(args.validation_prompt_separator)
902
+ for validation_image, validation_prompt in zip(validation_images, validation_prompts):
903
+ pipeline_args = {
904
+ "image": load_image(validation_image),
905
+ "prompt": validation_prompt,
906
+ "guidance_scale": args.guidance_scale,
907
+ "use_dynamic_cfg": args.use_dynamic_cfg,
908
+ "height": args.height,
909
+ "width": args.width,
910
+ }
911
+
912
+ video = log_validation(
913
+ accelerator=accelerator,
914
+ pipe=pipe,
915
+ args=args,
916
+ pipeline_args=pipeline_args,
917
+ is_final_validation=True,
918
+ )
919
+ validation_outputs.extend(video)
920
+
921
+ accelerator.print("===== Memory after testing =====")
922
+ print_memory(accelerator.device)
923
+ reset_memory(accelerator.device)
924
+ torch.cuda.synchronize(accelerator.device)
925
+
926
+ if args.push_to_hub:
927
+ save_model_card(
928
+ repo_id,
929
+ videos=validation_outputs,
930
+ base_model=args.pretrained_model_name_or_path,
931
+ validation_prompt=args.validation_prompt,
932
+ repo_folder=args.output_dir,
933
+ fps=args.fps,
934
+ )
935
+ upload_folder(
936
+ repo_id=repo_id,
937
+ folder_path=args.output_dir,
938
+ commit_message="End of training",
939
+ ignore_patterns=["step_*", "epoch_*"],
940
+ )
941
+
942
+ accelerator.end_training()
943
+
944
+
945
+ if __name__ == "__main__":
946
+ args = get_args()
947
+ main(args)
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import logging
18
+ import math
19
+ import os
20
+ import shutil
21
+ from datetime import timedelta
22
+ from pathlib import Path
23
+ from typing import Any, Dict
24
+
25
+ import diffusers
26
+ import torch
27
+ import transformers
28
+ import wandb
29
+ from accelerate import Accelerator, DistributedType
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import (
32
+ DistributedDataParallelKwargs,
33
+ InitProcessGroupKwargs,
34
+ ProjectConfiguration,
35
+ set_seed,
36
+ )
37
+ from diffusers import (
38
+ AutoencoderKLCogVideoX,
39
+ CogVideoXDPMScheduler,
40
+ CogVideoXPipeline,
41
+ CogVideoXTransformer3DModel,
42
+ )
43
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import cast_training_params
46
+ from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video
47
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
48
+ from huggingface_hub import create_repo, upload_folder
49
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
50
+ from torch.utils.data import DataLoader
51
+ from tqdm.auto import tqdm
52
+ from transformers import AutoTokenizer, T5EncoderModel
53
+
54
+
55
+ from args import get_args # isort:skip
56
+ from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
57
+ from text_encoder import compute_prompt_embeddings # isort:skip
58
+ from utils import (
59
+ get_gradient_norm,
60
+ get_optimizer,
61
+ prepare_rotary_positional_embeddings,
62
+ print_memory,
63
+ reset_memory,
64
+ unwrap_model,
65
+ ) # isort:skip
66
+
67
+
68
+ logger = get_logger(__name__)
69
+
70
+
71
+ def save_model_card(
72
+ repo_id: str,
73
+ videos=None,
74
+ base_model: str = None,
75
+ validation_prompt=None,
76
+ repo_folder=None,
77
+ fps=8,
78
+ ):
79
+ widget_dict = []
80
+ if videos is not None:
81
+ for i, video in enumerate(videos):
82
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
83
+ widget_dict.append(
84
+ {
85
+ "text": validation_prompt if validation_prompt else " ",
86
+ "output": {"url": f"video_{i}.mp4"},
87
+ }
88
+ )
89
+
90
+ model_description = f"""
91
+ # CogVideoX LoRA Finetune
92
+
93
+ <Gallery />
94
+
95
+ ## Model description
96
+
97
+ This is a lora finetune of the CogVideoX model `{base_model}`.
98
+
99
+ The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
100
+
101
+ ## Download model
102
+
103
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
104
+
105
+ ## Usage
106
+
107
+ Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
108
+
109
+ ```py
110
+ import torch
111
+ from diffusers import CogVideoXPipeline
112
+ from diffusers.utils import export_to_video
113
+
114
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
115
+ pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
116
+
117
+ # The LoRA adapter weights are determined by what was used for training.
118
+ # In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
119
+ # It can be made lower or higher from what was used in training to decrease or amplify the effect
120
+ # of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
121
+ pipe.set_adapters(["cogvideox-lora"], [32 / 64])
122
+
123
+ video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
124
+ export_to_video(video, "output.mp4", fps=8)
125
+ ```
126
+
127
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
128
+
129
+ ## License
130
+
131
+ Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
132
+ """
133
+ model_card = load_or_create_model_card(
134
+ repo_id_or_path=repo_id,
135
+ from_training=True,
136
+ license="other",
137
+ base_model=base_model,
138
+ prompt=validation_prompt,
139
+ model_description=model_description,
140
+ widget=widget_dict,
141
+ )
142
+ tags = [
143
+ "text-to-video",
144
+ "diffusers-training",
145
+ "diffusers",
146
+ "lora",
147
+ "cogvideox",
148
+ "cogvideox-diffusers",
149
+ "template:sd-lora",
150
+ ]
151
+
152
+ model_card = populate_model_card(model_card, tags=tags)
153
+ model_card.save(os.path.join(repo_folder, "README.md"))
154
+
155
+
156
+ def log_validation(
157
+ accelerator: Accelerator,
158
+ pipe: CogVideoXPipeline,
159
+ args: Dict[str, Any],
160
+ pipeline_args: Dict[str, Any],
161
+ epoch,
162
+ is_final_validation: bool = False,
163
+ ):
164
+ logger.info(
165
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
166
+ )
167
+
168
+ pipe = pipe.to(accelerator.device)
169
+
170
+ # run inference
171
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
172
+
173
+ videos = []
174
+ for _ in range(args.num_validation_videos):
175
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
176
+ videos.append(video)
177
+
178
+ for tracker in accelerator.trackers:
179
+ phase_name = "test" if is_final_validation else "validation"
180
+ if tracker.name == "wandb":
181
+ video_filenames = []
182
+ for i, video in enumerate(videos):
183
+ prompt = (
184
+ pipeline_args["prompt"][:25]
185
+ .replace(" ", "_")
186
+ .replace(" ", "_")
187
+ .replace("'", "_")
188
+ .replace('"', "_")
189
+ .replace("/", "_")
190
+ )
191
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
192
+ export_to_video(video, filename, fps=8)
193
+ video_filenames.append(filename)
194
+
195
+ tracker.log(
196
+ {
197
+ phase_name: [
198
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
199
+ for i, filename in enumerate(video_filenames)
200
+ ]
201
+ }
202
+ )
203
+
204
+ return videos
205
+
206
+
207
+ class CollateFunction:
208
+ def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
209
+ self.weight_dtype = weight_dtype
210
+ self.load_tensors = load_tensors
211
+
212
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
213
+ prompts = [x["prompt"] for x in data[0]]
214
+
215
+ if self.load_tensors:
216
+ prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
217
+
218
+ videos = [x["video"] for x in data[0]]
219
+ videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
220
+
221
+ return {
222
+ "videos": videos,
223
+ "prompts": prompts,
224
+ }
225
+
226
+
227
+ def main(args):
228
+ if args.report_to == "wandb" and args.hub_token is not None:
229
+ raise ValueError(
230
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
231
+ " Please use `huggingface-cli login` to authenticate with the Hub."
232
+ )
233
+
234
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
235
+ # due to pytorch#99272, MPS does not yet support bfloat16.
236
+ raise ValueError(
237
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
238
+ )
239
+
240
+ logging_dir = Path(args.output_dir, args.logging_dir)
241
+
242
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
243
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
244
+ init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
245
+ accelerator = Accelerator(
246
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
247
+ mixed_precision=args.mixed_precision,
248
+ log_with=args.report_to,
249
+ project_config=accelerator_project_config,
250
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
251
+ )
252
+
253
+ # Disable AMP for MPS.
254
+ if torch.backends.mps.is_available():
255
+ accelerator.native_amp = False
256
+
257
+ # Make one log on every process with the configuration for debugging.
258
+ logging.basicConfig(
259
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
260
+ datefmt="%m/%d/%Y %H:%M:%S",
261
+ level=logging.INFO,
262
+ )
263
+ logger.info(accelerator.state, main_process_only=False)
264
+ if accelerator.is_local_main_process:
265
+ transformers.utils.logging.set_verbosity_warning()
266
+ diffusers.utils.logging.set_verbosity_info()
267
+ else:
268
+ transformers.utils.logging.set_verbosity_error()
269
+ diffusers.utils.logging.set_verbosity_error()
270
+
271
+ # If passed along, set the training seed now.
272
+ if args.seed is not None:
273
+ set_seed(args.seed)
274
+
275
+ # Handle the repository creation
276
+ if accelerator.is_main_process:
277
+ if args.output_dir is not None:
278
+ os.makedirs(args.output_dir, exist_ok=True)
279
+
280
+ if args.push_to_hub:
281
+ repo_id = create_repo(
282
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
283
+ exist_ok=True,
284
+ ).repo_id
285
+
286
+ # Prepare models and scheduler
287
+ tokenizer = AutoTokenizer.from_pretrained(
288
+ args.pretrained_model_name_or_path,
289
+ subfolder="tokenizer",
290
+ revision=args.revision,
291
+ )
292
+
293
+ text_encoder = T5EncoderModel.from_pretrained(
294
+ args.pretrained_model_name_or_path,
295
+ subfolder="text_encoder",
296
+ revision=args.revision,
297
+ )
298
+
299
+ # CogVideoX-2b weights are stored in float16
300
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
301
+ load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
302
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
303
+ args.pretrained_model_name_or_path,
304
+ subfolder="transformer",
305
+ torch_dtype=load_dtype,
306
+ revision=args.revision,
307
+ variant=args.variant,
308
+ )
309
+
310
+ vae = AutoencoderKLCogVideoX.from_pretrained(
311
+ args.pretrained_model_name_or_path,
312
+ subfolder="vae",
313
+ revision=args.revision,
314
+ variant=args.variant,
315
+ )
316
+
317
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
318
+
319
+ if args.enable_slicing:
320
+ vae.enable_slicing()
321
+ if args.enable_tiling:
322
+ vae.enable_tiling()
323
+
324
+ # We only train the additional adapter LoRA layers
325
+ text_encoder.requires_grad_(False)
326
+ transformer.requires_grad_(False)
327
+ vae.requires_grad_(False)
328
+
329
+ VAE_SCALING_FACTOR = vae.config.scaling_factor
330
+ VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
331
+ RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
332
+ RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
333
+
334
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
335
+ # as these weights are only used for inference, keeping weights in full precision is not required.
336
+ weight_dtype = torch.float32
337
+ if accelerator.state.deepspeed_plugin:
338
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
339
+ if (
340
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
341
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
342
+ ):
343
+ weight_dtype = torch.float16
344
+ if (
345
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
346
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
347
+ ):
348
+ weight_dtype = torch.bfloat16
349
+ else:
350
+ if accelerator.mixed_precision == "fp16":
351
+ weight_dtype = torch.float16
352
+ elif accelerator.mixed_precision == "bf16":
353
+ weight_dtype = torch.bfloat16
354
+
355
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
356
+ # due to pytorch#99272, MPS does not yet support bfloat16.
357
+ raise ValueError(
358
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
359
+ )
360
+
361
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
362
+ transformer.to(accelerator.device, dtype=weight_dtype)
363
+ vae.to(accelerator.device, dtype=weight_dtype)
364
+
365
+ if args.gradient_checkpointing:
366
+ transformer.enable_gradient_checkpointing()
367
+
368
+ # now we will add new LoRA weights to the attention layers
369
+ transformer_lora_config = LoraConfig(
370
+ r=args.rank,
371
+ lora_alpha=args.lora_alpha,
372
+ init_lora_weights=True,
373
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
374
+ )
375
+ transformer.add_adapter(transformer_lora_config)
376
+
377
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
378
+ def save_model_hook(models, weights, output_dir):
379
+ if accelerator.is_main_process:
380
+ transformer_lora_layers_to_save = None
381
+
382
+ for model in models:
383
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
384
+ model = unwrap_model(accelerator, model)
385
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
386
+ else:
387
+ raise ValueError(f"unexpected save model: {model.__class__}")
388
+
389
+ # make sure to pop weight so that corresponding model is not saved again
390
+ if weights:
391
+ weights.pop()
392
+
393
+ CogVideoXPipeline.save_lora_weights(
394
+ output_dir,
395
+ transformer_lora_layers=transformer_lora_layers_to_save,
396
+ )
397
+
398
+ def load_model_hook(models, input_dir):
399
+ transformer_ = None
400
+
401
+ # This is a bit of a hack but I don't know any other solution.
402
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
403
+ while len(models) > 0:
404
+ model = models.pop()
405
+
406
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
407
+ transformer_ = unwrap_model(accelerator, model)
408
+ else:
409
+ raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
410
+ else:
411
+ transformer_ = CogVideoXTransformer3DModel.from_pretrained(
412
+ args.pretrained_model_name_or_path, subfolder="transformer"
413
+ )
414
+ transformer_.add_adapter(transformer_lora_config)
415
+
416
+ lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
417
+
418
+ transformer_state_dict = {
419
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
420
+ }
421
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
422
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
423
+ if incompatible_keys is not None:
424
+ # check only for unexpected keys
425
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
426
+ if unexpected_keys:
427
+ logger.warning(
428
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
429
+ f" {unexpected_keys}. "
430
+ )
431
+
432
+ # Make sure the trainable params are in float32. This is again needed since the base models
433
+ # are in `weight_dtype`. More details:
434
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
435
+ if args.mixed_precision == "fp16":
436
+ # only upcast trainable parameters (LoRA) into fp32
437
+ cast_training_params([transformer_])
438
+
439
+ accelerator.register_save_state_pre_hook(save_model_hook)
440
+ accelerator.register_load_state_pre_hook(load_model_hook)
441
+
442
+ # Enable TF32 for faster training on Ampere GPUs,
443
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
444
+ if args.allow_tf32 and torch.cuda.is_available():
445
+ torch.backends.cuda.matmul.allow_tf32 = True
446
+
447
+ if args.scale_lr:
448
+ args.learning_rate = (
449
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
450
+ )
451
+
452
+ # Make sure the trainable params are in float32.
453
+ if args.mixed_precision == "fp16":
454
+ # only upcast trainable parameters (LoRA) into fp32
455
+ cast_training_params([transformer], dtype=torch.float32)
456
+
457
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
458
+
459
+ # Optimization parameters
460
+ transformer_parameters_with_lr = {
461
+ "params": transformer_lora_parameters,
462
+ "lr": args.learning_rate,
463
+ }
464
+ params_to_optimize = [transformer_parameters_with_lr]
465
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
466
+
467
+ use_deepspeed_optimizer = (
468
+ accelerator.state.deepspeed_plugin is not None
469
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
470
+ )
471
+ use_deepspeed_scheduler = (
472
+ accelerator.state.deepspeed_plugin is not None
473
+ and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
474
+ )
475
+
476
+ optimizer = get_optimizer(
477
+ params_to_optimize=params_to_optimize,
478
+ optimizer_name=args.optimizer,
479
+ learning_rate=args.learning_rate,
480
+ beta1=args.beta1,
481
+ beta2=args.beta2,
482
+ beta3=args.beta3,
483
+ epsilon=args.epsilon,
484
+ weight_decay=args.weight_decay,
485
+ prodigy_decouple=args.prodigy_decouple,
486
+ prodigy_use_bias_correction=args.prodigy_use_bias_correction,
487
+ prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
488
+ use_8bit=args.use_8bit,
489
+ use_4bit=args.use_4bit,
490
+ use_torchao=args.use_torchao,
491
+ use_deepspeed=use_deepspeed_optimizer,
492
+ use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
493
+ offload_gradients=args.offload_gradients,
494
+ )
495
+
496
+ # Dataset and DataLoader
497
+ dataset_init_kwargs = {
498
+ "data_root": args.data_root,
499
+ "dataset_file": args.dataset_file,
500
+ "caption_column": args.caption_column,
501
+ "video_column": args.video_column,
502
+ "max_num_frames": args.max_num_frames,
503
+ "id_token": args.id_token,
504
+ "height_buckets": args.height_buckets,
505
+ "width_buckets": args.width_buckets,
506
+ "frame_buckets": args.frame_buckets,
507
+ "load_tensors": args.load_tensors,
508
+ "random_flip": args.random_flip,
509
+ }
510
+ if args.video_reshape_mode is None:
511
+ train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
512
+ else:
513
+ train_dataset = VideoDatasetWithResizeAndRectangleCrop(
514
+ video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
515
+ )
516
+
517
+ collate_fn = CollateFunction(weight_dtype, args.load_tensors)
518
+
519
+ train_dataloader = DataLoader(
520
+ train_dataset,
521
+ batch_size=1,
522
+ sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
523
+ collate_fn=collate_fn,
524
+ num_workers=args.dataloader_num_workers,
525
+ pin_memory=args.pin_memory,
526
+ )
527
+
528
+ # Scheduler and math around the number of training steps.
529
+ overrode_max_train_steps = False
530
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
531
+ if args.max_train_steps is None:
532
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
533
+ overrode_max_train_steps = True
534
+
535
+ if args.use_cpu_offload_optimizer:
536
+ lr_scheduler = None
537
+ accelerator.print(
538
+ "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
539
+ "you are training with those settings, they will be ignored."
540
+ )
541
+ else:
542
+ if use_deepspeed_scheduler:
543
+ from accelerate.utils import DummyScheduler
544
+
545
+ lr_scheduler = DummyScheduler(
546
+ name=args.lr_scheduler,
547
+ optimizer=optimizer,
548
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
549
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
550
+ )
551
+ else:
552
+ lr_scheduler = get_scheduler(
553
+ args.lr_scheduler,
554
+ optimizer=optimizer,
555
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
556
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
557
+ num_cycles=args.lr_num_cycles,
558
+ power=args.lr_power,
559
+ )
560
+
561
+ # Prepare everything with our `accelerator`.
562
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
563
+ transformer, optimizer, train_dataloader, lr_scheduler
564
+ )
565
+
566
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
567
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
568
+ if overrode_max_train_steps:
569
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
570
+ # Afterwards we recalculate our number of training epochs
571
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
572
+
573
+ # We need to initialize the trackers we use, and also store our configuration.
574
+ # The trackers initializes automatically on the main process.
575
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
576
+ tracker_name = args.tracker_name or "cogvideox-lora"
577
+ accelerator.init_trackers(tracker_name, config=vars(args))
578
+
579
+ accelerator.print("===== Memory before training =====")
580
+ reset_memory(accelerator.device)
581
+ print_memory(accelerator.device)
582
+
583
+ # Train!
584
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
585
+
586
+ accelerator.print("***** Running training *****")
587
+ accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
588
+ accelerator.print(f" Num examples = {len(train_dataset)}")
589
+ accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
590
+ accelerator.print(f" Num epochs = {args.num_train_epochs}")
591
+ accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
592
+ accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
593
+ accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
594
+ accelerator.print(f" Total optimization steps = {args.max_train_steps}")
595
+ global_step = 0
596
+ first_epoch = 0
597
+
598
+ # Potentially load in the weights and states from a previous save
599
+ if not args.resume_from_checkpoint:
600
+ initial_global_step = 0
601
+ else:
602
+ if args.resume_from_checkpoint != "latest":
603
+ path = os.path.basename(args.resume_from_checkpoint)
604
+ else:
605
+ # Get the most recent checkpoint
606
+ dirs = os.listdir(args.output_dir)
607
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
608
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
609
+ path = dirs[-1] if len(dirs) > 0 else None
610
+
611
+ if path is None:
612
+ accelerator.print(
613
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
614
+ )
615
+ args.resume_from_checkpoint = None
616
+ initial_global_step = 0
617
+ else:
618
+ accelerator.print(f"Resuming from checkpoint {path}")
619
+ accelerator.load_state(os.path.join(args.output_dir, path))
620
+ global_step = int(path.split("-")[1])
621
+
622
+ initial_global_step = global_step
623
+ first_epoch = global_step // num_update_steps_per_epoch
624
+
625
+ progress_bar = tqdm(
626
+ range(0, args.max_train_steps),
627
+ initial=initial_global_step,
628
+ desc="Steps",
629
+ # Only show the progress bar once on each machine.
630
+ disable=not accelerator.is_local_main_process,
631
+ )
632
+
633
+ # For DeepSpeed training
634
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
635
+
636
+ if args.load_tensors:
637
+ del vae, text_encoder
638
+ gc.collect()
639
+ torch.cuda.empty_cache()
640
+ torch.cuda.synchronize(accelerator.device)
641
+
642
+ alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
643
+
644
+ for epoch in range(first_epoch, args.num_train_epochs):
645
+ transformer.train()
646
+
647
+ for step, batch in enumerate(train_dataloader):
648
+ models_to_accumulate = [transformer]
649
+ logs = {}
650
+
651
+ with accelerator.accumulate(models_to_accumulate):
652
+ videos = batch["videos"].to(accelerator.device, non_blocking=True)
653
+ prompts = batch["prompts"]
654
+
655
+ # Encode videos
656
+ if not args.load_tensors:
657
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
658
+ latent_dist = vae.encode(videos).latent_dist
659
+ else:
660
+ latent_dist = DiagonalGaussianDistribution(videos)
661
+
662
+ videos = latent_dist.sample() * VAE_SCALING_FACTOR
663
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
664
+ videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
665
+ model_input = videos
666
+
667
+ # Encode prompts
668
+ if not args.load_tensors:
669
+ prompt_embeds = compute_prompt_embeddings(
670
+ tokenizer,
671
+ text_encoder,
672
+ prompts,
673
+ model_config.max_text_seq_length,
674
+ accelerator.device,
675
+ weight_dtype,
676
+ requires_grad=False,
677
+ )
678
+ else:
679
+ prompt_embeds = prompts.to(dtype=weight_dtype)
680
+
681
+ # Sample noise that will be added to the latents
682
+ noise = torch.randn_like(model_input)
683
+ batch_size, num_frames, num_channels, height, width = model_input.shape
684
+
685
+ # Sample a random timestep for each image
686
+ timesteps = torch.randint(
687
+ 0,
688
+ scheduler.config.num_train_timesteps,
689
+ (batch_size,),
690
+ dtype=torch.int64,
691
+ device=model_input.device,
692
+ )
693
+
694
+ # Prepare rotary embeds
695
+ image_rotary_emb = (
696
+ prepare_rotary_positional_embeddings(
697
+ height=height * VAE_SCALE_FACTOR_SPATIAL,
698
+ width=width * VAE_SCALE_FACTOR_SPATIAL,
699
+ num_frames=num_frames,
700
+ vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
701
+ patch_size=model_config.patch_size,
702
+ patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
703
+ attention_head_dim=model_config.attention_head_dim,
704
+ device=accelerator.device,
705
+ base_height=RoPE_BASE_HEIGHT,
706
+ base_width=RoPE_BASE_WIDTH,
707
+ )
708
+ if model_config.use_rotary_positional_embeddings
709
+ else None
710
+ )
711
+
712
+ # Add noise to the model input according to the noise magnitude at each timestep
713
+ # (this is the forward diffusion process)
714
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
715
+
716
+ # Predict the noise residual
717
+ model_output = transformer(
718
+ hidden_states=noisy_model_input,
719
+ encoder_hidden_states=prompt_embeds,
720
+ timestep=timesteps,
721
+ image_rotary_emb=image_rotary_emb,
722
+ return_dict=False,
723
+ )[0]
724
+
725
+ model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
726
+
727
+ weights = 1 / (1 - alphas_cumprod[timesteps])
728
+ while len(weights.shape) < len(model_pred.shape):
729
+ weights = weights.unsqueeze(-1)
730
+
731
+ target = model_input
732
+
733
+ loss = torch.mean(
734
+ (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
735
+ dim=1,
736
+ )
737
+ loss = loss.mean()
738
+ accelerator.backward(loss)
739
+
740
+ if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
741
+ gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
742
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
743
+ gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
744
+ logs.update(
745
+ {
746
+ "gradient_norm_before_clip": gradient_norm_before_clip,
747
+ "gradient_norm_after_clip": gradient_norm_after_clip,
748
+ }
749
+ )
750
+
751
+ if accelerator.state.deepspeed_plugin is None:
752
+ optimizer.step()
753
+ optimizer.zero_grad()
754
+
755
+ if not args.use_cpu_offload_optimizer:
756
+ lr_scheduler.step()
757
+
758
+ # Checks if the accelerator has performed an optimization step behind the scenes
759
+ if accelerator.sync_gradients:
760
+ progress_bar.update(1)
761
+ global_step += 1
762
+
763
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
764
+ if global_step % args.checkpointing_steps == 0:
765
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
766
+ if args.checkpoints_total_limit is not None:
767
+ checkpoints = os.listdir(args.output_dir)
768
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
769
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
770
+
771
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
772
+ if len(checkpoints) >= args.checkpoints_total_limit:
773
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
774
+ removing_checkpoints = checkpoints[0:num_to_remove]
775
+
776
+ logger.info(
777
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
778
+ )
779
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
780
+
781
+ for removing_checkpoint in removing_checkpoints:
782
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
783
+ shutil.rmtree(removing_checkpoint)
784
+
785
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
786
+ accelerator.save_state(save_path)
787
+ logger.info(f"Saved state to {save_path}")
788
+
789
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
790
+ logs.update(
791
+ {
792
+ "loss": loss.detach().item(),
793
+ "lr": last_lr,
794
+ }
795
+ )
796
+ progress_bar.set_postfix(**logs)
797
+ accelerator.log(logs, step=global_step)
798
+
799
+ if global_step >= args.max_train_steps:
800
+ break
801
+
802
+ if accelerator.is_main_process:
803
+ if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
804
+ accelerator.print("===== Memory before validation =====")
805
+ print_memory(accelerator.device)
806
+ torch.cuda.synchronize(accelerator.device)
807
+
808
+ pipe = CogVideoXPipeline.from_pretrained(
809
+ args.pretrained_model_name_or_path,
810
+ transformer=unwrap_model(accelerator, transformer),
811
+ scheduler=scheduler,
812
+ revision=args.revision,
813
+ variant=args.variant,
814
+ torch_dtype=weight_dtype,
815
+ )
816
+
817
+ if args.enable_slicing:
818
+ pipe.vae.enable_slicing()
819
+ if args.enable_tiling:
820
+ pipe.vae.enable_tiling()
821
+ if args.enable_model_cpu_offload:
822
+ pipe.enable_model_cpu_offload()
823
+
824
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
825
+ for validation_prompt in validation_prompts:
826
+ pipeline_args = {
827
+ "prompt": validation_prompt,
828
+ "guidance_scale": args.guidance_scale,
829
+ "use_dynamic_cfg": args.use_dynamic_cfg,
830
+ "height": args.height,
831
+ "width": args.width,
832
+ "max_sequence_length": model_config.max_text_seq_length,
833
+ }
834
+
835
+ log_validation(
836
+ pipe=pipe,
837
+ args=args,
838
+ accelerator=accelerator,
839
+ pipeline_args=pipeline_args,
840
+ epoch=epoch,
841
+ )
842
+
843
+ accelerator.print("===== Memory after validation =====")
844
+ print_memory(accelerator.device)
845
+ reset_memory(accelerator.device)
846
+
847
+ del pipe
848
+ gc.collect()
849
+ torch.cuda.empty_cache()
850
+ torch.cuda.synchronize(accelerator.device)
851
+
852
+ accelerator.wait_for_everyone()
853
+
854
+ if accelerator.is_main_process:
855
+ transformer = unwrap_model(accelerator, transformer)
856
+ dtype = (
857
+ torch.float16
858
+ if args.mixed_precision == "fp16"
859
+ else torch.bfloat16
860
+ if args.mixed_precision == "bf16"
861
+ else torch.float32
862
+ )
863
+ transformer = transformer.to(dtype)
864
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
865
+
866
+ CogVideoXPipeline.save_lora_weights(
867
+ save_directory=args.output_dir,
868
+ transformer_lora_layers=transformer_lora_layers,
869
+ )
870
+
871
+ # Cleanup trained models to save memory
872
+ if args.load_tensors:
873
+ del transformer
874
+ else:
875
+ del transformer, text_encoder, vae
876
+
877
+ gc.collect()
878
+ torch.cuda.empty_cache()
879
+ torch.cuda.synchronize(accelerator.device)
880
+
881
+ accelerator.print("===== Memory before testing =====")
882
+ print_memory(accelerator.device)
883
+ reset_memory(accelerator.device)
884
+
885
+ # Final test inference
886
+ pipe = CogVideoXPipeline.from_pretrained(
887
+ args.pretrained_model_name_or_path,
888
+ revision=args.revision,
889
+ variant=args.variant,
890
+ torch_dtype=weight_dtype,
891
+ )
892
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
893
+
894
+ if args.enable_slicing:
895
+ pipe.vae.enable_slicing()
896
+ if args.enable_tiling:
897
+ pipe.vae.enable_tiling()
898
+ if args.enable_model_cpu_offload:
899
+ pipe.enable_model_cpu_offload()
900
+
901
+ # Load LoRA weights
902
+ lora_scaling = args.lora_alpha / args.rank
903
+ pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
904
+ pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
905
+
906
+ # Run inference
907
+ validation_outputs = []
908
+ if args.validation_prompt and args.num_validation_videos > 0:
909
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
910
+ for validation_prompt in validation_prompts:
911
+ pipeline_args = {
912
+ "prompt": validation_prompt,
913
+ "guidance_scale": args.guidance_scale,
914
+ "use_dynamic_cfg": args.use_dynamic_cfg,
915
+ "height": args.height,
916
+ "width": args.width,
917
+ }
918
+
919
+ video = log_validation(
920
+ accelerator=accelerator,
921
+ pipe=pipe,
922
+ args=args,
923
+ pipeline_args=pipeline_args,
924
+ epoch=epoch,
925
+ is_final_validation=True,
926
+ )
927
+ validation_outputs.extend(video)
928
+
929
+ accelerator.print("===== Memory after testing =====")
930
+ print_memory(accelerator.device)
931
+ reset_memory(accelerator.device)
932
+ torch.cuda.synchronize(accelerator.device)
933
+
934
+ if args.push_to_hub:
935
+ save_model_card(
936
+ repo_id,
937
+ videos=validation_outputs,
938
+ base_model=args.pretrained_model_name_or_path,
939
+ validation_prompt=args.validation_prompt,
940
+ repo_folder=args.output_dir,
941
+ fps=args.fps,
942
+ )
943
+ upload_folder(
944
+ repo_id=repo_id,
945
+ folder_path=args.output_dir,
946
+ commit_message="End of training",
947
+ ignore_patterns=["step_*", "epoch_*"],
948
+ )
949
+
950
+ accelerator.end_training()
951
+
952
+
953
+ if __name__ == "__main__":
954
+ args = get_args()
955
+ main(args)
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import logging
18
+ import math
19
+ import os
20
+ import shutil
21
+ from datetime import timedelta
22
+ from pathlib import Path
23
+ from typing import Any, Dict
24
+
25
+ import diffusers
26
+ import torch
27
+ import transformers
28
+ import wandb
29
+ from accelerate import Accelerator, DistributedType, init_empty_weights
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import (
32
+ DistributedDataParallelKwargs,
33
+ InitProcessGroupKwargs,
34
+ ProjectConfiguration,
35
+ set_seed,
36
+ )
37
+ from diffusers import (
38
+ AutoencoderKLCogVideoX,
39
+ CogVideoXDPMScheduler,
40
+ CogVideoXPipeline,
41
+ CogVideoXTransformer3DModel,
42
+ )
43
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import cast_training_params
46
+ from diffusers.utils import export_to_video
47
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
48
+ from huggingface_hub import create_repo, upload_folder
49
+ from torch.utils.data import DataLoader
50
+ from tqdm.auto import tqdm
51
+ from transformers import AutoTokenizer, T5EncoderModel
52
+
53
+
54
+ from args import get_args # isort:skip
55
+ from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
56
+ from text_encoder import compute_prompt_embeddings # isort:skip
57
+ from utils import (
58
+ get_gradient_norm,
59
+ get_optimizer,
60
+ prepare_rotary_positional_embeddings,
61
+ print_memory,
62
+ reset_memory,
63
+ unwrap_model,
64
+ ) # isort:skip
65
+
66
+
67
+ logger = get_logger(__name__)
68
+
69
+
70
+ def save_model_card(
71
+ repo_id: str,
72
+ videos=None,
73
+ base_model: str = None,
74
+ validation_prompt=None,
75
+ repo_folder=None,
76
+ fps=8,
77
+ ):
78
+ widget_dict = []
79
+ if videos is not None:
80
+ for i, video in enumerate(videos):
81
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
82
+ widget_dict.append(
83
+ {
84
+ "text": validation_prompt if validation_prompt else " ",
85
+ "output": {"url": f"video_{i}.mp4"},
86
+ }
87
+ )
88
+
89
+ model_description = f"""
90
+ # CogVideoX Full Finetune
91
+
92
+ <Gallery />
93
+
94
+ ## Model description
95
+
96
+ This is a full finetune of the CogVideoX model `{base_model}`.
97
+
98
+ The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
99
+
100
+ ## Download model
101
+
102
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
103
+
104
+ ## Usage
105
+
106
+ Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
107
+
108
+ ```py
109
+ import torch
110
+ from diffusers import CogVideoXPipeline
111
+ from diffusers.utils import export_to_video
112
+
113
+ pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda")
114
+
115
+ video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
116
+ export_to_video(video, "output.mp4", fps=8)
117
+ ```
118
+
119
+ For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX.
120
+
121
+ ## License
122
+
123
+ Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
124
+ """
125
+ model_card = load_or_create_model_card(
126
+ repo_id_or_path=repo_id,
127
+ from_training=True,
128
+ license="other",
129
+ base_model=base_model,
130
+ prompt=validation_prompt,
131
+ model_description=model_description,
132
+ widget=widget_dict,
133
+ )
134
+ tags = [
135
+ "text-to-video",
136
+ "diffusers-training",
137
+ "diffusers",
138
+ "cogvideox",
139
+ "cogvideox-diffusers",
140
+ ]
141
+
142
+ model_card = populate_model_card(model_card, tags=tags)
143
+ model_card.save(os.path.join(repo_folder, "README.md"))
144
+
145
+
146
+ def log_validation(
147
+ accelerator: Accelerator,
148
+ pipe: CogVideoXPipeline,
149
+ args: Dict[str, Any],
150
+ pipeline_args: Dict[str, Any],
151
+ epoch,
152
+ is_final_validation: bool = False,
153
+ ):
154
+ logger.info(
155
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
156
+ )
157
+
158
+ pipe = pipe.to(accelerator.device)
159
+
160
+ # run inference
161
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
162
+
163
+ videos = []
164
+ for _ in range(args.num_validation_videos):
165
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
166
+ videos.append(video)
167
+
168
+ for tracker in accelerator.trackers:
169
+ phase_name = "test" if is_final_validation else "validation"
170
+ if tracker.name == "wandb":
171
+ video_filenames = []
172
+ for i, video in enumerate(videos):
173
+ prompt = (
174
+ pipeline_args["prompt"][:25]
175
+ .replace(" ", "_")
176
+ .replace(" ", "_")
177
+ .replace("'", "_")
178
+ .replace('"', "_")
179
+ .replace("/", "_")
180
+ )
181
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
182
+ export_to_video(video, filename, fps=8)
183
+ video_filenames.append(filename)
184
+
185
+ tracker.log(
186
+ {
187
+ phase_name: [
188
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
189
+ for i, filename in enumerate(video_filenames)
190
+ ]
191
+ }
192
+ )
193
+
194
+ return videos
195
+
196
+
197
+ class CollateFunction:
198
+ def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
199
+ self.weight_dtype = weight_dtype
200
+ self.load_tensors = load_tensors
201
+
202
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
203
+ prompts = [x["prompt"] for x in data[0]]
204
+
205
+ if self.load_tensors:
206
+ prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
207
+
208
+ videos = [x["video"] for x in data[0]]
209
+ videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
210
+
211
+ return {
212
+ "videos": videos,
213
+ "prompts": prompts,
214
+ }
215
+
216
+
217
+ def main(args):
218
+ if args.report_to == "wandb" and args.hub_token is not None:
219
+ raise ValueError(
220
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
221
+ " Please use `huggingface-cli login` to authenticate with the Hub."
222
+ )
223
+
224
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
225
+ # due to pytorch#99272, MPS does not yet support bfloat16.
226
+ raise ValueError(
227
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
228
+ )
229
+
230
+ logging_dir = Path(args.output_dir, args.logging_dir)
231
+
232
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
233
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
234
+ init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
235
+ accelerator = Accelerator(
236
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
237
+ mixed_precision=args.mixed_precision,
238
+ log_with=args.report_to,
239
+ project_config=accelerator_project_config,
240
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
241
+ )
242
+
243
+ # Disable AMP for MPS.
244
+ if torch.backends.mps.is_available():
245
+ accelerator.native_amp = False
246
+
247
+ # Make one log on every process with the configuration for debugging.
248
+ logging.basicConfig(
249
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
250
+ datefmt="%m/%d/%Y %H:%M:%S",
251
+ level=logging.INFO,
252
+ )
253
+ logger.info(accelerator.state, main_process_only=False)
254
+ if accelerator.is_local_main_process:
255
+ transformers.utils.logging.set_verbosity_warning()
256
+ diffusers.utils.logging.set_verbosity_info()
257
+ else:
258
+ transformers.utils.logging.set_verbosity_error()
259
+ diffusers.utils.logging.set_verbosity_error()
260
+
261
+ # If passed along, set the training seed now.
262
+ if args.seed is not None:
263
+ set_seed(args.seed)
264
+
265
+ # Handle the repository creation
266
+ if accelerator.is_main_process:
267
+ if args.output_dir is not None:
268
+ os.makedirs(args.output_dir, exist_ok=True)
269
+
270
+ if args.push_to_hub:
271
+ repo_id = create_repo(
272
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
273
+ exist_ok=True,
274
+ ).repo_id
275
+
276
+ # Prepare models and scheduler
277
+ tokenizer = AutoTokenizer.from_pretrained(
278
+ args.pretrained_model_name_or_path,
279
+ subfolder="tokenizer",
280
+ revision=args.revision,
281
+ )
282
+
283
+ text_encoder = T5EncoderModel.from_pretrained(
284
+ args.pretrained_model_name_or_path,
285
+ subfolder="text_encoder",
286
+ revision=args.revision,
287
+ )
288
+
289
+ # CogVideoX-2b weights are stored in float16
290
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
291
+ load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
292
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
293
+ args.pretrained_model_name_or_path,
294
+ subfolder="transformer",
295
+ torch_dtype=load_dtype,
296
+ revision=args.revision,
297
+ variant=args.variant,
298
+ )
299
+
300
+ vae = AutoencoderKLCogVideoX.from_pretrained(
301
+ args.pretrained_model_name_or_path,
302
+ subfolder="vae",
303
+ revision=args.revision,
304
+ variant=args.variant,
305
+ )
306
+
307
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
308
+
309
+ if args.enable_slicing:
310
+ vae.enable_slicing()
311
+ if args.enable_tiling:
312
+ vae.enable_tiling()
313
+
314
+ text_encoder.requires_grad_(False)
315
+ vae.requires_grad_(False)
316
+ transformer.requires_grad_(True)
317
+
318
+ VAE_SCALING_FACTOR = vae.config.scaling_factor
319
+ VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
320
+ RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
321
+ RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
322
+
323
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
324
+ # as these weights are only used for inference, keeping weights in full precision is not required.
325
+ weight_dtype = torch.float32
326
+ if accelerator.state.deepspeed_plugin:
327
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
328
+ if (
329
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
330
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
331
+ ):
332
+ weight_dtype = torch.float16
333
+ if (
334
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
335
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
336
+ ):
337
+ weight_dtype = torch.bfloat16
338
+ else:
339
+ if accelerator.mixed_precision == "fp16":
340
+ weight_dtype = torch.float16
341
+ elif accelerator.mixed_precision == "bf16":
342
+ weight_dtype = torch.bfloat16
343
+
344
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
345
+ # due to pytorch#99272, MPS does not yet support bfloat16.
346
+ raise ValueError(
347
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
348
+ )
349
+
350
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
351
+ transformer.to(accelerator.device, dtype=weight_dtype)
352
+ vae.to(accelerator.device, dtype=weight_dtype)
353
+
354
+ if args.gradient_checkpointing:
355
+ transformer.enable_gradient_checkpointing()
356
+
357
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
358
+ def save_model_hook(models, weights, output_dir):
359
+ if accelerator.is_main_process:
360
+ for model in models:
361
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
362
+ model: CogVideoXTransformer3DModel
363
+ model = unwrap_model(accelerator, model)
364
+ model.save_pretrained(
365
+ os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
366
+ )
367
+ else:
368
+ raise ValueError(f"Unexpected save model: {model.__class__}")
369
+
370
+ # make sure to pop weight so that corresponding model is not saved again
371
+ if weights:
372
+ weights.pop()
373
+
374
+ def load_model_hook(models, input_dir):
375
+ transformer_ = None
376
+ init_under_meta = False
377
+
378
+ # This is a bit of a hack but I don't know any other solution.
379
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
380
+ while len(models) > 0:
381
+ model = models.pop()
382
+
383
+ if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
384
+ transformer_ = unwrap_model(accelerator, model)
385
+ else:
386
+ raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
387
+ else:
388
+ with init_empty_weights():
389
+ transformer_ = CogVideoXTransformer3DModel.from_config(
390
+ args.pretrained_model_name_or_path, subfolder="transformer"
391
+ )
392
+ init_under_meta = True
393
+
394
+ load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
395
+ transformer_.register_to_config(**load_model.config)
396
+ transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
397
+ del load_model
398
+
399
+ # Make sure the trainable params are in float32. This is again needed since the base models
400
+ # are in `weight_dtype`. More details:
401
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
402
+ if args.mixed_precision == "fp16":
403
+ cast_training_params([transformer_])
404
+
405
+ accelerator.register_save_state_pre_hook(save_model_hook)
406
+ accelerator.register_load_state_pre_hook(load_model_hook)
407
+
408
+ # Enable TF32 for faster training on Ampere GPUs,
409
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
410
+ if args.allow_tf32 and torch.cuda.is_available():
411
+ torch.backends.cuda.matmul.allow_tf32 = True
412
+
413
+ if args.scale_lr:
414
+ args.learning_rate = (
415
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
416
+ )
417
+
418
+ # Make sure the trainable params are in float32.
419
+ if args.mixed_precision == "fp16":
420
+ # only upcast trainable parameters (LoRA) into fp32
421
+ cast_training_params([transformer], dtype=torch.float32)
422
+
423
+ transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
424
+
425
+ # Optimization parameters
426
+ transformer_parameters_with_lr = {
427
+ "params": transformer_parameters,
428
+ "lr": args.learning_rate,
429
+ }
430
+ params_to_optimize = [transformer_parameters_with_lr]
431
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
432
+
433
+ use_deepspeed_optimizer = (
434
+ accelerator.state.deepspeed_plugin is not None
435
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
436
+ )
437
+ use_deepspeed_scheduler = (
438
+ accelerator.state.deepspeed_plugin is not None
439
+ and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
440
+ )
441
+
442
+ optimizer = get_optimizer(
443
+ params_to_optimize=params_to_optimize,
444
+ optimizer_name=args.optimizer,
445
+ learning_rate=args.learning_rate,
446
+ beta1=args.beta1,
447
+ beta2=args.beta2,
448
+ beta3=args.beta3,
449
+ epsilon=args.epsilon,
450
+ weight_decay=args.weight_decay,
451
+ prodigy_decouple=args.prodigy_decouple,
452
+ prodigy_use_bias_correction=args.prodigy_use_bias_correction,
453
+ prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
454
+ use_8bit=args.use_8bit,
455
+ use_4bit=args.use_4bit,
456
+ use_torchao=args.use_torchao,
457
+ use_deepspeed=use_deepspeed_optimizer,
458
+ use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
459
+ offload_gradients=args.offload_gradients,
460
+ )
461
+
462
+ # Dataset and DataLoader
463
+ dataset_init_kwargs = {
464
+ "data_root": args.data_root,
465
+ "dataset_file": args.dataset_file,
466
+ "caption_column": args.caption_column,
467
+ "video_column": args.video_column,
468
+ "max_num_frames": args.max_num_frames,
469
+ "id_token": args.id_token,
470
+ "height_buckets": args.height_buckets,
471
+ "width_buckets": args.width_buckets,
472
+ "frame_buckets": args.frame_buckets,
473
+ "load_tensors": args.load_tensors,
474
+ "random_flip": args.random_flip,
475
+ }
476
+ if args.video_reshape_mode is None:
477
+ train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
478
+ else:
479
+ train_dataset = VideoDatasetWithResizeAndRectangleCrop(
480
+ video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
481
+ )
482
+
483
+ collate_fn = CollateFunction(weight_dtype, args.load_tensors)
484
+
485
+ train_dataloader = DataLoader(
486
+ train_dataset,
487
+ batch_size=1,
488
+ sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
489
+ collate_fn=collate_fn,
490
+ num_workers=args.dataloader_num_workers,
491
+ pin_memory=args.pin_memory,
492
+ )
493
+
494
+ # Scheduler and math around the number of training steps.
495
+ overrode_max_train_steps = False
496
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
497
+ if args.max_train_steps is None:
498
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
499
+ overrode_max_train_steps = True
500
+
501
+ if args.use_cpu_offload_optimizer:
502
+ lr_scheduler = None
503
+ accelerator.print(
504
+ "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
505
+ "you are training with those settings, they will be ignored."
506
+ )
507
+ else:
508
+ if use_deepspeed_scheduler:
509
+ from accelerate.utils import DummyScheduler
510
+
511
+ lr_scheduler = DummyScheduler(
512
+ name=args.lr_scheduler,
513
+ optimizer=optimizer,
514
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
515
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
516
+ )
517
+ else:
518
+ lr_scheduler = get_scheduler(
519
+ args.lr_scheduler,
520
+ optimizer=optimizer,
521
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
522
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
523
+ num_cycles=args.lr_num_cycles,
524
+ power=args.lr_power,
525
+ )
526
+
527
+ # Prepare everything with our `accelerator`.
528
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
529
+ transformer, optimizer, train_dataloader, lr_scheduler
530
+ )
531
+
532
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
533
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
534
+ if overrode_max_train_steps:
535
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
536
+ # Afterwards we recalculate our number of training epochs
537
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
538
+
539
+ # We need to initialize the trackers we use, and also store our configuration.
540
+ # The trackers initializes automatically on the main process.
541
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
542
+ tracker_name = args.tracker_name or "cogvideox-sft"
543
+ accelerator.init_trackers(tracker_name, config=vars(args))
544
+
545
+ accelerator.print("===== Memory before training =====")
546
+ reset_memory(accelerator.device)
547
+ print_memory(accelerator.device)
548
+
549
+ # Train!
550
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
551
+
552
+ accelerator.print("***** Running training *****")
553
+ accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
554
+ accelerator.print(f" Num examples = {len(train_dataset)}")
555
+ accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
556
+ accelerator.print(f" Num epochs = {args.num_train_epochs}")
557
+ accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
558
+ accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
559
+ accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
560
+ accelerator.print(f" Total optimization steps = {args.max_train_steps}")
561
+ global_step = 0
562
+ first_epoch = 0
563
+
564
+ # Potentially load in the weights and states from a previous save
565
+ if not args.resume_from_checkpoint:
566
+ initial_global_step = 0
567
+ else:
568
+ if args.resume_from_checkpoint != "latest":
569
+ path = os.path.basename(args.resume_from_checkpoint)
570
+ else:
571
+ # Get the most recent checkpoint
572
+ dirs = os.listdir(args.output_dir)
573
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
574
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
575
+ path = dirs[-1] if len(dirs) > 0 else None
576
+
577
+ if path is None:
578
+ accelerator.print(
579
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
580
+ )
581
+ args.resume_from_checkpoint = None
582
+ initial_global_step = 0
583
+ else:
584
+ accelerator.print(f"Resuming from checkpoint {path}")
585
+ accelerator.load_state(os.path.join(args.output_dir, path))
586
+ global_step = int(path.split("-")[1])
587
+
588
+ initial_global_step = global_step
589
+ first_epoch = global_step // num_update_steps_per_epoch
590
+
591
+ progress_bar = tqdm(
592
+ range(0, args.max_train_steps),
593
+ initial=initial_global_step,
594
+ desc="Steps",
595
+ # Only show the progress bar once on each machine.
596
+ disable=not accelerator.is_local_main_process,
597
+ )
598
+
599
+ # For DeepSpeed training
600
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
601
+
602
+ if args.load_tensors:
603
+ del vae, text_encoder
604
+ gc.collect()
605
+ torch.cuda.empty_cache()
606
+ torch.cuda.synchronize(accelerator.device)
607
+
608
+ alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
609
+
610
+ for epoch in range(first_epoch, args.num_train_epochs):
611
+ transformer.train()
612
+
613
+ for step, batch in enumerate(train_dataloader):
614
+ models_to_accumulate = [transformer]
615
+ logs = {}
616
+
617
+ with accelerator.accumulate(models_to_accumulate):
618
+ videos = batch["videos"].to(accelerator.device, non_blocking=True)
619
+ prompts = batch["prompts"]
620
+
621
+ # Encode videos
622
+ if not args.load_tensors:
623
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
624
+ latent_dist = vae.encode(videos).latent_dist
625
+ else:
626
+ latent_dist = DiagonalGaussianDistribution(videos)
627
+
628
+ videos = latent_dist.sample() * VAE_SCALING_FACTOR
629
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
630
+ videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
631
+ model_input = videos
632
+
633
+ # Encode prompts
634
+ if not args.load_tensors:
635
+ prompt_embeds = compute_prompt_embeddings(
636
+ tokenizer,
637
+ text_encoder,
638
+ prompts,
639
+ model_config.max_text_seq_length,
640
+ accelerator.device,
641
+ weight_dtype,
642
+ requires_grad=False,
643
+ )
644
+ else:
645
+ prompt_embeds = prompts.to(dtype=weight_dtype)
646
+
647
+ # Sample noise that will be added to the latents
648
+ noise = torch.randn_like(model_input)
649
+ batch_size, num_frames, num_channels, height, width = model_input.shape
650
+
651
+ # Sample a random timestep for each image
652
+ timesteps = torch.randint(
653
+ 0,
654
+ scheduler.config.num_train_timesteps,
655
+ (batch_size,),
656
+ dtype=torch.int64,
657
+ device=model_input.device,
658
+ )
659
+
660
+ # Prepare rotary embeds
661
+ image_rotary_emb = (
662
+ prepare_rotary_positional_embeddings(
663
+ height=height * VAE_SCALE_FACTOR_SPATIAL,
664
+ width=width * VAE_SCALE_FACTOR_SPATIAL,
665
+ num_frames=num_frames,
666
+ vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
667
+ patch_size=model_config.patch_size,
668
+ patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
669
+ attention_head_dim=model_config.attention_head_dim,
670
+ device=accelerator.device,
671
+ base_height=RoPE_BASE_HEIGHT,
672
+ base_width=RoPE_BASE_WIDTH,
673
+ )
674
+ if model_config.use_rotary_positional_embeddings
675
+ else None
676
+ )
677
+
678
+ # Add noise to the model input according to the noise magnitude at each timestep
679
+ # (this is the forward diffusion process)
680
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
681
+
682
+ # Predict the noise residual
683
+ model_output = transformer(
684
+ hidden_states=noisy_model_input,
685
+ encoder_hidden_states=prompt_embeds,
686
+ timestep=timesteps,
687
+ image_rotary_emb=image_rotary_emb,
688
+ return_dict=False,
689
+ )[0]
690
+
691
+ model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
692
+
693
+ weights = 1 / (1 - alphas_cumprod[timesteps])
694
+ while len(weights.shape) < len(model_pred.shape):
695
+ weights = weights.unsqueeze(-1)
696
+
697
+ target = model_input
698
+
699
+ loss = torch.mean(
700
+ (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
701
+ dim=1,
702
+ )
703
+ loss = loss.mean()
704
+ accelerator.backward(loss)
705
+
706
+ if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
707
+ gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
708
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
709
+ gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
710
+ logs.update(
711
+ {
712
+ "gradient_norm_before_clip": gradient_norm_before_clip,
713
+ "gradient_norm_after_clip": gradient_norm_after_clip,
714
+ }
715
+ )
716
+
717
+ if accelerator.state.deepspeed_plugin is None:
718
+ optimizer.step()
719
+ optimizer.zero_grad()
720
+
721
+ if not args.use_cpu_offload_optimizer:
722
+ lr_scheduler.step()
723
+
724
+ # Checks if the accelerator has performed an optimization step behind the scenes
725
+ if accelerator.sync_gradients:
726
+ progress_bar.update(1)
727
+ global_step += 1
728
+
729
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
730
+ if global_step % args.checkpointing_steps == 0:
731
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
732
+ if args.checkpoints_total_limit is not None:
733
+ checkpoints = os.listdir(args.output_dir)
734
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
735
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
736
+
737
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
738
+ if len(checkpoints) >= args.checkpoints_total_limit:
739
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
740
+ removing_checkpoints = checkpoints[0:num_to_remove]
741
+
742
+ logger.info(
743
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
744
+ )
745
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
746
+
747
+ for removing_checkpoint in removing_checkpoints:
748
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
749
+ shutil.rmtree(removing_checkpoint)
750
+
751
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
752
+ accelerator.save_state(save_path)
753
+ logger.info(f"Saved state to {save_path}")
754
+
755
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
756
+ logs.update(
757
+ {
758
+ "loss": loss.detach().item(),
759
+ "lr": last_lr,
760
+ }
761
+ )
762
+ progress_bar.set_postfix(**logs)
763
+ accelerator.log(logs, step=global_step)
764
+
765
+ if global_step >= args.max_train_steps:
766
+ break
767
+
768
+ if accelerator.is_main_process:
769
+ if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
770
+ accelerator.print("===== Memory before validation =====")
771
+ print_memory(accelerator.device)
772
+ torch.cuda.synchronize(accelerator.device)
773
+
774
+ pipe = CogVideoXPipeline.from_pretrained(
775
+ args.pretrained_model_name_or_path,
776
+ transformer=unwrap_model(accelerator, transformer),
777
+ scheduler=scheduler,
778
+ revision=args.revision,
779
+ variant=args.variant,
780
+ torch_dtype=weight_dtype,
781
+ )
782
+
783
+ if args.enable_slicing:
784
+ pipe.vae.enable_slicing()
785
+ if args.enable_tiling:
786
+ pipe.vae.enable_tiling()
787
+ if args.enable_model_cpu_offload:
788
+ pipe.enable_model_cpu_offload()
789
+
790
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
791
+ for validation_prompt in validation_prompts:
792
+ pipeline_args = {
793
+ "prompt": validation_prompt,
794
+ "guidance_scale": args.guidance_scale,
795
+ "use_dynamic_cfg": args.use_dynamic_cfg,
796
+ "height": args.height,
797
+ "width": args.width,
798
+ "max_sequence_length": model_config.max_text_seq_length,
799
+ }
800
+
801
+ log_validation(
802
+ accelerator=accelerator,
803
+ pipe=pipe,
804
+ args=args,
805
+ pipeline_args=pipeline_args,
806
+ epoch=epoch,
807
+ is_final_validation=False,
808
+ )
809
+
810
+ accelerator.print("===== Memory after validation =====")
811
+ print_memory(accelerator.device)
812
+ reset_memory(accelerator.device)
813
+
814
+ del pipe
815
+ gc.collect()
816
+ torch.cuda.empty_cache()
817
+ torch.cuda.synchronize(accelerator.device)
818
+
819
+ accelerator.wait_for_everyone()
820
+
821
+ if accelerator.is_main_process:
822
+ transformer = unwrap_model(accelerator, transformer)
823
+ dtype = (
824
+ torch.float16
825
+ if args.mixed_precision == "fp16"
826
+ else torch.bfloat16
827
+ if args.mixed_precision == "bf16"
828
+ else torch.float32
829
+ )
830
+ transformer = transformer.to(dtype)
831
+
832
+ transformer.save_pretrained(
833
+ os.path.join(args.output_dir, "transformer"),
834
+ safe_serialization=True,
835
+ max_shard_size="5GB",
836
+ )
837
+
838
+ # Cleanup trained models to save memory
839
+ if args.load_tensors:
840
+ del transformer
841
+ else:
842
+ del transformer, text_encoder, vae
843
+
844
+ gc.collect()
845
+ torch.cuda.empty_cache()
846
+ torch.cuda.synchronize(accelerator.device)
847
+
848
+ accelerator.print("===== Memory before testing =====")
849
+ print_memory(accelerator.device)
850
+ reset_memory(accelerator.device)
851
+
852
+ # Final test inference
853
+ pipe = CogVideoXPipeline.from_pretrained(
854
+ args.pretrained_model_name_or_path,
855
+ revision=args.revision,
856
+ variant=args.variant,
857
+ torch_dtype=weight_dtype,
858
+ )
859
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
860
+
861
+ if args.enable_slicing:
862
+ pipe.vae.enable_slicing()
863
+ if args.enable_tiling:
864
+ pipe.vae.enable_tiling()
865
+ if args.enable_model_cpu_offload:
866
+ pipe.enable_model_cpu_offload()
867
+
868
+ # Run inference
869
+ validation_outputs = []
870
+ if args.validation_prompt and args.num_validation_videos > 0:
871
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
872
+ for validation_prompt in validation_prompts:
873
+ pipeline_args = {
874
+ "prompt": validation_prompt,
875
+ "guidance_scale": args.guidance_scale,
876
+ "use_dynamic_cfg": args.use_dynamic_cfg,
877
+ "height": args.height,
878
+ "width": args.width,
879
+ }
880
+
881
+ video = log_validation(
882
+ accelerator=accelerator,
883
+ pipe=pipe,
884
+ args=args,
885
+ pipeline_args=pipeline_args,
886
+ epoch=epoch,
887
+ is_final_validation=True,
888
+ )
889
+ validation_outputs.extend(video)
890
+
891
+ accelerator.print("===== Memory after testing =====")
892
+ print_memory(accelerator.device)
893
+ reset_memory(accelerator.device)
894
+ torch.cuda.synchronize(accelerator.device)
895
+
896
+ if args.push_to_hub:
897
+ save_model_card(
898
+ repo_id,
899
+ videos=validation_outputs,
900
+ base_model=args.pretrained_model_name_or_path,
901
+ validation_prompt=args.validation_prompt,
902
+ repo_folder=args.output_dir,
903
+ fps=args.fps,
904
+ )
905
+ upload_folder(
906
+ repo_id=repo_id,
907
+ folder_path=args.output_dir,
908
+ commit_message="End of training",
909
+ ignore_patterns=["step_*", "epoch_*"],
910
+ )
911
+
912
+ accelerator.end_training()
913
+
914
+
915
+ if __name__ == "__main__":
916
+ args = get_args()
917
+ main(args)
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchvision.transforms as TT
9
+ from accelerate.logging import get_logger
10
+ from torch.utils.data import Dataset, Sampler
11
+ from torchvision import transforms
12
+ from torchvision.transforms import InterpolationMode
13
+ from torchvision.transforms.functional import resize
14
+
15
+
16
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
17
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
18
+ import decord # isort:skip
19
+
20
+ decord.bridge.set_bridge("torch")
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
25
+ WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
26
+ FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
27
+
28
+
29
+ class VideoDataset(Dataset):
30
+ def __init__(
31
+ self,
32
+ data_root: str,
33
+ dataset_file: Optional[str] = None,
34
+ caption_column: str = "text",
35
+ video_column: str = "video",
36
+ max_num_frames: int = 49,
37
+ id_token: Optional[str] = None,
38
+ height_buckets: List[int] = None,
39
+ width_buckets: List[int] = None,
40
+ frame_buckets: List[int] = None,
41
+ load_tensors: bool = False,
42
+ random_flip: Optional[float] = None,
43
+ image_to_video: bool = False,
44
+ ) -> None:
45
+ super().__init__()
46
+
47
+ self.data_root = Path(data_root)
48
+ self.dataset_file = dataset_file
49
+ self.caption_column = caption_column
50
+ self.video_column = video_column
51
+ self.max_num_frames = max_num_frames
52
+ self.id_token = f"{id_token.strip()} " if id_token else ""
53
+ self.height_buckets = height_buckets or HEIGHT_BUCKETS
54
+ self.width_buckets = width_buckets or WIDTH_BUCKETS
55
+ self.frame_buckets = frame_buckets or FRAME_BUCKETS
56
+ self.load_tensors = load_tensors
57
+ self.random_flip = random_flip
58
+ self.image_to_video = image_to_video
59
+
60
+ self.resolutions = [
61
+ (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
62
+ ]
63
+
64
+ # Two methods of loading data are supported.
65
+ # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
66
+ # make use of other columns too, such as a motion score or aesthetic score, by modifying the
67
+ # logic in CSV processing.
68
+ # - Using two files containing line-separate captions and relative paths to videos.
69
+ # For a more detailed explanation about preparing dataset format, checkout the README.
70
+ if dataset_file is None:
71
+ (
72
+ self.prompts,
73
+ self.video_paths,
74
+ ) = self._load_dataset_from_local_path()
75
+ else:
76
+ (
77
+ self.prompts,
78
+ self.video_paths,
79
+ ) = self._load_dataset_from_csv()
80
+
81
+ if len(self.video_paths) != len(self.prompts):
82
+ raise ValueError(
83
+ f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
84
+ )
85
+
86
+ self.video_transforms = transforms.Compose(
87
+ [
88
+ transforms.RandomHorizontalFlip(random_flip)
89
+ if random_flip
90
+ else transforms.Lambda(self.identity_transform),
91
+ transforms.Lambda(self.scale_transform),
92
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
93
+ ]
94
+ )
95
+
96
+ @staticmethod
97
+ def identity_transform(x):
98
+ return x
99
+
100
+ @staticmethod
101
+ def scale_transform(x):
102
+ return x / 255.0
103
+
104
+ def __len__(self) -> int:
105
+ return len(self.video_paths)
106
+
107
+ def __getitem__(self, index: int) -> Dict[str, Any]:
108
+ if isinstance(index, list):
109
+ # Here, index is actually a list of data objects that we need to return.
110
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
111
+ # to have information about num_frames, height and width. Since this is not stored
112
+ # as metadata, we need to read the video to get this information. You could read this
113
+ # information without loading the full video in memory, but we do it anyway. In order
114
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
115
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
116
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
117
+ # that data is not loaded a second time. PRs are welcome for improvements.
118
+ return index
119
+
120
+ if self.load_tensors:
121
+ image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
122
+
123
+ # This is hardcoded for now.
124
+ # The VAE's temporal compression ratio is 4.
125
+ # The VAE's spatial compression ratio is 8.
126
+ latent_num_frames = video_latents.size(1)
127
+ if latent_num_frames % 2 == 0:
128
+ num_frames = latent_num_frames * 4
129
+ else:
130
+ num_frames = (latent_num_frames - 1) * 4 + 1
131
+
132
+ height = video_latents.size(2) * 8
133
+ width = video_latents.size(3) * 8
134
+
135
+ return {
136
+ "prompt": prompt_embeds,
137
+ "image": image_latents,
138
+ "video": video_latents,
139
+ "video_metadata": {
140
+ "num_frames": num_frames,
141
+ "height": height,
142
+ "width": width,
143
+ },
144
+ }
145
+ else:
146
+ image, video, _ = self._preprocess_video(self.video_paths[index])
147
+
148
+ return {
149
+ "prompt": self.id_token + self.prompts[index],
150
+ "image": image,
151
+ "video": video,
152
+ "video_metadata": {
153
+ "num_frames": video.shape[0],
154
+ "height": video.shape[2],
155
+ "width": video.shape[3],
156
+ },
157
+ }
158
+
159
+ def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
160
+ if not self.data_root.exists():
161
+ raise ValueError("Root folder for videos does not exist")
162
+
163
+ prompt_path = self.data_root.joinpath(self.caption_column)
164
+ video_path = self.data_root.joinpath(self.video_column)
165
+
166
+ if not prompt_path.exists() or not prompt_path.is_file():
167
+ raise ValueError(
168
+ "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
169
+ )
170
+ if not video_path.exists() or not video_path.is_file():
171
+ raise ValueError(
172
+ "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
173
+ )
174
+
175
+ with open(prompt_path, "r", encoding="utf-8") as file:
176
+ prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
177
+ with open(video_path, "r", encoding="utf-8") as file:
178
+ video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
179
+
180
+ if not self.load_tensors and any(not path.is_file() for path in video_paths):
181
+ raise ValueError(
182
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
183
+ )
184
+
185
+ return prompts, video_paths
186
+
187
+ def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
188
+ df = pd.read_csv(self.dataset_file)
189
+ prompts = df[self.caption_column].tolist()
190
+ video_paths = df[self.video_column].tolist()
191
+ video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
192
+
193
+ if any(not path.is_file() for path in video_paths):
194
+ raise ValueError(
195
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
196
+ )
197
+
198
+ return prompts, video_paths
199
+
200
+ def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
201
+ r"""
202
+ Loads a single video, or latent and prompt embedding, based on initialization parameters.
203
+
204
+ If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
205
+ F, C, H and W are the frames, channels, height and width of the input video.
206
+
207
+ If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
208
+ F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
209
+ and embedding dimension of prompt embeddings.
210
+ """
211
+ if self.load_tensors:
212
+ return self._load_preprocessed_latents_and_embeds(path)
213
+ else:
214
+ video_reader = decord.VideoReader(uri=path.as_posix())
215
+ video_num_frames = len(video_reader)
216
+
217
+ indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
218
+ frames = video_reader.get_batch(indices)
219
+ frames = frames[: self.max_num_frames].float()
220
+ frames = frames.permute(0, 3, 1, 2).contiguous()
221
+ frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
222
+
223
+ image = frames[:1].clone() if self.image_to_video else None
224
+
225
+ return image, frames, None
226
+
227
+ def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ filename_without_ext = path.name.split(".")[0]
229
+ pt_filename = f"{filename_without_ext}.pt"
230
+
231
+ # The current path is something like: /a/b/c/d/videos/00001.mp4
232
+ # We need to reach: /a/b/c/d/video_latents/00001.pt
233
+ image_latents_path = path.parent.parent.joinpath("image_latents")
234
+ video_latents_path = path.parent.parent.joinpath("video_latents")
235
+ embeds_path = path.parent.parent.joinpath("prompt_embeds")
236
+
237
+ if (
238
+ not video_latents_path.exists()
239
+ or not embeds_path.exists()
240
+ or (self.image_to_video and not image_latents_path.exists())
241
+ ):
242
+ raise ValueError(
243
+ f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
244
+ )
245
+
246
+ if self.image_to_video:
247
+ image_latent_filepath = image_latents_path.joinpath(pt_filename)
248
+ video_latent_filepath = video_latents_path.joinpath(pt_filename)
249
+ embeds_filepath = embeds_path.joinpath(pt_filename)
250
+
251
+ if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
252
+ if self.image_to_video:
253
+ image_latent_filepath = image_latent_filepath.as_posix()
254
+ video_latent_filepath = video_latent_filepath.as_posix()
255
+ embeds_filepath = embeds_filepath.as_posix()
256
+ raise ValueError(
257
+ f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
258
+ )
259
+
260
+ images = (
261
+ torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
262
+ )
263
+ latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
264
+ embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
265
+
266
+ return images, latents, embeds
267
+
268
+
269
+ class VideoDatasetWithResizing(VideoDataset):
270
+ def __init__(self, *args, **kwargs) -> None:
271
+ super().__init__(*args, **kwargs)
272
+
273
+ def _preprocess_video(self, path: Path) -> torch.Tensor:
274
+ if self.load_tensors:
275
+ return self._load_preprocessed_latents_and_embeds(path)
276
+ else:
277
+ video_reader = decord.VideoReader(uri=path.as_posix())
278
+ video_num_frames = len(video_reader)
279
+ nearest_frame_bucket = min(
280
+ self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
281
+ )
282
+
283
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
284
+
285
+ frames = video_reader.get_batch(frame_indices)
286
+ frames = frames[:nearest_frame_bucket].float()
287
+ frames = frames.permute(0, 3, 1, 2).contiguous()
288
+
289
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
290
+ frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
291
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
292
+
293
+ image = frames[:1].clone() if self.image_to_video else None
294
+
295
+ return image, frames, None
296
+
297
+ def _find_nearest_resolution(self, height, width):
298
+ nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
299
+ return nearest_res[1], nearest_res[2]
300
+
301
+
302
+ class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
303
+ def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
304
+ super().__init__(*args, **kwargs)
305
+ self.video_reshape_mode = video_reshape_mode
306
+
307
+ def _resize_for_rectangle_crop(self, arr, image_size):
308
+ reshape_mode = self.video_reshape_mode
309
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
310
+ arr = resize(
311
+ arr,
312
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
313
+ interpolation=InterpolationMode.BICUBIC,
314
+ )
315
+ else:
316
+ arr = resize(
317
+ arr,
318
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
319
+ interpolation=InterpolationMode.BICUBIC,
320
+ )
321
+
322
+ h, w = arr.shape[2], arr.shape[3]
323
+ arr = arr.squeeze(0)
324
+
325
+ delta_h = h - image_size[0]
326
+ delta_w = w - image_size[1]
327
+
328
+ if reshape_mode == "random" or reshape_mode == "none":
329
+ top = np.random.randint(0, delta_h + 1)
330
+ left = np.random.randint(0, delta_w + 1)
331
+ elif reshape_mode == "center":
332
+ top, left = delta_h // 2, delta_w // 2
333
+ else:
334
+ raise NotImplementedError
335
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
336
+ return arr
337
+
338
+ def _preprocess_video(self, path: Path) -> torch.Tensor:
339
+ if self.load_tensors:
340
+ return self._load_preprocessed_latents_and_embeds(path)
341
+ else:
342
+ video_reader = decord.VideoReader(uri=path.as_posix())
343
+ video_num_frames = len(video_reader)
344
+ nearest_frame_bucket = min(
345
+ self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
346
+ )
347
+
348
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
349
+
350
+ frames = video_reader.get_batch(frame_indices)
351
+ frames = frames[:nearest_frame_bucket].float()
352
+ frames = frames.permute(0, 3, 1, 2).contiguous()
353
+
354
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
355
+ frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
356
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
357
+
358
+ image = frames[:1].clone() if self.image_to_video else None
359
+
360
+ return image, frames, None
361
+
362
+ def _find_nearest_resolution(self, height, width):
363
+ nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
364
+ return nearest_res[1], nearest_res[2]
365
+
366
+
367
+ class BucketSampler(Sampler):
368
+ r"""
369
+ PyTorch Sampler that groups 3D data by height, width and frames.
370
+
371
+ Args:
372
+ data_source (`VideoDataset`):
373
+ A PyTorch dataset object that is an instance of `VideoDataset`.
374
+ batch_size (`int`, defaults to `8`):
375
+ The batch size to use for training.
376
+ shuffle (`bool`, defaults to `True`):
377
+ Whether or not to shuffle the data in each batch before dispatching to dataloader.
378
+ drop_last (`bool`, defaults to `False`):
379
+ Whether or not to drop incomplete buckets of data after completely iterating over all data
380
+ in the dataset. If set to True, only batches that have `batch_size` number of entries will
381
+ be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
382
+ and batches that do not have `batch_size` number of entries will also be yielded.
383
+ """
384
+
385
+ def __init__(
386
+ self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
387
+ ) -> None:
388
+ self.data_source = data_source
389
+ self.batch_size = batch_size
390
+ self.shuffle = shuffle
391
+ self.drop_last = drop_last
392
+
393
+ self.buckets = {resolution: [] for resolution in data_source.resolutions}
394
+
395
+ self._raised_warning_for_drop_last = False
396
+
397
+ def __len__(self):
398
+ if self.drop_last and not self._raised_warning_for_drop_last:
399
+ self._raised_warning_for_drop_last = True
400
+ logger.warning(
401
+ "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
402
+ )
403
+ return (len(self.data_source) + self.batch_size - 1) // self.batch_size
404
+
405
+ def __iter__(self):
406
+ for index, data in enumerate(self.data_source):
407
+ video_metadata = data["video_metadata"]
408
+ f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
409
+
410
+ self.buckets[(f, h, w)].append(data)
411
+ if len(self.buckets[(f, h, w)]) == self.batch_size:
412
+ if self.shuffle:
413
+ random.shuffle(self.buckets[(f, h, w)])
414
+ yield self.buckets[(f, h, w)]
415
+ del self.buckets[(f, h, w)]
416
+ self.buckets[(f, h, w)] = []
417
+
418
+ if self.drop_last:
419
+ return
420
+
421
+ for fhw, bucket in list(self.buckets.items()):
422
+ if len(bucket) == 0:
423
+ continue
424
+ if self.shuffle:
425
+ random.shuffle(bucket)
426
+ yield bucket
427
+ del self.buckets[fhw]
428
+ self.buckets[fhw] = []
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import functools
5
+ import json
6
+ import os
7
+ import pathlib
8
+ import queue
9
+ import traceback
10
+ import uuid
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any, Dict, List, Optional, Union
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ from diffusers import AutoencoderKLCogVideoX
17
+ from diffusers.training_utils import set_seed
18
+ from diffusers.utils import export_to_video, get_logger
19
+ from torch.utils.data import DataLoader
20
+ from torchvision import transforms
21
+ from tqdm import tqdm
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+
25
+ import decord # isort:skip
26
+
27
+ from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
28
+
29
+
30
+ decord.bridge.set_bridge("torch")
31
+
32
+ logger = get_logger(__name__)
33
+
34
+ DTYPE_MAPPING = {
35
+ "fp32": torch.float32,
36
+ "fp16": torch.float16,
37
+ "bf16": torch.bfloat16,
38
+ }
39
+
40
+
41
+ def check_height(x: Any) -> int:
42
+ x = int(x)
43
+ if x % 16 != 0:
44
+ raise argparse.ArgumentTypeError(
45
+ f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria."
46
+ )
47
+ return x
48
+
49
+
50
+ def check_width(x: Any) -> int:
51
+ x = int(x)
52
+ if x % 16 != 0:
53
+ raise argparse.ArgumentTypeError(
54
+ f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria."
55
+ )
56
+ return x
57
+
58
+
59
+ def check_frames(x: Any) -> int:
60
+ x = int(x)
61
+ if x % 4 != 0 and x % 4 != 1:
62
+ raise argparse.ArgumentTypeError(
63
+ f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria."
64
+ )
65
+ return x
66
+
67
+
68
+ def get_args() -> Dict[str, Any]:
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument(
71
+ "--model_id",
72
+ type=str,
73
+ default="THUDM/CogVideoX-2b",
74
+ help="Hugging Face model ID to use for tokenizer, text encoder and VAE.",
75
+ )
76
+ parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.")
77
+ parser.add_argument(
78
+ "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data."
79
+ )
80
+ parser.add_argument(
81
+ "--caption_column",
82
+ type=str,
83
+ default="caption",
84
+ help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).",
85
+ )
86
+ parser.add_argument(
87
+ "--video_column",
88
+ type=str,
89
+ default="video",
90
+ help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).",
91
+ )
92
+ parser.add_argument(
93
+ "--id_token",
94
+ type=str,
95
+ default=None,
96
+ help="Identifier token appended to the start of each prompt if provided.",
97
+ )
98
+ parser.add_argument(
99
+ "--height_buckets",
100
+ nargs="+",
101
+ type=check_height,
102
+ default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
103
+ )
104
+ parser.add_argument(
105
+ "--width_buckets",
106
+ nargs="+",
107
+ type=check_width,
108
+ default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
109
+ )
110
+ parser.add_argument(
111
+ "--frame_buckets",
112
+ nargs="+",
113
+ type=check_frames,
114
+ default=[49],
115
+ )
116
+ parser.add_argument(
117
+ "--random_flip",
118
+ type=float,
119
+ default=None,
120
+ help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
121
+ )
122
+ parser.add_argument(
123
+ "--dataloader_num_workers",
124
+ type=int,
125
+ default=0,
126
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
127
+ )
128
+ parser.add_argument(
129
+ "--pin_memory",
130
+ action="store_true",
131
+ help="Whether or not to use the pinned memory setting in pytorch dataloader.",
132
+ )
133
+ parser.add_argument(
134
+ "--video_reshape_mode",
135
+ type=str,
136
+ default=None,
137
+ help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
138
+ )
139
+ parser.add_argument(
140
+ "--save_image_latents",
141
+ action="store_true",
142
+ help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.",
143
+ )
144
+ parser.add_argument(
145
+ "--output_dir",
146
+ type=str,
147
+ required=True,
148
+ help="Path to output directory where preprocessed videos/latents/embeddings will be saved.",
149
+ )
150
+ parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.")
151
+ parser.add_argument(
152
+ "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings."
153
+ )
154
+ parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.")
155
+ parser.add_argument(
156
+ "--save_latents_and_embeddings",
157
+ action="store_true",
158
+ help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.",
159
+ )
160
+ parser.add_argument(
161
+ "--use_slicing",
162
+ action="store_true",
163
+ help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
164
+ )
165
+ parser.add_argument(
166
+ "--use_tiling",
167
+ action="store_true",
168
+ help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
169
+ )
170
+ parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.")
171
+ parser.add_argument(
172
+ "--num_decode_threads",
173
+ type=int,
174
+ default=0,
175
+ help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.",
176
+ )
177
+ parser.add_argument(
178
+ "--dtype",
179
+ type=str,
180
+ choices=["fp32", "fp16", "bf16"],
181
+ default="fp32",
182
+ help="Data type to use when generating latents and prompt embeddings.",
183
+ )
184
+ parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
185
+ parser.add_argument(
186
+ "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts."
187
+ )
188
+ return parser.parse_args()
189
+
190
+
191
+ def _get_t5_prompt_embeds(
192
+ tokenizer: T5Tokenizer,
193
+ text_encoder: T5EncoderModel,
194
+ prompt: Union[str, List[str]],
195
+ num_videos_per_prompt: int = 1,
196
+ max_sequence_length: int = 226,
197
+ device: Optional[torch.device] = None,
198
+ dtype: Optional[torch.dtype] = None,
199
+ text_input_ids=None,
200
+ ):
201
+ prompt = [prompt] if isinstance(prompt, str) else prompt
202
+ batch_size = len(prompt)
203
+
204
+ if tokenizer is not None:
205
+ text_inputs = tokenizer(
206
+ prompt,
207
+ padding="max_length",
208
+ max_length=max_sequence_length,
209
+ truncation=True,
210
+ add_special_tokens=True,
211
+ return_tensors="pt",
212
+ )
213
+ text_input_ids = text_inputs.input_ids
214
+ else:
215
+ if text_input_ids is None:
216
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
217
+
218
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
219
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
220
+
221
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
222
+ _, seq_len, _ = prompt_embeds.shape
223
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
224
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
225
+
226
+ return prompt_embeds
227
+
228
+
229
+ def encode_prompt(
230
+ tokenizer: T5Tokenizer,
231
+ text_encoder: T5EncoderModel,
232
+ prompt: Union[str, List[str]],
233
+ num_videos_per_prompt: int = 1,
234
+ max_sequence_length: int = 226,
235
+ device: Optional[torch.device] = None,
236
+ dtype: Optional[torch.dtype] = None,
237
+ text_input_ids=None,
238
+ ):
239
+ prompt = [prompt] if isinstance(prompt, str) else prompt
240
+ prompt_embeds = _get_t5_prompt_embeds(
241
+ tokenizer,
242
+ text_encoder,
243
+ prompt=prompt,
244
+ num_videos_per_prompt=num_videos_per_prompt,
245
+ max_sequence_length=max_sequence_length,
246
+ device=device,
247
+ dtype=dtype,
248
+ text_input_ids=text_input_ids,
249
+ )
250
+ return prompt_embeds
251
+
252
+
253
+ def compute_prompt_embeddings(
254
+ tokenizer: T5Tokenizer,
255
+ text_encoder: T5EncoderModel,
256
+ prompts: List[str],
257
+ max_sequence_length: int,
258
+ device: torch.device,
259
+ dtype: torch.dtype,
260
+ requires_grad: bool = False,
261
+ ):
262
+ if requires_grad:
263
+ prompt_embeds = encode_prompt(
264
+ tokenizer,
265
+ text_encoder,
266
+ prompts,
267
+ num_videos_per_prompt=1,
268
+ max_sequence_length=max_sequence_length,
269
+ device=device,
270
+ dtype=dtype,
271
+ )
272
+ else:
273
+ with torch.no_grad():
274
+ prompt_embeds = encode_prompt(
275
+ tokenizer,
276
+ text_encoder,
277
+ prompts,
278
+ num_videos_per_prompt=1,
279
+ max_sequence_length=max_sequence_length,
280
+ device=device,
281
+ dtype=dtype,
282
+ )
283
+ return prompt_embeds
284
+
285
+
286
+ to_pil_image = transforms.ToPILImage(mode="RGB")
287
+
288
+
289
+ def save_image(image: torch.Tensor, path: pathlib.Path) -> None:
290
+ image = image.to(dtype=torch.float32).clamp(-1, 1)
291
+ image = to_pil_image(image.float())
292
+ image.save(path)
293
+
294
+
295
+ def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None:
296
+ video = video.to(dtype=torch.float32).clamp(-1, 1)
297
+ video = [to_pil_image(frame) for frame in video]
298
+ export_to_video(video, path, fps=fps)
299
+
300
+
301
+ def save_prompt(prompt: str, path: pathlib.Path) -> None:
302
+ with open(path, "w", encoding="utf-8") as file:
303
+ file.write(prompt)
304
+
305
+
306
+ def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None:
307
+ with open(path, "w", encoding="utf-8") as file:
308
+ file.write(json.dumps(metadata))
309
+
310
+
311
+ @torch.no_grad()
312
+ def serialize_artifacts(
313
+ batch_size: int,
314
+ fps: int,
315
+ images_dir: Optional[pathlib.Path] = None,
316
+ image_latents_dir: Optional[pathlib.Path] = None,
317
+ videos_dir: Optional[pathlib.Path] = None,
318
+ video_latents_dir: Optional[pathlib.Path] = None,
319
+ prompts_dir: Optional[pathlib.Path] = None,
320
+ prompt_embeds_dir: Optional[pathlib.Path] = None,
321
+ images: Optional[torch.Tensor] = None,
322
+ image_latents: Optional[torch.Tensor] = None,
323
+ videos: Optional[torch.Tensor] = None,
324
+ video_latents: Optional[torch.Tensor] = None,
325
+ prompts: Optional[List[str]] = None,
326
+ prompt_embeds: Optional[torch.Tensor] = None,
327
+ ) -> None:
328
+ num_frames, height, width = videos.size(1), videos.size(3), videos.size(4)
329
+ metadata = [{"num_frames": num_frames, "height": height, "width": width}]
330
+
331
+ data_folder_mapper_list = [
332
+ (images, images_dir, lambda img, path: save_image(img[0], path), "png"),
333
+ (image_latents, image_latents_dir, torch.save, "pt"),
334
+ (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"),
335
+ (video_latents, video_latents_dir, torch.save, "pt"),
336
+ (prompts, prompts_dir, save_prompt, "txt"),
337
+ (prompt_embeds, prompt_embeds_dir, torch.save, "pt"),
338
+ (metadata, videos_dir, save_metadata, "txt"),
339
+ ]
340
+ filenames = [uuid.uuid4() for _ in range(batch_size)]
341
+
342
+ for data, folder, save_fn, extension in data_folder_mapper_list:
343
+ if data is None:
344
+ continue
345
+ for slice, filename in zip(data, filenames):
346
+ if isinstance(slice, torch.Tensor):
347
+ slice = slice.clone().to("cpu")
348
+ path = folder.joinpath(f"{filename}.{extension}")
349
+ save_fn(slice, path)
350
+
351
+
352
+ def save_intermediates(output_queue: queue.Queue) -> None:
353
+ while True:
354
+ try:
355
+ item = output_queue.get(timeout=30)
356
+ if item is None:
357
+ break
358
+ serialize_artifacts(**item)
359
+
360
+ except queue.Empty:
361
+ continue
362
+
363
+
364
+ @torch.no_grad()
365
+ def main():
366
+ args = get_args()
367
+ set_seed(args.seed)
368
+
369
+ output_dir = pathlib.Path(args.output_dir)
370
+ tmp_dir = output_dir.joinpath("tmp")
371
+
372
+ output_dir.mkdir(parents=True, exist_ok=True)
373
+ tmp_dir.mkdir(parents=True, exist_ok=True)
374
+
375
+ # Create task queue for non-blocking serializing of artifacts
376
+ output_queue = queue.Queue()
377
+ save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers)
378
+ save_future = save_thread.submit(save_intermediates, output_queue)
379
+
380
+ # Initialize distributed processing
381
+ if "LOCAL_RANK" in os.environ:
382
+ local_rank = int(os.environ["LOCAL_RANK"])
383
+ torch.cuda.set_device(local_rank)
384
+ dist.init_process_group(backend="nccl")
385
+ world_size = dist.get_world_size()
386
+ rank = dist.get_rank()
387
+ else:
388
+ # Single GPU
389
+ local_rank = 0
390
+ world_size = 1
391
+ rank = 0
392
+ torch.cuda.set_device(rank)
393
+
394
+ # Create folders where intermediate tensors from each rank will be saved
395
+ images_dir = tmp_dir.joinpath(f"images/{rank}")
396
+ image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}")
397
+ videos_dir = tmp_dir.joinpath(f"videos/{rank}")
398
+ video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}")
399
+ prompts_dir = tmp_dir.joinpath(f"prompts/{rank}")
400
+ prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}")
401
+
402
+ images_dir.mkdir(parents=True, exist_ok=True)
403
+ image_latents_dir.mkdir(parents=True, exist_ok=True)
404
+ videos_dir.mkdir(parents=True, exist_ok=True)
405
+ video_latents_dir.mkdir(parents=True, exist_ok=True)
406
+ prompts_dir.mkdir(parents=True, exist_ok=True)
407
+ prompt_embeds_dir.mkdir(parents=True, exist_ok=True)
408
+
409
+ weight_dtype = DTYPE_MAPPING[args.dtype]
410
+ target_fps = args.target_fps
411
+
412
+ # 1. Dataset
413
+ dataset_init_kwargs = {
414
+ "data_root": args.data_root,
415
+ "dataset_file": args.dataset_file,
416
+ "caption_column": args.caption_column,
417
+ "video_column": args.video_column,
418
+ "max_num_frames": args.max_num_frames,
419
+ "id_token": args.id_token,
420
+ "height_buckets": args.height_buckets,
421
+ "width_buckets": args.width_buckets,
422
+ "frame_buckets": args.frame_buckets,
423
+ "load_tensors": False,
424
+ "random_flip": args.random_flip,
425
+ "image_to_video": args.save_image_latents,
426
+ }
427
+ if args.video_reshape_mode is None:
428
+ dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
429
+ else:
430
+ dataset = VideoDatasetWithResizeAndRectangleCrop(
431
+ video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
432
+ )
433
+
434
+ original_dataset_size = len(dataset)
435
+
436
+ # Split data among GPUs
437
+ if world_size > 1:
438
+ samples_per_gpu = original_dataset_size // world_size
439
+ start_index = rank * samples_per_gpu
440
+ end_index = start_index + samples_per_gpu
441
+ if rank == world_size - 1:
442
+ end_index = original_dataset_size # Make sure the last GPU gets the remaining data
443
+
444
+ # Slice the data
445
+ dataset.prompts = dataset.prompts[start_index:end_index]
446
+ dataset.video_paths = dataset.video_paths[start_index:end_index]
447
+ else:
448
+ pass
449
+
450
+ rank_dataset_size = len(dataset)
451
+
452
+ # 2. Dataloader
453
+ def collate_fn(data):
454
+ prompts = [x["prompt"] for x in data[0]]
455
+
456
+ images = None
457
+ if args.save_image_latents:
458
+ images = [x["image"] for x in data[0]]
459
+ images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
460
+
461
+ videos = [x["video"] for x in data[0]]
462
+ videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
463
+
464
+ return {
465
+ "images": images,
466
+ "videos": videos,
467
+ "prompts": prompts,
468
+ }
469
+
470
+ dataloader = DataLoader(
471
+ dataset,
472
+ batch_size=1,
473
+ sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False),
474
+ collate_fn=collate_fn,
475
+ num_workers=args.dataloader_num_workers,
476
+ pin_memory=args.pin_memory,
477
+ )
478
+
479
+ # 3. Prepare models
480
+ device = f"cuda:{rank}"
481
+
482
+ if args.save_latents_and_embeddings:
483
+ tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
484
+ text_encoder = T5EncoderModel.from_pretrained(
485
+ args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype
486
+ )
487
+ text_encoder = text_encoder.to(device)
488
+
489
+ vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype)
490
+ vae = vae.to(device)
491
+
492
+ if args.use_slicing:
493
+ vae.enable_slicing()
494
+ if args.use_tiling:
495
+ vae.enable_tiling()
496
+
497
+ # 4. Compute latents and embeddings and save
498
+ if rank == 0:
499
+ iterator = tqdm(
500
+ dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size
501
+ )
502
+ else:
503
+ iterator = dataloader
504
+
505
+ for step, batch in enumerate(iterator):
506
+ try:
507
+ images = None
508
+ image_latents = None
509
+ video_latents = None
510
+ prompt_embeds = None
511
+
512
+ if args.save_image_latents:
513
+ images = batch["images"].to(device, non_blocking=True)
514
+ images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
515
+
516
+ videos = batch["videos"].to(device, non_blocking=True)
517
+ videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
518
+
519
+ prompts = batch["prompts"]
520
+
521
+ # Encode videos & images
522
+ if args.save_latents_and_embeddings:
523
+ if args.use_slicing:
524
+ if args.save_image_latents:
525
+ encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
526
+ image_latents = torch.cat(encoded_slices)
527
+ image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
528
+
529
+ encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
530
+ video_latents = torch.cat(encoded_slices)
531
+
532
+ else:
533
+ if args.save_image_latents:
534
+ image_latents = vae._encode(images)
535
+ image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
536
+
537
+ video_latents = vae._encode(videos)
538
+
539
+ video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
540
+
541
+ # Encode prompts
542
+ prompt_embeds = compute_prompt_embeddings(
543
+ tokenizer,
544
+ text_encoder,
545
+ prompts,
546
+ args.max_sequence_length,
547
+ device,
548
+ weight_dtype,
549
+ requires_grad=False,
550
+ )
551
+
552
+ if images is not None:
553
+ images = (images.permute(0, 2, 1, 3, 4) + 1) / 2
554
+
555
+ videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2
556
+
557
+ output_queue.put(
558
+ {
559
+ "batch_size": len(prompts),
560
+ "fps": target_fps,
561
+ "images_dir": images_dir,
562
+ "image_latents_dir": image_latents_dir,
563
+ "videos_dir": videos_dir,
564
+ "video_latents_dir": video_latents_dir,
565
+ "prompts_dir": prompts_dir,
566
+ "prompt_embeds_dir": prompt_embeds_dir,
567
+ "images": images,
568
+ "image_latents": image_latents,
569
+ "videos": videos,
570
+ "video_latents": video_latents,
571
+ "prompts": prompts,
572
+ "prompt_embeds": prompt_embeds,
573
+ }
574
+ )
575
+
576
+ except Exception:
577
+ print("-------------------------")
578
+ print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}")
579
+ traceback.print_exc()
580
+ print("-------------------------")
581
+
582
+ # 5. Complete distributed processing
583
+ if world_size > 1:
584
+ dist.barrier()
585
+ dist.destroy_process_group()
586
+
587
+ output_queue.put(None)
588
+ save_thread.shutdown(wait=True)
589
+ save_future.result()
590
+
591
+ # 6. Combine results from each rank
592
+ if rank == 0:
593
+ print(
594
+ f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`"
595
+ )
596
+
597
+ # Move files from each rank to common directory
598
+ for subfolder, extension in [
599
+ ("images", "png"),
600
+ ("image_latents", "pt"),
601
+ ("videos", "mp4"),
602
+ ("video_latents", "pt"),
603
+ ("prompts", "txt"),
604
+ ("prompt_embeds", "pt"),
605
+ ("videos", "txt"),
606
+ ]:
607
+ tmp_subfolder = tmp_dir.joinpath(subfolder)
608
+ combined_subfolder = output_dir.joinpath(subfolder)
609
+ combined_subfolder.mkdir(parents=True, exist_ok=True)
610
+ pattern = f"*.{extension}"
611
+
612
+ for file in tmp_subfolder.rglob(pattern):
613
+ file.replace(combined_subfolder / file.name)
614
+
615
+ # Remove temporary directories
616
+ def rmdir_recursive(dir: pathlib.Path) -> None:
617
+ for child in dir.iterdir():
618
+ if child.is_file():
619
+ child.unlink()
620
+ else:
621
+ rmdir_recursive(child)
622
+ dir.rmdir()
623
+
624
+ rmdir_recursive(tmp_dir)
625
+
626
+ # Combine prompts and videos into individual text files and single jsonl
627
+ prompts_folder = output_dir.joinpath("prompts")
628
+ prompts = []
629
+ stems = []
630
+
631
+ for filename in prompts_folder.rglob("*.txt"):
632
+ with open(filename, "r") as file:
633
+ prompts.append(file.read().strip())
634
+ stems.append(filename.stem)
635
+
636
+ prompts_txt = output_dir.joinpath("prompts.txt")
637
+ videos_txt = output_dir.joinpath("videos.txt")
638
+ data_jsonl = output_dir.joinpath("data.jsonl")
639
+
640
+ with open(prompts_txt, "w") as file:
641
+ for prompt in prompts:
642
+ file.write(f"{prompt}\n")
643
+
644
+ with open(videos_txt, "w") as file:
645
+ for stem in stems:
646
+ file.write(f"videos/{stem}.mp4\n")
647
+
648
+ with open(data_jsonl, "w") as file:
649
+ for prompt, stem in zip(prompts, stems):
650
+ video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt")
651
+ with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file:
652
+ metadata = json.loads(metadata_file.read())
653
+
654
+ data = {
655
+ "prompt": prompt,
656
+ "prompt_embed": f"prompt_embeds/{stem}.pt",
657
+ "image": f"images/{stem}.png",
658
+ "image_latent": f"image_latents/{stem}.pt",
659
+ "video": f"videos/{stem}.mp4",
660
+ "video_latent": f"video_latents/{stem}.pt",
661
+ "metadata": metadata,
662
+ }
663
+ file.write(json.dumps(data) + "\n")
664
+
665
+ print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`")
666
+
667
+
668
+ if __name__ == "__main__":
669
+ main()
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .text_encoder import compute_prompt_embeddings
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ from transformers import T5EncoderModel, T5Tokenizer
5
+
6
+
7
+ def _get_t5_prompt_embeds(
8
+ tokenizer: T5Tokenizer,
9
+ text_encoder: T5EncoderModel,
10
+ prompt: Union[str, List[str]],
11
+ num_videos_per_prompt: int = 1,
12
+ max_sequence_length: int = 226,
13
+ device: Optional[torch.device] = None,
14
+ dtype: Optional[torch.dtype] = None,
15
+ text_input_ids=None,
16
+ ):
17
+ prompt = [prompt] if isinstance(prompt, str) else prompt
18
+ batch_size = len(prompt)
19
+
20
+ if tokenizer is not None:
21
+ text_inputs = tokenizer(
22
+ prompt,
23
+ padding="max_length",
24
+ max_length=max_sequence_length,
25
+ truncation=True,
26
+ add_special_tokens=True,
27
+ return_tensors="pt",
28
+ )
29
+ text_input_ids = text_inputs.input_ids
30
+ else:
31
+ if text_input_ids is None:
32
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
33
+
34
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
35
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
36
+
37
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
38
+ _, seq_len, _ = prompt_embeds.shape
39
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
40
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
41
+
42
+ return prompt_embeds
43
+
44
+
45
+ def encode_prompt(
46
+ tokenizer: T5Tokenizer,
47
+ text_encoder: T5EncoderModel,
48
+ prompt: Union[str, List[str]],
49
+ num_videos_per_prompt: int = 1,
50
+ max_sequence_length: int = 226,
51
+ device: Optional[torch.device] = None,
52
+ dtype: Optional[torch.dtype] = None,
53
+ text_input_ids=None,
54
+ ):
55
+ prompt = [prompt] if isinstance(prompt, str) else prompt
56
+ prompt_embeds = _get_t5_prompt_embeds(
57
+ tokenizer,
58
+ text_encoder,
59
+ prompt=prompt,
60
+ num_videos_per_prompt=num_videos_per_prompt,
61
+ max_sequence_length=max_sequence_length,
62
+ device=device,
63
+ dtype=dtype,
64
+ text_input_ids=text_input_ids,
65
+ )
66
+ return prompt_embeds
67
+
68
+
69
+ def compute_prompt_embeddings(
70
+ tokenizer: T5Tokenizer,
71
+ text_encoder: T5EncoderModel,
72
+ prompt: str,
73
+ max_sequence_length: int,
74
+ device: torch.device,
75
+ dtype: torch.dtype,
76
+ requires_grad: bool = False,
77
+ ):
78
+ if requires_grad:
79
+ prompt_embeds = encode_prompt(
80
+ tokenizer,
81
+ text_encoder,
82
+ prompt,
83
+ num_videos_per_prompt=1,
84
+ max_sequence_length=max_sequence_length,
85
+ device=device,
86
+ dtype=dtype,
87
+ )
88
+ else:
89
+ with torch.no_grad():
90
+ prompt_embeds = encode_prompt(
91
+ tokenizer,
92
+ text_encoder,
93
+ prompt,
94
+ num_videos_per_prompt=1,
95
+ max_sequence_length=max_sequence_length,
96
+ device=device,
97
+ dtype=dtype,
98
+ )
99
+ return prompt_embeds