Spaces:
Running
Running
Commit
·
9fd1204
1
Parent(s):
76a0a50
we are going to hack into finetrainers
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml +51 -0
- docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml +34 -0
- docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml +30 -0
- docs/finetrainers-src-codebase/.gitignore +179 -0
- docs/finetrainers-src-codebase/CONTRIBUTING.md +41 -0
- docs/finetrainers-src-codebase/LICENSE +201 -0
- docs/finetrainers-src-codebase/Makefile +11 -0
- docs/{finetrainers/documentation_README.md → finetrainers-src-codebase/README.md} +6 -3
- docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml +22 -0
- docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml +23 -0
- docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml +17 -0
- docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml +17 -0
- docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml +17 -0
- docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml +17 -0
- docs/finetrainers-src-codebase/assets/contribute.md +16 -0
- docs/finetrainers-src-codebase/assets/contribute_zh.md +16 -0
- docs/finetrainers-src-codebase/assets/dataset_zh.md +72 -0
- docs/finetrainers-src-codebase/assets/sft_2b.png +0 -0
- docs/finetrainers-src-codebase/assets/sft_5b.png +0 -0
- docs/finetrainers-src-codebase/assets/tests/metadata.csv +2 -0
- docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md +20 -0
- docs/{finetrainers/documentation_args.md → finetrainers-src-codebase/docs/args.md} +44 -5
- docs/{finetrainers/documentation_dataset_README.md → finetrainers-src-codebase/docs/dataset/README.md} +11 -4
- docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md +44 -0
- docs/{finetrainers/documentation_environment.md → finetrainers-src-codebase/docs/environment.md} +11 -0
- docs/{finetrainers/documentation_models_README.md → finetrainers-src-codebase/docs/models/README.md} +0 -0
- docs/finetrainers-src-codebase/docs/models/attention.md +263 -0
- docs/{finetrainers/documentation_models_cogvideox.md → finetrainers-src-codebase/docs/models/cogvideox.md} +6 -6
- docs/finetrainers-src-codebase/docs/models/cogview4.md +94 -0
- docs/finetrainers-src-codebase/docs/models/flux.md +53 -0
- docs/{finetrainers/documentation_models_hunyuan_video.md → finetrainers-src-codebase/docs/models/hunyuan_video.md} +3 -3
- docs/{finetrainers/documentation_models_ltx_video.md → finetrainers-src-codebase/docs/models/ltx_video.md} +3 -3
- docs/{finetrainers/documentation_models_optimization.md → finetrainers-src-codebase/docs/models/optimization.md} +0 -0
- docs/{finetrainers/documentation_models_wan.md → finetrainers-src-codebase/docs/models/wan.md} +11 -1
- docs/{finetrainers/documentation_optimizers.md → finetrainers-src-codebase/docs/optimizer.md} +0 -0
- docs/{finetrainers/documentation_parallel_processing_README.md → finetrainers-src-codebase/docs/parallel/README.md} +8 -3
- docs/{finetrainers/documentation_trainers_control_trainer.md → finetrainers-src-codebase/docs/trainer/control_trainer.md} +0 -0
- docs/{finetrainers/documentation_trainers_sft_trainer.md → finetrainers-src-codebase/docs/trainer/sft_trainer.md} +0 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/README.md +459 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md +455 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py +0 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py +484 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py +1016 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py +947 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py +955 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py +917 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py +428 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py +669 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py +1 -0
- docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py +99 -0
docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/bug_report.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F41B Bug Report"
|
2 |
+
description: Submit a bug report to help us improve CogVideoX-Factory / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX-Factory 开源框架
|
3 |
+
body:
|
4 |
+
- type: textarea
|
5 |
+
id: system-info
|
6 |
+
attributes:
|
7 |
+
label: System Info / 系統信息
|
8 |
+
description: Your operating environment / 您的运行环境信息
|
9 |
+
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
|
10 |
+
validations:
|
11 |
+
required: true
|
12 |
+
|
13 |
+
- type: checkboxes
|
14 |
+
id: information-scripts-examples
|
15 |
+
attributes:
|
16 |
+
label: Information / 问题信息
|
17 |
+
description: 'The problem arises when using: / 问题出现在'
|
18 |
+
options:
|
19 |
+
- label: "The official example scripts / 官方的示例脚本"
|
20 |
+
- label: "My own modified scripts / 我自己修改的脚本和任务"
|
21 |
+
|
22 |
+
- type: textarea
|
23 |
+
id: reproduction
|
24 |
+
validations:
|
25 |
+
required: true
|
26 |
+
attributes:
|
27 |
+
label: Reproduction / 复现过程
|
28 |
+
description: |
|
29 |
+
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
|
30 |
+
If you have code snippets, error messages, stack traces, please provide them here as well.
|
31 |
+
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
32 |
+
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
|
33 |
+
|
34 |
+
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
|
35 |
+
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
|
36 |
+
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
37 |
+
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
|
38 |
+
placeholder: |
|
39 |
+
Steps to reproduce the behavior/复现Bug的步骤:
|
40 |
+
|
41 |
+
1.
|
42 |
+
2.
|
43 |
+
3.
|
44 |
+
|
45 |
+
- type: textarea
|
46 |
+
id: expected-behavior
|
47 |
+
validations:
|
48 |
+
required: true
|
49 |
+
attributes:
|
50 |
+
label: Expected behavior / 期待表现
|
51 |
+
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
|
docs/finetrainers-src-codebase/.github/ISSUE_TEMPLATE/feature-request.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F680 Feature request"
|
2 |
+
description: Submit a request for a new CogVideoX-Factory feature / 提交一个新的 CogVideoX-Factory 开源项目的功能建议
|
3 |
+
labels: [ "feature" ]
|
4 |
+
body:
|
5 |
+
- type: textarea
|
6 |
+
id: feature-request
|
7 |
+
validations:
|
8 |
+
required: true
|
9 |
+
attributes:
|
10 |
+
label: Feature request / 功能建议
|
11 |
+
description: |
|
12 |
+
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
|
13 |
+
对功能建议的简述。最好提供对应的论文和代码链接。
|
14 |
+
|
15 |
+
- type: textarea
|
16 |
+
id: motivation
|
17 |
+
validations:
|
18 |
+
required: true
|
19 |
+
attributes:
|
20 |
+
label: Motivation / 动机
|
21 |
+
description: |
|
22 |
+
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
|
23 |
+
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
|
24 |
+
|
25 |
+
- type: textarea
|
26 |
+
id: contribution
|
27 |
+
validations:
|
28 |
+
required: true
|
29 |
+
attributes:
|
30 |
+
label: Your contribution / 您的贡献
|
31 |
+
description: |
|
32 |
+
|
33 |
+
Your PR link or any other link you can help with.
|
34 |
+
您的PR链接或者其他您能提供帮助的链接。
|
docs/finetrainers-src-codebase/.github/workflows/pr_tests.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Fast tests for PRs
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
concurrency:
|
9 |
+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
10 |
+
cancel-in-progress: true
|
11 |
+
|
12 |
+
jobs:
|
13 |
+
check_code_quality:
|
14 |
+
runs-on: ubuntu-22.04
|
15 |
+
steps:
|
16 |
+
- uses: actions/checkout@v3
|
17 |
+
- name: Set up Python
|
18 |
+
uses: actions/setup-python@v4
|
19 |
+
with:
|
20 |
+
python-version: "3.8"
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install ruff==0.9.10
|
25 |
+
- name: Check quality
|
26 |
+
run: make quality
|
27 |
+
- name: Check if failure
|
28 |
+
if: ${{ failure() }}
|
29 |
+
run: |
|
30 |
+
echo "Quality check failed. Please install ruff: `pip install ruff` and then run `make style && make quality` from the root of the repository." >> $GITHUB_STEP_SUMMARY
|
docs/finetrainers-src-codebase/.gitignore
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# JetBrains
|
7 |
+
.idea
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113 |
+
.pdm.toml
|
114 |
+
.pdm-python
|
115 |
+
.pdm-build/
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
166 |
+
|
167 |
+
# manually added
|
168 |
+
wandb/
|
169 |
+
*.txt
|
170 |
+
dump*
|
171 |
+
outputs*
|
172 |
+
*.slurm
|
173 |
+
.vscode/
|
174 |
+
*dummy*
|
175 |
+
*curated*
|
176 |
+
validation_dataset/
|
177 |
+
wan-framepack/
|
178 |
+
|
179 |
+
!requirements.txt
|
docs/finetrainers-src-codebase/CONTRIBUTING.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to contribute to Finetrainers
|
2 |
+
|
3 |
+
Finetrainers is an early-stage library for training diffusion models. Everyone is welcome to contribute - models, algorithms, refactors, docs, etc. - but due to the early stage of the project, we recommend bigger contributions be discussed in an issue before submitting a PR. Eventually, we will have a better process for this!
|
4 |
+
|
5 |
+
## How to contribute
|
6 |
+
|
7 |
+
### Adding a new model
|
8 |
+
|
9 |
+
If you would like to add a new model, please follow these steps:
|
10 |
+
|
11 |
+
- Create a new file in the `finetrainers/models` directory with the model name (if it's new), or use the same directory if it's a variant of an existing model.
|
12 |
+
- Implement the model specification in the file. For more details on what a model specification should look like, see the [ModelSpecification](TODO(aryan): add link) documentation.
|
13 |
+
- Update the supported configs in `finetrainers/config.py` to include the new model and the training types supported.
|
14 |
+
- Add a dummy model specification in the `tests/models` directory.
|
15 |
+
- Make sure to test training with the following settings:
|
16 |
+
- Single GPU
|
17 |
+
- 2x GPU with `--dp_degree 2 --dp_shards 1`
|
18 |
+
- 2x GPU with `--dp_degree 1 --dp_shards 2`
|
19 |
+
|
20 |
+
For `SFTTrainer` additions, please make sure to train with atleast 1000 steps (atleast 2000 data points) to ensure the model training is working as expected.
|
21 |
+
- Open a PR with your changes. Please make sure to share your wandb logs for the above training settings in the PR description. This will help us verify the training is working as expected.
|
22 |
+
|
23 |
+
### Adding a new algorithm
|
24 |
+
|
25 |
+
Currently, we are not accepting algorithm contributions. We will update this section once we are better ready 🤗
|
26 |
+
|
27 |
+
### Refactors
|
28 |
+
|
29 |
+
The library is in a very early stage. There are many instances of dead code, poorly written abstractions, and other issues. If you would like to refactor/clean-up a part of the codebase, please open an issue to discuss the changes before submitting a PR.
|
30 |
+
|
31 |
+
### Dataset improvements
|
32 |
+
|
33 |
+
Any changes to dataset/dataloader implementations can be submitted directly. The improvements and reasons for the changes should be conveyed appropriately for us to move quickly 🤗
|
34 |
+
|
35 |
+
### Documentation
|
36 |
+
|
37 |
+
Due to the early stage of the project, the documentation is not as comprehensive as we would like. Any improvements/refactors are welcome directly!
|
38 |
+
|
39 |
+
## Asking for help
|
40 |
+
|
41 |
+
If you have any questions, feel free to open an issue and we will be sure to help you out asap! Please make sure to describe your issues in either English (preferable) or Chinese. Any other language will make it hard for us to help you, so we will most likely close such issues without explanation/answer.
|
docs/finetrainers-src-codebase/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
docs/finetrainers-src-codebase/Makefile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: quality style
|
2 |
+
|
3 |
+
check_dirs := finetrainers tests examples train.py setup.py
|
4 |
+
|
5 |
+
quality:
|
6 |
+
ruff check $(check_dirs) --exclude examples/_legacy
|
7 |
+
ruff format --check $(check_dirs) --exclude examples/_legacy
|
8 |
+
|
9 |
+
style:
|
10 |
+
ruff check $(check_dirs) --fix --exclude examples/_legacy
|
11 |
+
ruff format $(check_dirs) --exclude examples/_legacy
|
docs/{finetrainers/documentation_README.md → finetrainers-src-codebase/README.md}
RENAMED
@@ -30,10 +30,10 @@ Checkout to the latest stable release tag:
|
|
30 |
|
31 |
```bash
|
32 |
git fetch --all --tags
|
33 |
-
git checkout tags/v0.
|
34 |
```
|
35 |
|
36 |
-
Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.
|
37 |
|
38 |
#### Using the main branch
|
39 |
|
@@ -54,9 +54,10 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
|
|
54 |
|
55 |
## Features
|
56 |
|
57 |
-
- DDP, FSDP-2 & HSDP support
|
58 |
- LoRA and full-rank finetuning; Conditional Control training
|
59 |
- Memory-efficient single-GPU training
|
|
|
60 |
- Auto-detection of commonly used dataset formats
|
61 |
- Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more
|
62 |
- Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets
|
@@ -65,6 +66,8 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
|
|
65 |
|
66 |
## News
|
67 |
|
|
|
|
|
68 |
- 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan!
|
69 |
- 🔥 **2025-04-08**: `torch.compile` support added!
|
70 |
- 🔥 **2025-04-06**: Flux support added!
|
|
|
30 |
|
31 |
```bash
|
32 |
git fetch --all --tags
|
33 |
+
git checkout tags/v0.2.0
|
34 |
```
|
35 |
|
36 |
+
Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.2.0-release) for the latest stable release.
|
37 |
|
38 |
#### Using the main branch
|
39 |
|
|
|
54 |
|
55 |
## Features
|
56 |
|
57 |
+
- DDP, FSDP-2 & HSDP, CP support
|
58 |
- LoRA and full-rank finetuning; Conditional Control training
|
59 |
- Memory-efficient single-GPU training
|
60 |
+
- Multiple attention backends supported - `flash`, `flex`, `sage`, `xformers` (see [attention](./docs/models/attention.md) docs)
|
61 |
- Auto-detection of commonly used dataset formats
|
62 |
- Combined image/video datasets, multiple chainable local/remote datasets, multi-resolution bucketing & more
|
63 |
- Memory-efficient precomputation support with/without on-the-fly precomputation for large scale datasets
|
|
|
66 |
|
67 |
## News
|
68 |
|
69 |
+
- 🔥 **2025-04-25**: Support for different attention providers added!
|
70 |
+
- 🔥 **2025-04-21**: Wan I2V supported added!
|
71 |
- 🔥 **2025-04-12**: Channel-concatenated control conditioning support added for CogView4 and Wan!
|
72 |
- 🔥 **2025-04-08**: `torch.compile` support added!
|
73 |
- 🔥 **2025-04-06**: Flux support added!
|
docs/finetrainers-src-codebase/accelerate_configs/compiled_1.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
dynamo_config:
|
6 |
+
dynamo_backend: INDUCTOR
|
7 |
+
dynamo_mode: max-autotune
|
8 |
+
dynamo_use_dynamic: true
|
9 |
+
dynamo_use_fullgraph: false
|
10 |
+
enable_cpu_affinity: false
|
11 |
+
gpu_ids: '3'
|
12 |
+
machine_rank: 0
|
13 |
+
main_training_function: main
|
14 |
+
mixed_precision: bf16
|
15 |
+
num_machines: 1
|
16 |
+
num_processes: 1
|
17 |
+
rdzv_backend: static
|
18 |
+
same_network: true
|
19 |
+
tpu_env: []
|
20 |
+
tpu_use_cluster: false
|
21 |
+
tpu_use_sudo: false
|
22 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/accelerate_configs/deepspeed.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 1
|
5 |
+
gradient_clipping: 1.0
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: false
|
9 |
+
zero_stage: 2
|
10 |
+
distributed_type: DEEPSPEED
|
11 |
+
downcast_bf16: 'no'
|
12 |
+
enable_cpu_affinity: false
|
13 |
+
machine_rank: 0
|
14 |
+
main_training_function: main
|
15 |
+
mixed_precision: bf16
|
16 |
+
num_machines: 1
|
17 |
+
num_processes: 2
|
18 |
+
rdzv_backend: static
|
19 |
+
same_network: true
|
20 |
+
tpu_env: []
|
21 |
+
tpu_use_cluster: false
|
22 |
+
tpu_use_sudo: false
|
23 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_1.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: '3'
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 1
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_2.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: 0,1
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 2
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_4.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: 0,1,2,3
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 4
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/accelerate_configs/uncompiled_8.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: all
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 8
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
docs/finetrainers-src-codebase/assets/contribute.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributions Welcome
|
2 |
+
|
3 |
+
This project is in a very early stage, and we welcome contributions from everyone. We hope to receive contributions and support in the following areas:
|
4 |
+
|
5 |
+
1. Support for more models. In addition to CogVideoX models, we also highly encourage contributions supporting other models.
|
6 |
+
2. Support for richer datasets. In our example, we used a Disney video generation dataset, but we hope to support more datasets as the current one is too limited for deeper fine-tuning exploration.
|
7 |
+
3. Anything in `TODO` we mention in our README.md
|
8 |
+
|
9 |
+
## How to Submit
|
10 |
+
|
11 |
+
We welcome you to create a new PR and describe the corresponding contribution. We will review it as soon as possible.
|
12 |
+
|
13 |
+
## Naming Conventions
|
14 |
+
|
15 |
+
- Please use English for naming, avoid using pinyin or other languages. All comments should be in English.
|
16 |
+
- Strictly follow PEP8 conventions, and use underscores to separate words. Please avoid using names like a, b, c.
|
docs/finetrainers-src-codebase/assets/contribute_zh.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 欢迎你们的贡献
|
2 |
+
|
3 |
+
本项目属于非常初级的阶段,欢迎大家进行贡献。我们希望在以下方面得到贡献和支持:
|
4 |
+
|
5 |
+
1. 支持更多的模型,除了 CogVideoX 模型之外的模型,我们也非常支持。
|
6 |
+
2. 更丰富的数据集支持。在我们的例子中,我们使用了一个 Disney 视频生成数据集,但是我们希望能够支持更多的数据集,这个数据集太少了,并不足以进行更深的微调探索。
|
7 |
+
3. 任何我们在README中`TODO`提到的内容。
|
8 |
+
|
9 |
+
## 提交方式
|
10 |
+
|
11 |
+
我们欢迎您直接创建一个新的PR,并说明对应的贡献,我们将第一时间查看。
|
12 |
+
|
13 |
+
## 命名规范
|
14 |
+
|
15 |
+
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
16 |
+
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
docs/finetrainers-src-codebase/assets/dataset_zh.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 数据集格式
|
2 |
+
|
3 |
+
### 提示词数据集要求
|
4 |
+
|
5 |
+
创建 `prompt.txt` 文件,文件应包含逐行分隔的提示。请注意,提示必须是英文,并且建议使用 [提示润色脚本](https://github.com/THUDM/CogVideo/blob/main/inference/convert_demo.py) 进行润色。或者可以使用 [CogVideo-caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) 进行数据标注:
|
6 |
+
|
7 |
+
```
|
8 |
+
A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.
|
9 |
+
A black and white animated sequence on a ship’s deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language...
|
10 |
+
...
|
11 |
+
```
|
12 |
+
|
13 |
+
### 视频数据集要求
|
14 |
+
|
15 |
+
该框架支持的分辨率和帧数需要满足以下条件:
|
16 |
+
|
17 |
+
- **支持的分辨率(宽 * 高)**:
|
18 |
+
- 任意分辨率且必须能被32整除。例如,`720 * 480`, `1920 * 1020` 等分辨率。
|
19 |
+
|
20 |
+
- **支持的帧数(Frames)**:
|
21 |
+
- 必须是 `4 * k` 或 `4 * k + 1`(例如:16, 32, 49, 81)
|
22 |
+
|
23 |
+
所有的视频建议放在一个文件夹中。
|
24 |
+
|
25 |
+
|
26 |
+
接着,创建 `videos.txt` 文件。 `videos.txt` 文件应包含逐行分隔的视频文件路径。请注意,路径必须相对于 `--data_root` 目录。格式如下:
|
27 |
+
|
28 |
+
```
|
29 |
+
videos/00000.mp4
|
30 |
+
videos/00001.mp4
|
31 |
+
...
|
32 |
+
```
|
33 |
+
|
34 |
+
对于有兴趣了解更多细节的开发者,您可以查看相关的 `BucketSampler` 代码。
|
35 |
+
|
36 |
+
### 数据集结构
|
37 |
+
|
38 |
+
您的数据集结构应如下所示,通过运行`tree`命令,你能看到:
|
39 |
+
|
40 |
+
```
|
41 |
+
dataset
|
42 |
+
├── prompt.txt
|
43 |
+
├── videos.txt
|
44 |
+
├── videos
|
45 |
+
├── videos/00000.mp4
|
46 |
+
├── videos/00001.mp4
|
47 |
+
├── ...
|
48 |
+
```
|
49 |
+
|
50 |
+
### 使用数据集
|
51 |
+
|
52 |
+
当使用此格式时,`--caption_column` 应为 `prompt.txt`,`--video_column` 应为 `videos.txt`。如果您的数据存储在 CSV
|
53 |
+
文件中,也可以指定 `--dataset_file` 为 CSV 文件的路径,`--caption_column` 和 `--video_column` 为 CSV
|
54 |
+
文件中的实际列名。请参考 [test_dataset](../tests/test_dataset.py) 文件中的一些简单示例。
|
55 |
+
|
56 |
+
例如,使用 [这个](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) Disney 数据集进行微调。下载可通过🤗
|
57 |
+
Hugging Face CLI 完成:
|
58 |
+
|
59 |
+
```
|
60 |
+
huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
|
61 |
+
```
|
62 |
+
|
63 |
+
该数据集已按照预期格式准备好,可直接使用。但是,直接使用视频数据集可能会导致较小 VRAM 的 GPU 出现
|
64 |
+
OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae)
|
65 |
+
(将视频编码为潜在空间)和大型 [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/)
|
66 |
+
|
67 |
+
文本编码器。为了降低内存需求,您可以使用 `training/prepare_dataset.py` 脚本预先计算潜在变量和嵌入。
|
68 |
+
|
69 |
+
填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings`
|
70 |
+
以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。
|
71 |
+
在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和
|
72 |
+
VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。
|
docs/finetrainers-src-codebase/assets/sft_2b.png
ADDED
![]() |
docs/finetrainers-src-codebase/assets/sft_5b.png
ADDED
![]() |
docs/finetrainers-src-codebase/assets/tests/metadata.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
video,caption
|
2 |
+
"videos/hiker.mp4","""A hiker standing at the top of a mountain, triumphantly, high quality"""
|
docs/finetrainers-src-codebase/docs/_NOTES_FOR_FUTURE_ME.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Notes for Future Me
|
2 |
+
|
3 |
+
>![NOTE]
|
4 |
+
> This doc page is intended for developers and contributors.
|
5 |
+
|
6 |
+
FSDP dump:
|
7 |
+
- https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
|
8 |
+
- https://github.com/pytorch/pytorch/issues/114299
|
9 |
+
- Using FSDP1 requires that all FSDP flat parameters are of the same dtype. For LoRA training, we default lora parameters to fp32 and transformer parameters to dtype chosen by user. There seems to be no easy workaround than performing lora training in same dtype.
|
10 |
+
- https://github.com/pytorch/pytorch/issues/100945
|
11 |
+
- https://github.com/pytorch/torchtune/blob/9b3836028fd0b48f593ea43474b86880c49a4d74/recipes/lora_finetune_distributed.py
|
12 |
+
- https://github.com/KellerJordan/modded-nanogpt/pull/68
|
13 |
+
- https://github.com/pytorch/pytorch/pull/125394: monkey-patch method for FSDP pre/post-hooks to be triggered for method other than `forward`
|
14 |
+
- https://github.com/pytorch/pytorch/pull/127786:
|
15 |
+
- https://github.com/pytorch/pytorch/pull/130949:
|
16 |
+
- Sanity saver: create optimizers after parallelizing/activation-checkpointing models
|
17 |
+
|
18 |
+
DTensor:
|
19 |
+
- https://github.com/pytorch/pytorch/issues/88838
|
20 |
+
- https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/parallel/test_parallelize_api.py
|
docs/{finetrainers/documentation_args.md → finetrainers-src-codebase/docs/args.md}
RENAMED
@@ -75,7 +75,10 @@ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embe
|
|
75 |
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
|
76 |
by default, and recommend adding more layers to the default list based on the model architecture.
|
77 |
compile_modules (`List[str]`, defaults to `[]`):
|
78 |
-
Modules that should be regionally compiled with `torch.compile`.
|
|
|
|
|
|
|
79 |
|
80 |
DATASET ARGUMENTS
|
81 |
-----------------
|
@@ -109,6 +112,9 @@ dataset_config (`str`):
|
|
109 |
dataset_shuffle_buffer_size (`int`, defaults to `1`):
|
110 |
The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
|
111 |
value of `1` means that the dataset will not be shuffled.
|
|
|
|
|
|
|
112 |
precomputation_items (`int`, defaults to `512`):
|
113 |
Number of data samples to precompute at once for memory-efficient training. The higher this value,
|
114 |
the more disk memory will be used to save the precomputed samples (conditions and latents).
|
@@ -118,8 +124,16 @@ precomputation_dir (`str`, defaults to `None`):
|
|
118 |
precomputation_once (`bool`, defaults to `False`):
|
119 |
Precompute embeddings from all datasets at once before training. This is useful to save time during training
|
120 |
with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
|
121 |
-
training when required
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
DATALOADER_ARGUMENTS
|
125 |
--------------------
|
@@ -248,8 +262,6 @@ logging_dir (`str`, defaults to `logs`):
|
|
248 |
The directory where the logs will be stored.
|
249 |
logging_steps (`int`, defaults to `1`):
|
250 |
Training logs will be tracked every `logging_steps` steps.
|
251 |
-
allow_tf32 (`bool`, defaults to `False`):
|
252 |
-
Whether or not to allow the use of TF32 matmul on compatible hardware.
|
253 |
nccl_timeout (`int`, defaults to `1800`):
|
254 |
Timeout for the NCCL communication.
|
255 |
report_to (`str`, defaults to `wandb`):
|
@@ -260,6 +272,33 @@ verbose (`int`, defaults to `1`):
|
|
260 |
- 1: Diffusers/Transformers info logging on local main process only
|
261 |
- 2: Diffusers/Transformers debug logging on local main process only
|
262 |
- 3: Diffusers/Transformers debug logging on all processes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
```
|
264 |
|
265 |
## SFT training
|
|
|
75 |
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
|
76 |
by default, and recommend adding more layers to the default list based on the model architecture.
|
77 |
compile_modules (`List[str]`, defaults to `[]`):
|
78 |
+
Modules that should be regionally compiled with `torch.compile`.
|
79 |
+
compile_scopes (`str`, defaults to `None`):
|
80 |
+
The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as
|
81 |
+
`--compile_modules`. If `None`, will default to `regional` for all modules.
|
82 |
|
83 |
DATASET ARGUMENTS
|
84 |
-----------------
|
|
|
112 |
dataset_shuffle_buffer_size (`int`, defaults to `1`):
|
113 |
The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
|
114 |
value of `1` means that the dataset will not be shuffled.
|
115 |
+
enable_precomputation (`bool`, defaults to `False`):
|
116 |
+
Whether or not to precompute the embeddings for the dataset. This is useful for faster training. If set to `True`,
|
117 |
+
the embeddings will be precomputed and saved to disk and loaded as required.
|
118 |
precomputation_items (`int`, defaults to `512`):
|
119 |
Number of data samples to precompute at once for memory-efficient training. The higher this value,
|
120 |
the more disk memory will be used to save the precomputed samples (conditions and latents).
|
|
|
124 |
precomputation_once (`bool`, defaults to `False`):
|
125 |
Precompute embeddings from all datasets at once before training. This is useful to save time during training
|
126 |
with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
|
127 |
+
training when required (that is, computing embeddings of more data samples once `precomputation_items` of them
|
128 |
+
have been exhausted across all distributed ranks). Make sure to set `precomputation_items` to a reasonable value
|
129 |
+
in line with the size of your dataset(s).
|
130 |
+
precomputation_reuse (`bool`, defaults to `False`):
|
131 |
+
Reuse precomputed embeddings from previous training runs. This is useful to save time during training
|
132 |
+
with medium/large datasets. By default, old precomputed embeddings that exist in the specified precomputation
|
133 |
+
directory, or default precomputation dir `{output_dir}/precomputed` will be deleted if this is not set to `True`.
|
134 |
+
This flag is ignored if `enable_precomputation` is `False`. The topology of the distributed training run must be
|
135 |
+
the same as the one used to precompute the embeddings for this to work correctly (this limitation will be
|
136 |
+
addressed in the future).
|
137 |
|
138 |
DATALOADER_ARGUMENTS
|
139 |
--------------------
|
|
|
262 |
The directory where the logs will be stored.
|
263 |
logging_steps (`int`, defaults to `1`):
|
264 |
Training logs will be tracked every `logging_steps` steps.
|
|
|
|
|
265 |
nccl_timeout (`int`, defaults to `1800`):
|
266 |
Timeout for the NCCL communication.
|
267 |
report_to (`str`, defaults to `wandb`):
|
|
|
272 |
- 1: Diffusers/Transformers info logging on local main process only
|
273 |
- 2: Diffusers/Transformers debug logging on local main process only
|
274 |
- 3: Diffusers/Transformers debug logging on all processes
|
275 |
+
|
276 |
+
TORCH CONFIG ARGUMENTS
|
277 |
+
----------------------
|
278 |
+
allow_tf32 (`bool`, defaults to `False`):
|
279 |
+
Whether or not to allow the use of TF32 matmul on compatible hardware.
|
280 |
+
float32_matmul_precision (`str`, defaults to `highest`):
|
281 |
+
The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].
|
282 |
+
```
|
283 |
+
|
284 |
+
### Attention Provider
|
285 |
+
|
286 |
+
These arguments are relevant to setting attention provider for different modeling components. The attention providers may be set differently for training and validation/inference.
|
287 |
+
|
288 |
+
```
|
289 |
+
attn_provider_training (`str`, defaults to "native"):
|
290 |
+
The attention provider to use for training. Choose between
|
291 |
+
[
|
292 |
+
'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
|
293 |
+
'_native_math'
|
294 |
+
]
|
295 |
+
attn_provider_inference (`str`, defaults to "native"):
|
296 |
+
The attention provider to use for validation. Choose between
|
297 |
+
[
|
298 |
+
'flash', 'flash_varlen', 'flex', 'native', '_native_cudnn', '_native_efficient', '_native_flash',
|
299 |
+
'_native_math', 'sage', 'sage_varlen', '_sage_qk_int8_pv_fp8_cuda', '_sage_qk_int8_pv_fp8_cuda_sm90',
|
300 |
+
'_sage_qk_int8_pv_fp16_cuda', '_sage_qk_int8_pv_fp16_triton', 'xformers'
|
301 |
+
]
|
302 |
```
|
303 |
|
304 |
## SFT training
|
docs/{finetrainers/documentation_dataset_README.md → finetrainers-src-codebase/docs/dataset/README.md}
RENAMED
@@ -57,6 +57,8 @@ dataset
|
|
57 |
|
58 |
#### CSV/JSON/JSONL format
|
59 |
|
|
|
|
|
60 |
> [!NOTE]
|
61 |
> Relevant classes to look for implementation:
|
62 |
> - ImageFolderDataset
|
@@ -75,6 +77,8 @@ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely
|
|
75 |
|
76 |
Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
|
77 |
|
|
|
|
|
78 |
## Validation Dataset Format
|
79 |
|
80 |
Arguments related to validation are:
|
@@ -148,18 +152,21 @@ For memory efficient training, it is important to precompute conditional and lat
|
|
148 |
|
149 |
The following is a high-level overview of how datasets are loaded and preprocessed:
|
150 |
|
151 |
-
- Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed).
|
|
|
152 |
- The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
|
153 |
-
- If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way
|
154 |
- The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
|
155 |
-
- The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--
|
156 |
- When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
|
157 |
|
158 |
## Understanding how datasets are precomputed
|
159 |
|
160 |
-
There are
|
|
|
161 |
- `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
|
162 |
- `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
|
163 |
- `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
|
|
|
164 |
|
165 |
Batching is not yet supported for precomputation. This will be added in the future.
|
|
|
57 |
|
58 |
#### CSV/JSON/JSONL format
|
59 |
|
60 |
+
- Supported names are: `metadata.json`, `metadata.jsonl`, `metadata.csv`
|
61 |
+
|
62 |
> [!NOTE]
|
63 |
> Relevant classes to look for implementation:
|
64 |
> - ImageFolderDataset
|
|
|
77 |
|
78 |
Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
|
79 |
|
80 |
+
|
81 |
+
|
82 |
## Validation Dataset Format
|
83 |
|
84 |
Arguments related to validation are:
|
|
|
152 |
|
153 |
The following is a high-level overview of how datasets are loaded and preprocessed:
|
154 |
|
155 |
+
- Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Multiple datasets can be chained together. For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and create a new copy of the dataset containing multiresolution data. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions.
|
156 |
+
- When chaining multiple different datasets, make sure they are roughly the same size to avoid having smaller datasets repeatedly being used in the training loop. This is because the datasets are loaded in a round-robin fashion. For example, if you have 2 datasets of size 1000 and 2000, the first dataset will be fully seen twice before the second dataset is fully seen once by the model.
|
157 |
- The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
|
158 |
+
- If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way].
|
159 |
- The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
|
160 |
+
- The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precomputation_items` to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used.
|
161 |
- When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
|
162 |
|
163 |
## Understanding how datasets are precomputed
|
164 |
|
165 |
+
There are 4 arguments related to precomputation:
|
166 |
+
- `--enable_precomputation`: If set, precomputation will be enabled. The parameters that follow are only relevant if this flag is set. If this flag is not set, all models will be loaded in memory and training will take place without first precomputing embeddings.
|
167 |
- `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
|
168 |
- `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
|
169 |
- `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
|
170 |
+
- `--precomputation_reuse`: If you're working with medium/large-size datasets and want to precompute all embeddings and re-use them across different training runs, make sure to set this flag.
|
171 |
|
172 |
Batching is not yet supported for precomputation. This will be added in the future.
|
docs/finetrainers-src-codebase/docs/dataset/_DEBUG.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed dataset debugging
|
2 |
+
|
3 |
+
>![NOTE]
|
4 |
+
> This doc page is intended for developers and contributors.
|
5 |
+
|
6 |
+
If the number of samples in the dataset is lower than the number of processes per node, the training will hand indefinitely. I haven't been able to pin down on how this could be fixed due to limited time, but basically:
|
7 |
+
- Start training with `--dp_degree 2` and `torchrun --standalone --nnodes=1 --nproc_per_node=2`. This launches training with DDP across 2 ranks.
|
8 |
+
- The dataset has `< dp_degree` samples
|
9 |
+
- When `datasets.distributed.split_dataset_by_node` is called, the data is distributed correctly to one rank, but the other rank hangs indefinitely. Due to this edge case, fast tests seem to fail.
|
10 |
+
- For now, we should just use `>= dp_degree` samples in the test dataset. However, should be fixed in the future.
|
11 |
+
|
12 |
+
Minimal reproducer:
|
13 |
+
|
14 |
+
```python
|
15 |
+
import torch
|
16 |
+
import torch.distributed as dist
|
17 |
+
from datasets import Dataset
|
18 |
+
from datasets.distributed import split_dataset_by_node
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
|
21 |
+
ds = Dataset.from_dict({"x": [1]}).to_iterable_dataset()
|
22 |
+
|
23 |
+
dist.init_process_group()
|
24 |
+
rank, world_size = dist.get_rank(), dist.get_world_size()
|
25 |
+
ds = split_dataset_by_node(ds, rank=rank,world_size=world_size)
|
26 |
+
dl = DataLoader(ds)
|
27 |
+
|
28 |
+
exhausted = torch.zeros(world_size, dtype=torch.bool)
|
29 |
+
|
30 |
+
def loop():
|
31 |
+
while True:
|
32 |
+
print(rank, "hello", flush=True)
|
33 |
+
yield from dl
|
34 |
+
yield "end"
|
35 |
+
|
36 |
+
for x in loop():
|
37 |
+
if x == "end":
|
38 |
+
exhausted[rank] = True
|
39 |
+
continue
|
40 |
+
dist.all_reduce(exhausted)
|
41 |
+
if torch.all(exhausted):
|
42 |
+
break
|
43 |
+
print(f"{rank} {x}", flush=True)
|
44 |
+
```
|
docs/{finetrainers/documentation_environment.md → finetrainers-src-codebase/docs/environment.md}
RENAMED
@@ -26,3 +26,14 @@ NVIDIA A100-SXM4-80GB, 81920 MiB
|
|
26 |
```
|
27 |
|
28 |
Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
```
|
27 |
|
28 |
Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
|
29 |
+
|
30 |
+
> [!IMPORTANT]
|
31 |
+
>
|
32 |
+
> For context parallelism, PyTorch 2.6+ is required.
|
33 |
+
|
34 |
+
## Configuration
|
35 |
+
|
36 |
+
The following environment variables may be configured to change the default behaviour of finetrainers:
|
37 |
+
|
38 |
+
`FINETRAINERS_ATTN_PROVIDER`: Sets the default attention provider for training/validation. Defaults to `native`, as in native PyTorch SDPA. See [attention docs](./models/attention.md) for more information.
|
39 |
+
`FINETRAINERS_ATTN_CHECKS`: Whether or not to run basic sanity checks when using different attention providers. This is useful for debugging but you should leave it disabled for longer training runs. Defaults to `"0"`. Can be set to a truthy env value.
|
docs/{finetrainers/documentation_models_README.md → finetrainers-src-codebase/docs/models/README.md}
RENAMED
File without changes
|
docs/finetrainers-src-codebase/docs/models/attention.md
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attention backends
|
2 |
+
|
3 |
+
Finetrainers supports multiple attention backends to support different hardware and tradeoff between speed and memory usage. The following attention implementations are supported:
|
4 |
+
- Training:
|
5 |
+
- If model uses attention masks: `flash_varlen`, `flex`, `native`
|
6 |
+
- If model does not use attention masks: `flash`, `flex`, `native`, `xformers`
|
7 |
+
- Inference:
|
8 |
+
- If model uses attention masks: `flash_varlen`, `flex`, `native`, `sage_varlen`
|
9 |
+
- If model does not use attention masks: `flash`, `flash_varlen`, `flex`, `native`, `sage`, `sage_varlen`, `xformers`
|
10 |
+
|
11 |
+
Additionally, some specialized methods are available for debugging-specific purposes: `_native_cudnn`, `_native_efficient`, `_native_flash`, `_native_math`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`. With time, more attention-specific optimizations and custom implementations will be supported. Contributions are welcome!
|
12 |
+
|
13 |
+
Unfortunately, due to limited time for testing, only specific versions of packages that provide these implementations are supported. Other versions may work. The supported versions will be gradually made lower for more flexibility, but for now, please use the following versions:
|
14 |
+
- `flash-attn>=2.6.3`
|
15 |
+
- `sageattention>=2.1.1`
|
16 |
+
- `xformers>=0.0.29.post3`
|
17 |
+
|
18 |
+
This guide will help you quickly install flash-attn, sageattention, and xformers to make your models run faster and use less memory for training/inference. We'll cover installation on Linux (Ubuntu 22.04) and Windows (using WSL).
|
19 |
+
|
20 |
+
Before you start, make sure to use a clean python virtual environment to not mess up your system seriously, or to avoid conflicting dependencies leading to failed installations which might leave the environment in hard-to-recover state.
|
21 |
+
|
22 |
+
### Flash attention
|
23 |
+
|
24 |
+
Providers covered: `flash`, `flash_varlen`
|
25 |
+
|
26 |
+
The installation steps have only been tested with Ubuntu 22.04; CUDA version higher than 12.2 and 12.6.
|
27 |
+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
|
28 |
+
- You might need the following packages: `pip install packaging ninja`
|
29 |
+
- Linux: Run: `pip install flash-attn --no-build-isolation`. Verify the version with `pip show flash-attn`
|
30 |
+
- WSL: Same instruction as above should work. Native Windows might require building from source - check community guiders and follow the instruction [here](https://github.com/Dao-AILab/flash-attention).
|
31 |
+
|
32 |
+
### Sage attention
|
33 |
+
|
34 |
+
Providers covered: `sage`, `sage_varlen`, `_sage_qk_int8_pv_fp8_cuda`, `_sage_qk_int8_pv_fp8_cuda_sm90`, `_sage_qk_int8_pv_fp16_cuda`, `_sage_qk_int8_pv_fp16_triton`
|
35 |
+
|
36 |
+
FP8 implementations will require CUDA compute capability of 90 or higher (H100, RTX 5090, etc.). Some may work on compute capability 89 as well (RTX 4090, for example). For FP16 implementations, compute capability of atleast 80 is required (A100, RTX 3090, etc.). For other GPUs, FP16 implementations may or may not work (this is untested by me).
|
37 |
+
|
38 |
+
- Check your compute capability with the following command:
|
39 |
+
```bash
|
40 |
+
python -c "import torch; print(torch.cuda.get_device_capability())"
|
41 |
+
```
|
42 |
+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
|
43 |
+
- You might need the following packages: `pip install triton`. For Windows, check out the [triton-windows](https://github.com/woct0rdho/triton-windows) project.
|
44 |
+
- Linux/WSL: Run: `pip install git+https://github.com/thu-ml/SageAttention`. Verify the version with `pip show sageattention`.
|
45 |
+
- Make sure to look at the official installation guide in [SageAttention](https://github.com/thu-ml/SageAttention) too!
|
46 |
+
|
47 |
+
### xformers
|
48 |
+
|
49 |
+
Providers covered: `xformers`
|
50 |
+
|
51 |
+
- Check your CUDA version: look at the output of `nvidia-smi` or run `nvcc --version`.
|
52 |
+
- Linux/WSL: Run: `pip install -U xformers --index-url https://download.pytorch.org/whl/cu126` (assuming CUDA 12.6). Verify the version with `pip show xformers`.
|
53 |
+
- Make sure to look at the official installation guide in [xformers](https://github.com/facebookresearch/xformers) too!
|
54 |
+
|
55 |
+
----------
|
56 |
+
|
57 |
+
All other providers are either native PyTorch implementations or require a specific PyTorch version (for example, Flex Attention requires torch version of atleast 2.5.0).
|
58 |
+
|
59 |
+
----------
|
60 |
+
|
61 |
+
## Usage
|
62 |
+
|
63 |
+
There are two ways to use the attention dispatcher mechanism:
|
64 |
+
- Replace `scaled_dot_product_attention` globally:
|
65 |
+
```python
|
66 |
+
import torch.nn.functional as F
|
67 |
+
from finetrainers.models.attention_dispatch import attention_dispatch
|
68 |
+
|
69 |
+
F.scaled_dot_product_attention = attention_dispatch
|
70 |
+
```
|
71 |
+
- Replace all occurrences of `scaled_dot_product_attention` in your code with `attention_dispatch`.
|
72 |
+
|
73 |
+
```python
|
74 |
+
# Use dispatcher directly
|
75 |
+
from finetrainers.models.attention_dispatch import attention_provider, AttentionProvider
|
76 |
+
|
77 |
+
with attention_provider(AttentionProvider.FLASH_VARLEN):
|
78 |
+
model(...)
|
79 |
+
|
80 |
+
# or,
|
81 |
+
with attention_provider("sage_varlen"):
|
82 |
+
model(...)
|
83 |
+
```
|
84 |
+
|
85 |
+
## Context Parallel
|
86 |
+
|
87 |
+
References and reading material:
|
88 |
+
- https://docs.pytorch.org/tutorials/prototype/context_parallel.html
|
89 |
+
- https://insujang.github.io/2024-09-20/introducing-context-parallelism/
|
90 |
+
- https://www.youtube.com/watch?v=ws7angQYIxI
|
91 |
+
- https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
|
92 |
+
- https://arxiv.org/abs/2309.14509
|
93 |
+
|
94 |
+
There are three steps to enabling context parallelism with any model:
|
95 |
+
- Defining the context parallel plan: This is a dictionary that mentions what tensors to split and gather across CP region at different layers in the model
|
96 |
+
- Applying the CP plan with `apply_context_parallel` function: This registers the necessary hooks to split and gather tensors at the right places in the model without having to manually modify the model code.
|
97 |
+
- Running model under the `attention_provider` context manager
|
98 |
+
|
99 |
+
For a quick example, refer to the [inference example](#inference) below.
|
100 |
+
|
101 |
+
The CP plan is a dictionary that maps the name of the module to a list of `CPInput` or `CPOutput` objects. The keys in the dictionary are the names of the internal modules in the model, and the values are dictionaries that map a parameter identifier (either as an argument index or keyword argument as used in the forward method) to a `CPInput` or `CPOutput` object. The `CPInput` object specifies the input tensor to be split, and the `CPOutput` object specifies the output tensor to be gathered.
|
102 |
+
|
103 |
+
```python
|
104 |
+
class ParamId:
|
105 |
+
name: Optional[str] = None
|
106 |
+
index: Optional[int] = None
|
107 |
+
|
108 |
+
class CPInput:
|
109 |
+
split_dim: int
|
110 |
+
expected_dims: Optional[int] = None
|
111 |
+
split_output: bool = False
|
112 |
+
|
113 |
+
class CPOutput:
|
114 |
+
gather_dim: int
|
115 |
+
expected_dims: Optional[int] = None
|
116 |
+
```
|
117 |
+
|
118 |
+
- The `split_dim` and `gather_dim` parameters specify the dimension along which to split or gather the tensor. When using CP with native scaled dot product attention from pytorch, the tensor shape is `[B, N, S, D]`, so the `split_dim` and `gather_dim` parameters should be set to `2` as it is the sequence dimension.
|
119 |
+
- The `expected_dims` parameter is an optional parameter that is used for sanity checking if the tensor contains the expected number of dimensions.
|
120 |
+
- By default, `CPInput`'s are split in a pre-forward hook and `CPOutput`'s are gathered in a post-forward hook. If you want to split the output of a module, you can set the `split_output` parameter to `True`. This will split the output tensor in the post-forward hook instead of the pre-forward hook.
|
121 |
+
|
122 |
+
- Attention providers supported for training with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash`
|
123 |
+
- Attention providers supported for inference with CP: `flash`, `_native_cudnn`, `_native_efficient`, `_native_flash`
|
124 |
+
|
125 |
+
### Training
|
126 |
+
|
127 |
+
To enable training with context parallelism, you need to make sure a suitable CP plan is registered for the model you are using and launch training with `--cp_degree 2`. For models supported in finetrainers, this is internally done in the [transformer metadata](https://github.com/a-r-r-o-w/finetrainers/tree/main/finetrainers/models/_metadata/transformer.py) file. For custom models, make sure to pass the `plan` argument to the `apply_context_parallel` function.
|
128 |
+
|
129 |
+
Currently supported models include: CogVideoX, CogView4, Flux, Wan 2.1. Support for more models and attention providers is in progress.
|
130 |
+
|
131 |
+
### Inference
|
132 |
+
|
133 |
+
The following example shows how to run context parallel inference. For more examples and ready-to-use inference scripts, check out the [examples/inference](https://github.com/a-r-r-o-w/finetrainers/tree/main/examples/inference/) folder.
|
134 |
+
|
135 |
+
<details>
|
136 |
+
<summary> Example </summary>
|
137 |
+
|
138 |
+
```python
|
139 |
+
import torch
|
140 |
+
import torch.distributed as dist
|
141 |
+
from diffusers import AutoencoderKLWan, WanPipeline
|
142 |
+
from diffusers.utils import export_to_video
|
143 |
+
|
144 |
+
from finetrainers._metadata import ParamId, CPInput, CPOutput
|
145 |
+
from finetrainers.parallel.ptd import apply_context_parallel
|
146 |
+
from finetrainers.models.attention_dispatch import attention_provider, attention_dispatch
|
147 |
+
|
148 |
+
torch.nn.functional.scaled_dot_product_attention = attention_dispatch
|
149 |
+
|
150 |
+
|
151 |
+
def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module:
|
152 |
+
r"""Apply torch.compile to a model or its submodules if not already compiled."""
|
153 |
+
if getattr(model, "_torch_compiled", False):
|
154 |
+
return model # Already compiled
|
155 |
+
|
156 |
+
if compile_scope == "full":
|
157 |
+
model = torch.compile(model)
|
158 |
+
setattr(model, "_torch_compiled", True)
|
159 |
+
elif compile_scope == "regional":
|
160 |
+
if isinstance(model, torch.nn.ModuleList):
|
161 |
+
for name, module in model.named_children():
|
162 |
+
if not getattr(module, "_torch_compiled", False):
|
163 |
+
compiled_module = torch.compile(module, mode="max-autotune-no-cudagraphs", fullgraph=False, dynamic=False)
|
164 |
+
setattr(compiled_module, "_torch_compiled", True)
|
165 |
+
model.register_module(name, compiled_module)
|
166 |
+
else:
|
167 |
+
for name, module in model.named_children():
|
168 |
+
apply_compile(module, compile_scope)
|
169 |
+
else:
|
170 |
+
raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.")
|
171 |
+
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
torch.manual_seed(0)
|
176 |
+
dist.init_process_group("nccl")
|
177 |
+
rank, world_size = dist.get_rank(), dist.get_world_size()
|
178 |
+
torch.cuda.set_device(rank)
|
179 |
+
cp_mesh = dist.device_mesh.init_device_mesh("cuda", [world_size], mesh_dim_names=["cp"])
|
180 |
+
|
181 |
+
cp_plan = {
|
182 |
+
"rope": {
|
183 |
+
ParamId(index=0): CPInput(2, 4, split_output=True),
|
184 |
+
},
|
185 |
+
"blocks.*": {
|
186 |
+
ParamId("encoder_hidden_states", 1): CPInput(1, 3),
|
187 |
+
},
|
188 |
+
"blocks.0": {
|
189 |
+
ParamId("hidden_states", 0): CPInput(1, 3),
|
190 |
+
},
|
191 |
+
"proj_out": [CPOutput(1, 3)],
|
192 |
+
}
|
193 |
+
|
194 |
+
try:
|
195 |
+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
196 |
+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
197 |
+
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
198 |
+
pipe.to("cuda")
|
199 |
+
|
200 |
+
apply_context_parallel(pipe.transformer, mesh=cp_mesh, plan=cp_plan)
|
201 |
+
|
202 |
+
apply_compile(pipe.transformer, compile_scope="regional")
|
203 |
+
|
204 |
+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
205 |
+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
209 |
+
prompt=prompt, negative_prompt=negative_prompt, device="cuda",
|
210 |
+
)
|
211 |
+
|
212 |
+
attention_backend = "_native_flash"
|
213 |
+
generator = torch.Generator().manual_seed(0)
|
214 |
+
|
215 |
+
# Warmup for compilation
|
216 |
+
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="alltoall"):
|
217 |
+
latents = pipe(
|
218 |
+
prompt_embeds=prompt_embeds,
|
219 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
220 |
+
height=480,
|
221 |
+
width=832,
|
222 |
+
num_frames=81,
|
223 |
+
num_inference_steps=2,
|
224 |
+
guidance_scale=5.0,
|
225 |
+
output_type="latent",
|
226 |
+
generator=generator,
|
227 |
+
).frames[0]
|
228 |
+
|
229 |
+
# Inference
|
230 |
+
with attention_provider(attention_backend, mesh=cp_mesh, convert_to_fp32=True, rotate_method="allgather"):
|
231 |
+
latents = pipe(
|
232 |
+
prompt_embeds=prompt_embeds,
|
233 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
234 |
+
height=480,
|
235 |
+
width=832,
|
236 |
+
num_frames=81,
|
237 |
+
guidance_scale=5.0,
|
238 |
+
num_inference_steps=30,
|
239 |
+
output_type="latent",
|
240 |
+
generator=generator,
|
241 |
+
).frames[0]
|
242 |
+
|
243 |
+
with torch.no_grad():
|
244 |
+
latents = latents.to(pipe.vae.dtype)
|
245 |
+
latents_mean = (
|
246 |
+
torch.tensor(pipe.vae.config.latents_mean)
|
247 |
+
.view(1, pipe.vae.config.z_dim, 1, 1, 1)
|
248 |
+
.to(latents.device, latents.dtype)
|
249 |
+
)
|
250 |
+
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
|
251 |
+
latents.device, latents.dtype
|
252 |
+
)
|
253 |
+
latents = latents / latents_std + latents_mean
|
254 |
+
video = pipe.vae.decode(latents, return_dict=False)[0]
|
255 |
+
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
|
256 |
+
|
257 |
+
if rank == 0:
|
258 |
+
export_to_video(video, "output.mp4", fps=16)
|
259 |
+
finally:
|
260 |
+
dist.destroy_process_group()
|
261 |
+
```
|
262 |
+
|
263 |
+
</details>
|
docs/{finetrainers/documentation_models_cogvideox.md → finetrainers-src-codebase/docs/models/cogvideox.md}
RENAMED
@@ -20,9 +20,9 @@ On Windows, you will have to modify the script to a compatible format to run it.
|
|
20 |
|
21 |
CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
## Inference
|
28 |
|
@@ -45,6 +45,6 @@ export_to_video(video, "output.mp4")
|
|
45 |
|
46 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
20 |
|
21 |
CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:
|
22 |
|
23 |
+
- [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b)
|
24 |
+
- [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B)
|
25 |
+
- [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B)
|
26 |
|
27 |
## Inference
|
28 |
|
|
|
45 |
|
46 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
47 |
|
48 |
+
- [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox)
|
49 |
+
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
50 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/finetrainers-src-codebase/docs/models/cogview4.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CogView4
|
2 |
+
|
3 |
+
## Training
|
4 |
+
|
5 |
+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
|
6 |
+
|
7 |
+
Examples available:
|
8 |
+
- [Raider White Tarot cards style](../../examples/training/sft/cogview4/raider_white_tarot/)
|
9 |
+
- [Omni Edit Control LoRA](../../examples/training/control/cogview4/omni_edit/)
|
10 |
+
- [Canny Control LoRA](../../examples/training/control/cogview4/canny/)
|
11 |
+
|
12 |
+
To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
|
13 |
+
|
14 |
+
```bash
|
15 |
+
chmod +x ./examples/training/sft/cogview4/raider_white_tarot/train.sh
|
16 |
+
./examples/training/sft/cogview4/raider_white_tarot/train.sh
|
17 |
+
```
|
18 |
+
|
19 |
+
On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
|
20 |
+
|
21 |
+
## Supported checkpoints
|
22 |
+
|
23 |
+
The following checkpoints were tested with `finetrainers` and are known to be working:
|
24 |
+
|
25 |
+
- [THUDM/CogView4-6B](https://huggingface.co/THUDM/CogView4-6B)
|
26 |
+
|
27 |
+
## Inference
|
28 |
+
|
29 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
30 |
+
|
31 |
+
```diff
|
32 |
+
import torch
|
33 |
+
from diffusers import CogView4Pipeline
|
34 |
+
from diffusers.utils import export_to_video
|
35 |
+
|
36 |
+
pipe = CogView4Pipeline.from_pretrained(
|
37 |
+
"THUDM/CogView4-6B", torch_dtype=torch.bfloat16
|
38 |
+
).to("cuda")
|
39 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogview4-lora")
|
40 |
+
+ pipe.set_adapters(["cogview4-lora"], [0.9])
|
41 |
+
|
42 |
+
video = pipe("<my-awesome-prompt>").frames[0]
|
43 |
+
export_to_video(video, "output.mp4")
|
44 |
+
```
|
45 |
+
|
46 |
+
To use trained Control LoRAs, the following can be used for inference (ideally, you should raise a support request in Diffusers):
|
47 |
+
|
48 |
+
<details>
|
49 |
+
<summary> Control Lora inference </summary>
|
50 |
+
|
51 |
+
```python
|
52 |
+
import torch
|
53 |
+
from diffusers import CogView4Pipeline
|
54 |
+
from diffusers.utils import load_image
|
55 |
+
from finetrainers.models.utils import _expand_linear_with_zeroed_weights
|
56 |
+
from finetrainers.patches import load_lora_weights
|
57 |
+
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
|
58 |
+
|
59 |
+
dtype = torch.bfloat16
|
60 |
+
device = torch.device("cuda")
|
61 |
+
generator = torch.Generator().manual_seed(0)
|
62 |
+
|
63 |
+
pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype)
|
64 |
+
|
65 |
+
in_channels = pipe.transformer.config.in_channels
|
66 |
+
patch_channels = pipe.transformer.patch_embed.proj.in_features
|
67 |
+
pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels)
|
68 |
+
|
69 |
+
load_lora_weights(pipe, "/raid/aryan/cogview4-control-lora", "cogview4-lora")
|
70 |
+
pipe.to(device)
|
71 |
+
|
72 |
+
prompt = "Make the image look like it's from an ancient Egyptian mural."
|
73 |
+
control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png")
|
74 |
+
height, width = 1024, 1024
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator)
|
78 |
+
control_image = pipe.image_processor.preprocess(control_image, height=height, width=width)
|
79 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
80 |
+
control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator)
|
81 |
+
control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
|
82 |
+
|
83 |
+
with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]):
|
84 |
+
image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0]
|
85 |
+
|
86 |
+
image.save("output.png")
|
87 |
+
```
|
88 |
+
</details>
|
89 |
+
|
90 |
+
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
91 |
+
|
92 |
+
- [CogView4 in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4)
|
93 |
+
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
94 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/finetrainers-src-codebase/docs/models/flux.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flux
|
2 |
+
|
3 |
+
## Training
|
4 |
+
|
5 |
+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
|
6 |
+
|
7 |
+
Examples available:
|
8 |
+
- [Raider White Tarot cards style](../../examples/training/sft/flux_dev/raider_white_tarot/)
|
9 |
+
|
10 |
+
To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
|
11 |
+
|
12 |
+
```bash
|
13 |
+
chmod +x ./examples/training/sft/flux_dev/raider_white_tarot/train.sh
|
14 |
+
./examples/training/sft/flux_dev/raider_white_tarot/train.sh
|
15 |
+
```
|
16 |
+
|
17 |
+
On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
|
18 |
+
|
19 |
+
> [!NOTE]
|
20 |
+
> Currently, only FLUX.1-dev is supported. It is a guidance-distilled model which directly predicts the outputs of its teacher model when the teacher is run with CFG. To match the output distribution of the distilled model with that of the teacher model, a guidance scale of 1.0 is hardcoded into the codebase. However, other values may work too but it is experimental.
|
21 |
+
> FLUX.1-schnell is not supported for training yet. It is a timestep-distilled model. Matching its output distribution for training is significantly more difficult.
|
22 |
+
|
23 |
+
## Supported checkpoints
|
24 |
+
|
25 |
+
The following checkpoints were tested with `finetrainers` and are known to be working:
|
26 |
+
|
27 |
+
- [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
28 |
+
- [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
|
29 |
+
|
30 |
+
## Inference
|
31 |
+
|
32 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
33 |
+
|
34 |
+
```diff
|
35 |
+
import torch
|
36 |
+
from diffusers import FluxPipeline
|
37 |
+
|
38 |
+
pipe = FluxPipeline.from_pretrained(
|
39 |
+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
40 |
+
).to("cuda")
|
41 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="flux-lora")
|
42 |
+
+ pipe.set_adapters(["flux-lora"], [0.9])
|
43 |
+
|
44 |
+
# Make sure to set guidance_scale to 0.0 when inferencing with FLUX.1-schnell or derivative models
|
45 |
+
image = pipe("<my-awesome-prompt>").images[0]
|
46 |
+
image.save("output.png")
|
47 |
+
```
|
48 |
+
|
49 |
+
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
50 |
+
|
51 |
+
- [Flux in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux)
|
52 |
+
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
53 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/{finetrainers/documentation_models_hunyuan_video.md → finetrainers-src-codebase/docs/models/hunyuan_video.md}
RENAMED
@@ -50,6 +50,6 @@ export_to_video(output, "output.mp4", fps=15)
|
|
50 |
|
51 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
50 |
|
51 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
52 |
|
53 |
+
- [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)
|
54 |
+
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
55 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/{finetrainers/documentation_models_ltx_video.md → finetrainers-src-codebase/docs/models/ltx_video.md}
RENAMED
@@ -37,6 +37,6 @@ export_to_video(video, "output.mp4", fps=8)
|
|
37 |
|
38 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
37 |
|
38 |
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
39 |
|
40 |
+
- [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video)
|
41 |
+
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
42 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/{finetrainers/documentation_models_optimization.md → finetrainers-src-codebase/docs/models/optimization.md}
RENAMED
File without changes
|
docs/{finetrainers/documentation_models_wan.md → finetrainers-src-codebase/docs/models/wan.md}
RENAMED
@@ -18,6 +18,16 @@ chmod +x ./examples/training/sft/wan/crush_smol_lora/train.sh
|
|
18 |
|
19 |
On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
## Inference
|
22 |
|
23 |
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
@@ -102,4 +112,4 @@ You can refer to the following guides to know more about the model pipeline and
|
|
102 |
|
103 |
- [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
|
104 |
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
105 |
-
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
|
|
18 |
|
19 |
On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
|
20 |
|
21 |
+
## Supported checkpoints
|
22 |
+
|
23 |
+
Wan has multiple checkpoints as one can find [here](https://huggingface.co/Wan-AI). The following checkpoints were tested with `finetrainers` and are known to be working:
|
24 |
+
|
25 |
+
- [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
|
26 |
+
- [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
|
27 |
+
- [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
|
28 |
+
- [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
|
29 |
+
- [Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
|
30 |
+
|
31 |
## Inference
|
32 |
|
33 |
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
|
|
112 |
|
113 |
- [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
|
114 |
- [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
115 |
+
- [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/{finetrainers/documentation_optimizers.md → finetrainers-src-codebase/docs/optimizer.md}
RENAMED
File without changes
|
docs/{finetrainers/documentation_parallel_processing_README.md → finetrainers-src-codebase/docs/parallel/README.md}
RENAMED
@@ -14,11 +14,12 @@ As an experiment for comparing performance of different training backends, Finet
|
|
14 |
|
15 |
## Support matrix
|
16 |
|
17 |
-
|
18 |
- [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
|
19 |
- [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
|
20 |
- [HSDP](https://pytorch.org/docs/stable/fsdp.html)
|
21 |
-
- [
|
|
|
22 |
|
23 |
## Training
|
24 |
|
@@ -28,7 +29,7 @@ The following parameters are relevant for launching training:
|
|
28 |
- `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
|
29 |
- `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
|
30 |
- `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
|
31 |
-
- `cp_degree`: The degree of context parallelism.
|
32 |
- `tp_degree`: The degree of tensor parallelism.
|
33 |
|
34 |
For launching training with the Pytorch DTensor backend, use the following:
|
@@ -57,3 +58,7 @@ accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0
|
|
57 |
# Multi-node - Nx8 GPUs available
|
58 |
# TODO(aryan): Add slurm script
|
59 |
```
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
## Support matrix
|
16 |
|
17 |
+
Currently supported parallelizations include:
|
18 |
- [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
|
19 |
- [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
|
20 |
- [HSDP](https://pytorch.org/docs/stable/fsdp.html)
|
21 |
+
- [CP](https://docs.pytorch.org/tutorials/prototype/context_parallel.html)
|
22 |
+
<!-- - [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) -->
|
23 |
|
24 |
## Training
|
25 |
|
|
|
29 |
- `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
|
30 |
- `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
|
31 |
- `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
|
32 |
+
- `cp_degree`: The degree of context parallelism.
|
33 |
- `tp_degree`: The degree of tensor parallelism.
|
34 |
|
35 |
For launching training with the Pytorch DTensor backend, use the following:
|
|
|
58 |
# Multi-node - Nx8 GPUs available
|
59 |
# TODO(aryan): Add slurm script
|
60 |
```
|
61 |
+
|
62 |
+
## Inference
|
63 |
+
|
64 |
+
For inference-only purposes, the example implementation can be found in the [examples/inference/](../../examples/inference/) directory.
|
docs/{finetrainers/documentation_trainers_control_trainer.md → finetrainers-src-codebase/docs/trainer/control_trainer.md}
RENAMED
File without changes
|
docs/{finetrainers/documentation_trainers_sft_trainer.md → finetrainers-src-codebase/docs/trainer/sft_trainer.md}
RENAMED
File without changes
|
docs/finetrainers-src-codebase/examples/_legacy/training/README.md
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CogVideoX Factory 🧪
|
2 |
+
|
3 |
+
[中文阅读](./README_zh.md)
|
4 |
+
|
5 |
+
Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼
|
6 |
+
|
7 |
+
<table align="center">
|
8 |
+
<tr>
|
9 |
+
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
|
10 |
+
</tr>
|
11 |
+
</table>
|
12 |
+
|
13 |
+
**Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)!
|
14 |
+
|
15 |
+
## Quickstart
|
16 |
+
|
17 |
+
Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`.
|
18 |
+
|
19 |
+
Then download a dataset:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
# install `huggingface_hub`
|
23 |
+
huggingface-cli download \
|
24 |
+
--repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
|
25 |
+
--local-dir video-dataset-disney
|
26 |
+
```
|
27 |
+
|
28 |
+
Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice):
|
29 |
+
|
30 |
+
```bash
|
31 |
+
# For LoRA finetuning of the text-to-video CogVideoX models
|
32 |
+
./train_text_to_video_lora.sh
|
33 |
+
|
34 |
+
# For full finetuning of the text-to-video CogVideoX models
|
35 |
+
./train_text_to_video_sft.sh
|
36 |
+
|
37 |
+
# For LoRA finetuning of the image-to-video CogVideoX models
|
38 |
+
./train_image_to_video_lora.sh
|
39 |
+
```
|
40 |
+
|
41 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
42 |
+
|
43 |
+
```diff
|
44 |
+
import torch
|
45 |
+
from diffusers import CogVideoXPipeline
|
46 |
+
from diffusers.utils import export_to_video
|
47 |
+
|
48 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
49 |
+
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
|
50 |
+
).to("cuda")
|
51 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
|
52 |
+
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
|
53 |
+
|
54 |
+
video = pipe("<my-awesome-prompt>").frames[0]
|
55 |
+
export_to_video(video, "output.mp4", fps=8)
|
56 |
+
```
|
57 |
+
|
58 |
+
For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):
|
59 |
+
|
60 |
+
```python
|
61 |
+
from diffusers import CogVideoXImageToVideoPipeline
|
62 |
+
|
63 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
64 |
+
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
|
65 |
+
).to("cuda")
|
66 |
+
|
67 |
+
# ...
|
68 |
+
|
69 |
+
del pipe.transformer.patch_embed.pos_embedding
|
70 |
+
pipe.transformer.patch_embed.use_learned_positional_embeddings = False
|
71 |
+
pipe.transformer.config.use_learned_positional_embeddings = False
|
72 |
+
```
|
73 |
+
|
74 |
+
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
|
75 |
+
|
76 |
+
Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.
|
77 |
+
|
78 |
+
## Prepare Dataset and Training
|
79 |
+
|
80 |
+
Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example.
|
81 |
+
|
82 |
+
- Configure environment variables as per your choice:
|
83 |
+
|
84 |
+
```bash
|
85 |
+
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
86 |
+
export TORCHDYNAMO_VERBOSE=1
|
87 |
+
export WANDB_MODE="offline"
|
88 |
+
export NCCL_P2P_DISABLE=1
|
89 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
90 |
+
```
|
91 |
+
|
92 |
+
- Configure which GPUs to use for training: `GPU_IDS="0,1"`
|
93 |
+
|
94 |
+
- Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example:
|
95 |
+
|
96 |
+
```bash
|
97 |
+
LEARNING_RATES=("1e-4" "1e-3")
|
98 |
+
LR_SCHEDULES=("cosine_with_restarts")
|
99 |
+
OPTIMIZERS=("adamw" "adam")
|
100 |
+
MAX_TRAIN_STEPS=("3000")
|
101 |
+
```
|
102 |
+
|
103 |
+
- Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`.
|
104 |
+
|
105 |
+
- Specify the absolute paths and columns/files for captions and videos.
|
106 |
+
|
107 |
+
```bash
|
108 |
+
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
|
109 |
+
CAPTION_COLUMN="prompt.txt"
|
110 |
+
VIDEO_COLUMN="videos.txt"
|
111 |
+
```
|
112 |
+
|
113 |
+
- Launch experiments sweeping different hyperparameters:
|
114 |
+
```
|
115 |
+
for learning_rate in "${LEARNING_RATES[@]}"; do
|
116 |
+
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
117 |
+
for optimizer in "${OPTIMIZERS[@]}"; do
|
118 |
+
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
119 |
+
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
120 |
+
|
121 |
+
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
|
122 |
+
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
|
123 |
+
--data_root $DATA_ROOT \
|
124 |
+
--caption_column $CAPTION_COLUMN \
|
125 |
+
--video_column $VIDEO_COLUMN \
|
126 |
+
--id_token BW_STYLE \
|
127 |
+
--height_buckets 480 \
|
128 |
+
--width_buckets 720 \
|
129 |
+
--frame_buckets 49 \
|
130 |
+
--dataloader_num_workers 8 \
|
131 |
+
--pin_memory \
|
132 |
+
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
133 |
+
--validation_prompt_separator ::: \
|
134 |
+
--num_validation_videos 1 \
|
135 |
+
--validation_epochs 10 \
|
136 |
+
--seed 42 \
|
137 |
+
--rank 128 \
|
138 |
+
--lora_alpha 128 \
|
139 |
+
--mixed_precision bf16 \
|
140 |
+
--output_dir $output_dir \
|
141 |
+
--max_num_frames 49 \
|
142 |
+
--train_batch_size 1 \
|
143 |
+
--max_train_steps $steps \
|
144 |
+
--checkpointing_steps 1000 \
|
145 |
+
--gradient_accumulation_steps 1 \
|
146 |
+
--gradient_checkpointing \
|
147 |
+
--learning_rate $learning_rate \
|
148 |
+
--lr_scheduler $lr_schedule \
|
149 |
+
--lr_warmup_steps 400 \
|
150 |
+
--lr_num_cycles 1 \
|
151 |
+
--enable_slicing \
|
152 |
+
--enable_tiling \
|
153 |
+
--optimizer $optimizer \
|
154 |
+
--beta1 0.9 \
|
155 |
+
--beta2 0.95 \
|
156 |
+
--weight_decay 0.001 \
|
157 |
+
--max_grad_norm 1.0 \
|
158 |
+
--allow_tf32 \
|
159 |
+
--report_to wandb \
|
160 |
+
--nccl_timeout 1800"
|
161 |
+
|
162 |
+
echo "Running command: $cmd"
|
163 |
+
eval $cmd
|
164 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
165 |
+
done
|
166 |
+
done
|
167 |
+
done
|
168 |
+
done
|
169 |
+
```
|
170 |
+
|
171 |
+
To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`.
|
172 |
+
|
173 |
+
Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below.
|
174 |
+
|
175 |
+
## Memory requirements
|
176 |
+
|
177 |
+
<table align="center">
|
178 |
+
<tr>
|
179 |
+
<td align="center" colspan="2"><b>CogVideoX LoRA Finetuning</b></td>
|
180 |
+
</tr>
|
181 |
+
<tr>
|
182 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
183 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
184 |
+
</tr>
|
185 |
+
<tr>
|
186 |
+
<td align="center"><img src="../assets/lora_2b.png" /></td>
|
187 |
+
<td align="center"><img src="../assets/lora_5b.png" /></td>
|
188 |
+
</tr>
|
189 |
+
|
190 |
+
<tr>
|
191 |
+
<td align="center" colspan="2"><b>CogVideoX Full Finetuning</b></td>
|
192 |
+
</tr>
|
193 |
+
<tr>
|
194 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
195 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
196 |
+
</tr>
|
197 |
+
<tr>
|
198 |
+
<td align="center"><img src="../assets/sft_2b.png" /></td>
|
199 |
+
<td align="center"><img src="../assets/sft_5b.png" /></td>
|
200 |
+
</tr>
|
201 |
+
</table>
|
202 |
+
|
203 |
+
Supported and verified memory optimizations for training include:
|
204 |
+
|
205 |
+
- `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either.
|
206 |
+
- Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work
|
207 |
+
- DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations.
|
208 |
+
|
209 |
+
> [!IMPORTANT]
|
210 |
+
> The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
|
211 |
+
>
|
212 |
+
> If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`.
|
213 |
+
|
214 |
+
### LoRA finetuning
|
215 |
+
|
216 |
+
> [!NOTE]
|
217 |
+
> The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
|
218 |
+
>
|
219 |
+
> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
|
220 |
+
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
221 |
+
> or provide a URL to a valid and accessible image.
|
222 |
+
|
223 |
+
<details>
|
224 |
+
<summary> AdamW </summary>
|
225 |
+
|
226 |
+
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
227 |
+
|
228 |
+
With `train_batch_size = 1`:
|
229 |
+
|
230 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
231 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
232 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
|
233 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
|
234 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
|
235 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
|
236 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
|
237 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
|
238 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
|
239 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
|
240 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
|
241 |
+
|
242 |
+
With `train_batch_size = 4`:
|
243 |
+
|
244 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
245 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
246 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
|
247 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
|
248 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
|
249 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
|
250 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
|
251 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
|
252 |
+
|
253 |
+
</details>
|
254 |
+
|
255 |
+
<details>
|
256 |
+
<summary> AdamW (8-bit bitsandbytes) </summary>
|
257 |
+
|
258 |
+
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
259 |
+
|
260 |
+
With `train_batch_size = 1`:
|
261 |
+
|
262 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
263 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
264 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
|
265 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
|
266 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
|
267 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
|
268 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
|
269 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
|
270 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
|
271 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
|
272 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
|
273 |
+
|
274 |
+
With `train_batch_size = 4`:
|
275 |
+
|
276 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
277 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
278 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
|
279 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
|
280 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
|
281 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
|
282 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
|
283 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
|
284 |
+
|
285 |
+
</details>
|
286 |
+
|
287 |
+
<details>
|
288 |
+
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
289 |
+
|
290 |
+
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
291 |
+
|
292 |
+
With `train_batch_size = 1`:
|
293 |
+
|
294 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
295 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
296 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
|
297 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
|
298 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
|
299 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
|
300 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
|
301 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
|
302 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
|
303 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
|
304 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
|
305 |
+
|
306 |
+
With `train_batch_size = 4`:
|
307 |
+
|
308 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
309 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
310 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
|
311 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
|
312 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
|
313 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
|
314 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
|
315 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
|
316 |
+
|
317 |
+
</details>
|
318 |
+
|
319 |
+
<details>
|
320 |
+
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
321 |
+
|
322 |
+
**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
|
323 |
+
|
324 |
+
With `train_batch_size = 1`:
|
325 |
+
|
326 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
327 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
328 |
+
| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
|
329 |
+
| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
|
330 |
+
|
331 |
+
With `train_batch_size = 4`:
|
332 |
+
|
333 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
334 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
335 |
+
| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
|
336 |
+
| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
|
337 |
+
|
338 |
+
</details>
|
339 |
+
|
340 |
+
### Full finetuning
|
341 |
+
|
342 |
+
> [!NOTE]
|
343 |
+
> The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
|
344 |
+
>
|
345 |
+
> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
|
346 |
+
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
347 |
+
> or provide a URL to a valid and accessible image.
|
348 |
+
|
349 |
+
> [!NOTE]
|
350 |
+
> Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
351 |
+
|
352 |
+
<details>
|
353 |
+
<summary> AdamW </summary>
|
354 |
+
|
355 |
+
With `train_batch_size = 1`:
|
356 |
+
|
357 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
358 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
359 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
|
360 |
+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
361 |
+
|
362 |
+
With `train_batch_size = 4`:
|
363 |
+
|
364 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
365 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
366 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
|
367 |
+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
368 |
+
|
369 |
+
</details>
|
370 |
+
|
371 |
+
<details>
|
372 |
+
<summary> AdamW (8-bit bitsandbytes) </summary>
|
373 |
+
|
374 |
+
With `train_batch_size = 1`:
|
375 |
+
|
376 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
377 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
378 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
|
379 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
|
380 |
+
|
381 |
+
With `train_batch_size = 4`:
|
382 |
+
|
383 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
384 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
385 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
|
386 |
+
| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
|
387 |
+
|
388 |
+
</details>
|
389 |
+
|
390 |
+
<details>
|
391 |
+
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
392 |
+
|
393 |
+
With `train_batch_size = 1`:
|
394 |
+
|
395 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
396 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
397 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
|
398 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
|
399 |
+
|
400 |
+
With `train_batch_size = 4`:
|
401 |
+
|
402 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
403 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
404 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
|
405 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
|
406 |
+
|
407 |
+
</details>
|
408 |
+
|
409 |
+
<details>
|
410 |
+
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
411 |
+
|
412 |
+
**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
|
413 |
+
|
414 |
+
With `train_batch_size = 1`:
|
415 |
+
|
416 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
417 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
418 |
+
| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
|
419 |
+
| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
|
420 |
+
|
421 |
+
With `train_batch_size = 4`:
|
422 |
+
|
423 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
424 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
425 |
+
| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
|
426 |
+
| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
|
427 |
+
|
428 |
+
</details>
|
429 |
+
|
430 |
+
> [!NOTE]
|
431 |
+
> - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script.
|
432 |
+
>
|
433 |
+
> - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing.
|
434 |
+
|
435 |
+
<table align="center">
|
436 |
+
<tr>
|
437 |
+
<td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
|
438 |
+
</tr>
|
439 |
+
<tr>
|
440 |
+
<td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
|
441 |
+
</tr>
|
442 |
+
</table>
|
443 |
+
|
444 |
+
## TODOs
|
445 |
+
|
446 |
+
- [x] Make scripts compatible with DDP
|
447 |
+
- [ ] Make scripts compatible with FSDP
|
448 |
+
- [x] Make scripts compatible with DeepSpeed
|
449 |
+
- [ ] vLLM-powered captioning script
|
450 |
+
- [x] Multi-resolution/frame support in `prepare_dataset.py`
|
451 |
+
- [ ] Analyzing traces for potential speedups and removing as many syncs as possible
|
452 |
+
- [x] Test scripts with memory-efficient optimizer from bitsandbytes
|
453 |
+
- [x] Test scripts with CPUOffloadOptimizer, etc.
|
454 |
+
- [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao))
|
455 |
+
- [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out)
|
456 |
+
- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100)
|
457 |
+
|
458 |
+
> [!IMPORTANT]
|
459 |
+
> Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training.
|
docs/finetrainers-src-codebase/examples/_legacy/training/README_zh.md
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CogVideoX Factory 🧪
|
2 |
+
|
3 |
+
[Read in English](./README.md)
|
4 |
+
|
5 |
+
在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼
|
6 |
+
|
7 |
+
<table align="center">
|
8 |
+
<tr>
|
9 |
+
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">您的浏览器不支持视频标签。</video></td>
|
10 |
+
</tr>
|
11 |
+
</table>
|
12 |
+
|
13 |
+
## 快速开始
|
14 |
+
|
15 |
+
克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。
|
16 |
+
|
17 |
+
接着下载数据集:
|
18 |
+
|
19 |
+
```
|
20 |
+
# 安装 `huggingface_hub`
|
21 |
+
huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
|
22 |
+
```
|
23 |
+
|
24 |
+
然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项):
|
25 |
+
|
26 |
+
```
|
27 |
+
# 对 CogVideoX 模型进行文本到视频的 LoRA 微调
|
28 |
+
./train_text_to_video_lora.sh
|
29 |
+
|
30 |
+
# 对 CogVideoX 模型进行文本到视频的完整微调
|
31 |
+
./train_text_to_video_sft.sh
|
32 |
+
|
33 |
+
# 对 CogVideoX 模型进行图像到视频的 LoRA 微调
|
34 |
+
./train_image_to_video_lora.sh
|
35 |
+
```
|
36 |
+
|
37 |
+
假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理:
|
38 |
+
|
39 |
+
```
|
40 |
+
import torch
|
41 |
+
from diffusers import CogVideoXPipeline
|
42 |
+
from diffusers import export_to_video
|
43 |
+
|
44 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
45 |
+
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
|
46 |
+
).to("cuda")
|
47 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"])
|
48 |
+
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
|
49 |
+
|
50 |
+
video = pipe("<my-awesome-prompt>").frames[0]
|
51 |
+
export_to_video(video, "output.mp4", fps=8)
|
52 |
+
```
|
53 |
+
|
54 |
+
你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。
|
55 |
+
|
56 |
+
**注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装
|
57 |
+
diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。
|
58 |
+
|
59 |
+
以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。
|
60 |
+
|
61 |
+
## 训练
|
62 |
+
|
63 |
+
在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。
|
64 |
+
|
65 |
+
- 根据你的需求配置环境变量:
|
66 |
+
|
67 |
+
```
|
68 |
+
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
69 |
+
export TORCHDYNAMO_VERBOSE=1
|
70 |
+
export WANDB_MODE="offline"
|
71 |
+
export NCCL_P2P_DISABLE=1
|
72 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
73 |
+
```
|
74 |
+
|
75 |
+
- 配置用于训练的 GPU:`GPU_IDS="0,1"`
|
76 |
+
|
77 |
+
- 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例:
|
78 |
+
|
79 |
+
```
|
80 |
+
LEARNING_RATES=("1e-4" "1e-3")
|
81 |
+
LR_SCHEDULES=("cosine_with_restarts")
|
82 |
+
OPTIMIZERS=("adamw" "adam")
|
83 |
+
MAX_TRAIN_STEPS=("3000")
|
84 |
+
```
|
85 |
+
|
86 |
+
- 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`
|
87 |
+
。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed
|
88 |
+
等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。
|
89 |
+
|
90 |
+
- 指定字幕和视频的绝对路径以及列/文件。
|
91 |
+
|
92 |
+
```
|
93 |
+
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
|
94 |
+
CAPTION_COLUMN="prompt.txt"
|
95 |
+
VIDEO_COLUMN="videos.txt"
|
96 |
+
```
|
97 |
+
|
98 |
+
- 运行实验,遍历不同的超参数:
|
99 |
+
```
|
100 |
+
for learning_rate in "${LEARNING_RATES[@]}"; do
|
101 |
+
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
102 |
+
for optimizer in "${OPTIMIZERS[@]}"; do
|
103 |
+
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
104 |
+
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
105 |
+
|
106 |
+
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
|
107 |
+
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
|
108 |
+
--data_root $DATA_ROOT \
|
109 |
+
--caption_column $CAPTION_COLUMN \
|
110 |
+
--video_column $VIDEO_COLUMN \
|
111 |
+
--id_token BW_STYLE \
|
112 |
+
--height_buckets 480 \
|
113 |
+
--width_buckets 720 \
|
114 |
+
--frame_buckets 49 \
|
115 |
+
--dataloader_num_workers 8 \
|
116 |
+
--pin_memory \
|
117 |
+
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
118 |
+
--validation_prompt_separator ::: \
|
119 |
+
--num_validation_videos 1 \
|
120 |
+
--validation_epochs 10 \
|
121 |
+
--seed 42 \
|
122 |
+
--rank 128 \
|
123 |
+
--lora_alpha 128 \
|
124 |
+
--mixed_precision bf16 \
|
125 |
+
--output_dir $output_dir \
|
126 |
+
--max_num_frames 49 \
|
127 |
+
--train_batch_size 1 \
|
128 |
+
--max_train_steps $steps \
|
129 |
+
--checkpointing_steps 1000 \
|
130 |
+
--gradient_accumulation_steps 1 \
|
131 |
+
--gradient_checkpointing \
|
132 |
+
--learning_rate $learning_rate \
|
133 |
+
--lr_scheduler $lr_schedule \
|
134 |
+
--lr_warmup_steps 400 \
|
135 |
+
--lr_num_cycles 1 \
|
136 |
+
--enable_slicing \
|
137 |
+
--enable_tiling \
|
138 |
+
--optimizer $optimizer \
|
139 |
+
--beta1 0.9 \
|
140 |
+
--beta2 0.95 \
|
141 |
+
--weight_decay 0.001 \
|
142 |
+
--max_grad_norm 1.0 \
|
143 |
+
--allow_tf32 \
|
144 |
+
--report_to wandb \
|
145 |
+
--nccl_timeout 1800"
|
146 |
+
|
147 |
+
echo "Running command: $cmd"
|
148 |
+
eval $cmd
|
149 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
150 |
+
done
|
151 |
+
done
|
152 |
+
done
|
153 |
+
done
|
154 |
+
```
|
155 |
+
|
156 |
+
要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。
|
157 |
+
|
158 |
+
注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。
|
159 |
+
|
160 |
+
## 内存需求
|
161 |
+
|
162 |
+
<table align="center">
|
163 |
+
<tr>
|
164 |
+
<td align="center" colspan="2"><b>CogVideoX LoRA 微调</b></td>
|
165 |
+
</tr>
|
166 |
+
<tr>
|
167 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
168 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
169 |
+
</tr>
|
170 |
+
<tr>
|
171 |
+
<td align="center"><img src="assets/lora_2b.png" /></td>
|
172 |
+
<td align="center"><img src="assets/lora_5b.png" /></td>
|
173 |
+
</tr>
|
174 |
+
|
175 |
+
<tr>
|
176 |
+
<td align="center" colspan="2"><b>CogVideoX 全量微调</b></td>
|
177 |
+
</tr>
|
178 |
+
<tr>
|
179 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
180 |
+
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
181 |
+
</tr>
|
182 |
+
<tr>
|
183 |
+
<td align="center"><img src="assets/sft_2b.png" /></td>
|
184 |
+
<td align="center"><img src="assets/sft_5b.png" /></td>
|
185 |
+
</tr>
|
186 |
+
</table>
|
187 |
+
|
188 |
+
支持和验证的训练内存优化包括:
|
189 |
+
|
190 |
+
- `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao)
|
191 |
+
。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload)
|
192 |
+
阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU
|
193 |
+
优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile`
|
194 |
+
。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。
|
195 |
+
- 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers)
|
196 |
+
的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。
|
197 |
+
- DeepSpeed Zero2:由于我们依赖 `accelerate`
|
198 |
+
,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed
|
199 |
+
Zero2 优化训练。
|
200 |
+
|
201 |
+
> [!重要提示]
|
202 |
+
> 内存需求是运行 `training/prepare_dataset.py`
|
203 |
+
>
|
204 |
+
后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。
|
205 |
+
>
|
206 |
+
> 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。
|
207 |
+
|
208 |
+
### LoRA微调
|
209 |
+
|
210 |
+
> [!重要提示]
|
211 |
+
> 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。
|
212 |
+
>
|
213 |
+
> 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧:
|
214 |
+
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
215 |
+
> 或提供一个有效且可访问的图像URL。
|
216 |
+
|
217 |
+
<details>
|
218 |
+
<summary> AdamW </summary>
|
219 |
+
|
220 |
+
**注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。
|
221 |
+
|
222 |
+
当 `train_batch_size = 1` 时:
|
223 |
+
|
224 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
225 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
226 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
|
227 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
|
228 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
|
229 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
|
230 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
|
231 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
|
232 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
|
233 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
|
234 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
|
235 |
+
|
236 |
+
当 `train_batch_size = 4` 时:
|
237 |
+
|
238 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
239 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
240 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
|
241 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
|
242 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
|
243 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
|
244 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
|
245 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
|
246 |
+
|
247 |
+
</details>
|
248 |
+
|
249 |
+
<details>
|
250 |
+
<summary> AdamW (8-bit bitsandbytes) </summary>
|
251 |
+
|
252 |
+
**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
|
253 |
+
|
254 |
+
当 `train_batch_size = 1` 时:
|
255 |
+
|
256 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
257 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
258 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
|
259 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
|
260 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
|
261 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
|
262 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
|
263 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
|
264 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
|
265 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
|
266 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
|
267 |
+
|
268 |
+
当 `train_batch_size = 4` 时:
|
269 |
+
|
270 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
271 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
272 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
|
273 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
|
274 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
|
275 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
|
276 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
|
277 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
|
278 |
+
|
279 |
+
</details>
|
280 |
+
|
281 |
+
<details>
|
282 |
+
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
283 |
+
|
284 |
+
**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
|
285 |
+
|
286 |
+
当 `train_batch_size = 1` 时:
|
287 |
+
|
288 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
289 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
290 |
+
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
|
291 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
|
292 |
+
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
|
293 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
|
294 |
+
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
|
295 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
|
296 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
|
297 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
|
298 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
|
299 |
+
|
300 |
+
当 `train_batch_size = 4` 时:
|
301 |
+
|
302 |
+
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
303 |
+
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
304 |
+
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
|
305 |
+
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
|
306 |
+
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
|
307 |
+
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
|
308 |
+
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
|
309 |
+
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
|
310 |
+
|
311 |
+
</details>
|
312 |
+
|
313 |
+
<details>
|
314 |
+
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
315 |
+
|
316 |
+
**注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。
|
317 |
+
|
318 |
+
当 `train_batch_size = 1` 时:
|
319 |
+
|
320 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
321 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
322 |
+
| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
|
323 |
+
| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
|
324 |
+
|
325 |
+
当 `train_batch_size = 4` 时:
|
326 |
+
|
327 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
328 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
329 |
+
| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
|
330 |
+
| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
|
331 |
+
|
332 |
+
</details>
|
333 |
+
|
334 |
+
### Full finetuning
|
335 |
+
|
336 |
+
> [!注意]
|
337 |
+
> 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。
|
338 |
+
>
|
339 |
+
> 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧:
|
340 |
+
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
341 |
+
> 或提供一个有效且可访问的图像 URL。
|
342 |
+
|
343 |
+
> [!注意]
|
344 |
+
> 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。
|
345 |
+
|
346 |
+
<details>
|
347 |
+
<summary> AdamW </summary>
|
348 |
+
|
349 |
+
当 `train_batch_size = 1` 时:
|
350 |
+
|
351 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
352 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
353 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
|
354 |
+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
355 |
+
|
356 |
+
当 `train_batch_size = 4` 时:
|
357 |
+
|
358 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
359 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
360 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
|
361 |
+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
362 |
+
|
363 |
+
</details>
|
364 |
+
|
365 |
+
<details>
|
366 |
+
<summary> AdamW (8-bit 量化) </summary>
|
367 |
+
|
368 |
+
当 `train_batch_size = 1` 时:
|
369 |
+
|
370 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
371 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
372 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
|
373 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
|
374 |
+
|
375 |
+
当 `train_batch_size = 4` 时:
|
376 |
+
|
377 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
378 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
379 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
|
380 |
+
| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
|
381 |
+
|
382 |
+
</details>
|
383 |
+
|
384 |
+
<details>
|
385 |
+
<summary> AdamW + CPUOffloadOptimizer(带有梯度卸载)</summary>
|
386 |
+
|
387 |
+
当 `train_batch_size = 1` 时:
|
388 |
+
|
389 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
390 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
391 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
|
392 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
|
393 |
+
|
394 |
+
当 `train_batch_size = 4` 时:
|
395 |
+
|
396 |
+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
397 |
+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
398 |
+
| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
|
399 |
+
| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
|
400 |
+
|
401 |
+
</details>
|
402 |
+
|
403 |
+
<details>
|
404 |
+
<summary> DeepSpeed(AdamW + CPU/参数卸载) </summary>
|
405 |
+
|
406 |
+
**注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。
|
407 |
+
|
408 |
+
当 `train_batch_size = 1` 时:
|
409 |
+
|
410 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
411 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
412 |
+
| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
|
413 |
+
| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
|
414 |
+
|
415 |
+
当 `train_batch_size = 4` 时:
|
416 |
+
|
417 |
+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
418 |
+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
419 |
+
| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
|
420 |
+
| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
|
421 |
+
|
422 |
+
</details>
|
423 |
+
|
424 |
+
> [!注意]
|
425 |
+
> - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载
|
426 |
+
VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。
|
427 |
+
>
|
428 |
+
> - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。
|
429 |
+
|
430 |
+
<table align="center">
|
431 |
+
<tr>
|
432 |
+
<td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
|
433 |
+
</tr>
|
434 |
+
<tr>
|
435 |
+
<td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
|
436 |
+
</tr>
|
437 |
+
</table>
|
438 |
+
|
439 |
+
## 待办事项
|
440 |
+
|
441 |
+
- [x] 使脚本兼容 DDP
|
442 |
+
- [ ] 使脚本兼容 FSDP
|
443 |
+
- [x] 使脚本兼容 DeepSpeed
|
444 |
+
- [ ] 基于 vLLM 的字幕脚本
|
445 |
+
- [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数
|
446 |
+
- [ ] 分析性能瓶颈并尽可能减少同步操作
|
447 |
+
- [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法
|
448 |
+
- [x] 使用 bitsandbytes 的节省内存优化器测试脚本
|
449 |
+
- [x] 使用 CPUOffloadOptimizer 等测试脚本
|
450 |
+
- [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错)
|
451 |
+
- [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错)
|
452 |
+
- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化)
|
453 |
+
|
454 |
+
> [!重要]
|
455 |
+
> 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/__init__.py
ADDED
File without changes
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/args.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def _get_model_args(parser: argparse.ArgumentParser) -> None:
|
5 |
+
parser.add_argument(
|
6 |
+
"--pretrained_model_name_or_path",
|
7 |
+
type=str,
|
8 |
+
default=None,
|
9 |
+
required=True,
|
10 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
11 |
+
)
|
12 |
+
parser.add_argument(
|
13 |
+
"--revision",
|
14 |
+
type=str,
|
15 |
+
default=None,
|
16 |
+
required=False,
|
17 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--variant",
|
21 |
+
type=str,
|
22 |
+
default=None,
|
23 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--cache_dir",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
|
34 |
+
parser.add_argument(
|
35 |
+
"--data_root",
|
36 |
+
type=str,
|
37 |
+
default=None,
|
38 |
+
help=("A folder containing the training data."),
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--dataset_file",
|
42 |
+
type=str,
|
43 |
+
default=None,
|
44 |
+
help=("Path to a CSV file if loading prompts/video paths using this format."),
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--video_column",
|
48 |
+
type=str,
|
49 |
+
default="video",
|
50 |
+
help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--caption_column",
|
54 |
+
type=str,
|
55 |
+
default="text",
|
56 |
+
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--id_token",
|
60 |
+
type=str,
|
61 |
+
default=None,
|
62 |
+
help="Identifier token appended to the start of each prompt if provided.",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--height_buckets",
|
66 |
+
nargs="+",
|
67 |
+
type=int,
|
68 |
+
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--width_buckets",
|
72 |
+
nargs="+",
|
73 |
+
type=int,
|
74 |
+
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--frame_buckets",
|
78 |
+
nargs="+",
|
79 |
+
type=int,
|
80 |
+
default=[49],
|
81 |
+
help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--load_tensors",
|
85 |
+
action="store_true",
|
86 |
+
help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--random_flip",
|
90 |
+
type=float,
|
91 |
+
default=None,
|
92 |
+
help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--dataloader_num_workers",
|
96 |
+
type=int,
|
97 |
+
default=0,
|
98 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--pin_memory",
|
102 |
+
action="store_true",
|
103 |
+
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def _get_validation_args(parser: argparse.ArgumentParser) -> None:
|
108 |
+
parser.add_argument(
|
109 |
+
"--validation_prompt",
|
110 |
+
type=str,
|
111 |
+
default=None,
|
112 |
+
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--validation_images",
|
116 |
+
type=str,
|
117 |
+
default=None,
|
118 |
+
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--validation_prompt_separator",
|
122 |
+
type=str,
|
123 |
+
default=":::",
|
124 |
+
help="String that separates multiple validation prompts",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--num_validation_videos",
|
128 |
+
type=int,
|
129 |
+
default=1,
|
130 |
+
help="Number of videos that should be generated during validation per `validation_prompt`.",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--validation_epochs",
|
134 |
+
type=int,
|
135 |
+
default=None,
|
136 |
+
help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--validation_steps",
|
140 |
+
type=int,
|
141 |
+
default=None,
|
142 |
+
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--guidance_scale",
|
146 |
+
type=float,
|
147 |
+
default=6,
|
148 |
+
help="The guidance scale to use while sampling validation videos.",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--use_dynamic_cfg",
|
152 |
+
action="store_true",
|
153 |
+
default=False,
|
154 |
+
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--enable_model_cpu_offload",
|
158 |
+
action="store_true",
|
159 |
+
default=False,
|
160 |
+
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def _get_training_args(parser: argparse.ArgumentParser) -> None:
|
165 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
166 |
+
parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
|
167 |
+
parser.add_argument(
|
168 |
+
"--lora_alpha",
|
169 |
+
type=int,
|
170 |
+
default=64,
|
171 |
+
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--mixed_precision",
|
175 |
+
type=str,
|
176 |
+
default=None,
|
177 |
+
choices=["no", "fp16", "bf16"],
|
178 |
+
help=(
|
179 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. "
|
180 |
+
"Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this "
|
181 |
+
"argument to override the accelerate config."
|
182 |
+
),
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--output_dir",
|
186 |
+
type=str,
|
187 |
+
default="cogvideox-sft",
|
188 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--height",
|
192 |
+
type=int,
|
193 |
+
default=480,
|
194 |
+
help="All input videos are resized to this height.",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--width",
|
198 |
+
type=int,
|
199 |
+
default=720,
|
200 |
+
help="All input videos are resized to this width.",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--video_reshape_mode",
|
204 |
+
type=str,
|
205 |
+
default=None,
|
206 |
+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
207 |
+
)
|
208 |
+
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
209 |
+
parser.add_argument(
|
210 |
+
"--max_num_frames",
|
211 |
+
type=int,
|
212 |
+
default=49,
|
213 |
+
help="All input videos will be truncated to these many frames.",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--skip_frames_start",
|
217 |
+
type=int,
|
218 |
+
default=0,
|
219 |
+
help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--skip_frames_end",
|
223 |
+
type=int,
|
224 |
+
default=0,
|
225 |
+
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--train_batch_size",
|
229 |
+
type=int,
|
230 |
+
default=4,
|
231 |
+
help="Batch size (per device) for the training dataloader.",
|
232 |
+
)
|
233 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
234 |
+
parser.add_argument(
|
235 |
+
"--max_train_steps",
|
236 |
+
type=int,
|
237 |
+
default=None,
|
238 |
+
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--checkpointing_steps",
|
242 |
+
type=int,
|
243 |
+
default=500,
|
244 |
+
help=(
|
245 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
246 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
247 |
+
" training using `--resume_from_checkpoint`."
|
248 |
+
),
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--checkpoints_total_limit",
|
252 |
+
type=int,
|
253 |
+
default=None,
|
254 |
+
help=("Max number of checkpoints to store."),
|
255 |
+
)
|
256 |
+
parser.add_argument(
|
257 |
+
"--resume_from_checkpoint",
|
258 |
+
type=str,
|
259 |
+
default=None,
|
260 |
+
help=(
|
261 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
262 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
263 |
+
),
|
264 |
+
)
|
265 |
+
parser.add_argument(
|
266 |
+
"--gradient_accumulation_steps",
|
267 |
+
type=int,
|
268 |
+
default=1,
|
269 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--gradient_checkpointing",
|
273 |
+
action="store_true",
|
274 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
275 |
+
)
|
276 |
+
parser.add_argument(
|
277 |
+
"--learning_rate",
|
278 |
+
type=float,
|
279 |
+
default=1e-4,
|
280 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--scale_lr",
|
284 |
+
action="store_true",
|
285 |
+
default=False,
|
286 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--lr_scheduler",
|
290 |
+
type=str,
|
291 |
+
default="constant",
|
292 |
+
help=(
|
293 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
294 |
+
' "constant", "constant_with_warmup"]'
|
295 |
+
),
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--lr_warmup_steps",
|
299 |
+
type=int,
|
300 |
+
default=500,
|
301 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--lr_num_cycles",
|
305 |
+
type=int,
|
306 |
+
default=1,
|
307 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--lr_power",
|
311 |
+
type=float,
|
312 |
+
default=1.0,
|
313 |
+
help="Power factor of the polynomial scheduler.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--enable_slicing",
|
317 |
+
action="store_true",
|
318 |
+
default=False,
|
319 |
+
help="Whether or not to use VAE slicing for saving memory.",
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--enable_tiling",
|
323 |
+
action="store_true",
|
324 |
+
default=False,
|
325 |
+
help="Whether or not to use VAE tiling for saving memory.",
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--noised_image_dropout",
|
329 |
+
type=float,
|
330 |
+
default=0.05,
|
331 |
+
help="Image condition dropout probability when finetuning image-to-video.",
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--ignore_learned_positional_embeddings",
|
335 |
+
action="store_true",
|
336 |
+
default=False,
|
337 |
+
help=(
|
338 |
+
"Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
|
339 |
+
"should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
|
340 |
+
"otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
|
341 |
+
),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
|
346 |
+
parser.add_argument(
|
347 |
+
"--optimizer",
|
348 |
+
type=lambda s: s.lower(),
|
349 |
+
default="adam",
|
350 |
+
choices=["adam", "adamw", "prodigy", "came"],
|
351 |
+
help=("The optimizer type to use."),
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--use_8bit",
|
355 |
+
action="store_true",
|
356 |
+
help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.",
|
357 |
+
)
|
358 |
+
parser.add_argument(
|
359 |
+
"--use_4bit",
|
360 |
+
action="store_true",
|
361 |
+
help="Whether or not to use 4-bit optimizers from `torchao`.",
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers."
|
365 |
+
)
|
366 |
+
parser.add_argument(
|
367 |
+
"--beta1",
|
368 |
+
type=float,
|
369 |
+
default=0.9,
|
370 |
+
help="The beta1 parameter for the Adam and Prodigy optimizers.",
|
371 |
+
)
|
372 |
+
parser.add_argument(
|
373 |
+
"--beta2",
|
374 |
+
type=float,
|
375 |
+
default=0.95,
|
376 |
+
help="The beta2 parameter for the Adam and Prodigy optimizers.",
|
377 |
+
)
|
378 |
+
parser.add_argument(
|
379 |
+
"--beta3",
|
380 |
+
type=float,
|
381 |
+
default=None,
|
382 |
+
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
|
383 |
+
)
|
384 |
+
parser.add_argument(
|
385 |
+
"--prodigy_decouple",
|
386 |
+
action="store_true",
|
387 |
+
help="Use AdamW style decoupled weight decay.",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--weight_decay",
|
391 |
+
type=float,
|
392 |
+
default=1e-04,
|
393 |
+
help="Weight decay to use for optimizer.",
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--epsilon",
|
397 |
+
type=float,
|
398 |
+
default=1e-8,
|
399 |
+
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
400 |
+
)
|
401 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
402 |
+
parser.add_argument(
|
403 |
+
"--prodigy_use_bias_correction",
|
404 |
+
action="store_true",
|
405 |
+
help="Turn on Adam's bias correction.",
|
406 |
+
)
|
407 |
+
parser.add_argument(
|
408 |
+
"--prodigy_safeguard_warmup",
|
409 |
+
action="store_true",
|
410 |
+
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
"--use_cpu_offload_optimizer",
|
414 |
+
action="store_true",
|
415 |
+
help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.",
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--offload_gradients",
|
419 |
+
action="store_true",
|
420 |
+
help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.",
|
421 |
+
)
|
422 |
+
|
423 |
+
|
424 |
+
def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
|
425 |
+
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
|
426 |
+
parser.add_argument(
|
427 |
+
"--push_to_hub",
|
428 |
+
action="store_true",
|
429 |
+
help="Whether or not to push the model to the Hub.",
|
430 |
+
)
|
431 |
+
parser.add_argument(
|
432 |
+
"--hub_token",
|
433 |
+
type=str,
|
434 |
+
default=None,
|
435 |
+
help="The token to use to push to the Model Hub.",
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--hub_model_id",
|
439 |
+
type=str,
|
440 |
+
default=None,
|
441 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
442 |
+
)
|
443 |
+
parser.add_argument(
|
444 |
+
"--logging_dir",
|
445 |
+
type=str,
|
446 |
+
default="logs",
|
447 |
+
help="Directory where logs are stored.",
|
448 |
+
)
|
449 |
+
parser.add_argument(
|
450 |
+
"--allow_tf32",
|
451 |
+
action="store_true",
|
452 |
+
help=(
|
453 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
454 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
455 |
+
),
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--nccl_timeout",
|
459 |
+
type=int,
|
460 |
+
default=600,
|
461 |
+
help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--report_to",
|
465 |
+
type=str,
|
466 |
+
default=None,
|
467 |
+
help=(
|
468 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
469 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
470 |
+
),
|
471 |
+
)
|
472 |
+
|
473 |
+
|
474 |
+
def get_args():
|
475 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
|
476 |
+
|
477 |
+
_get_model_args(parser)
|
478 |
+
_get_dataset_args(parser)
|
479 |
+
_get_training_args(parser)
|
480 |
+
_get_validation_args(parser)
|
481 |
+
_get_optimizer_args(parser)
|
482 |
+
_get_configuration_args(parser)
|
483 |
+
|
484 |
+
return parser.parse_args()
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_lora.py
ADDED
@@ -0,0 +1,1016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import shutil
|
22 |
+
from datetime import timedelta
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Dict
|
25 |
+
|
26 |
+
import diffusers
|
27 |
+
import torch
|
28 |
+
import transformers
|
29 |
+
import wandb
|
30 |
+
from accelerate import Accelerator, DistributedType
|
31 |
+
from accelerate.logging import get_logger
|
32 |
+
from accelerate.utils import (
|
33 |
+
DistributedDataParallelKwargs,
|
34 |
+
InitProcessGroupKwargs,
|
35 |
+
ProjectConfiguration,
|
36 |
+
set_seed,
|
37 |
+
)
|
38 |
+
from diffusers import (
|
39 |
+
AutoencoderKLCogVideoX,
|
40 |
+
CogVideoXDPMScheduler,
|
41 |
+
CogVideoXImageToVideoPipeline,
|
42 |
+
CogVideoXTransformer3DModel,
|
43 |
+
)
|
44 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
45 |
+
from diffusers.optimization import get_scheduler
|
46 |
+
from diffusers.training_utils import cast_training_params
|
47 |
+
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
|
48 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
49 |
+
from huggingface_hub import create_repo, upload_folder
|
50 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
51 |
+
from torch.utils.data import DataLoader
|
52 |
+
from tqdm.auto import tqdm
|
53 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
54 |
+
|
55 |
+
|
56 |
+
from args import get_args # isort:skip
|
57 |
+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
58 |
+
from text_encoder import compute_prompt_embeddings # isort:skip
|
59 |
+
from utils import (
|
60 |
+
get_gradient_norm,
|
61 |
+
get_optimizer,
|
62 |
+
prepare_rotary_positional_embeddings,
|
63 |
+
print_memory,
|
64 |
+
reset_memory,
|
65 |
+
unwrap_model,
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
logger = get_logger(__name__)
|
70 |
+
|
71 |
+
|
72 |
+
def save_model_card(
|
73 |
+
repo_id: str,
|
74 |
+
videos=None,
|
75 |
+
base_model: str = None,
|
76 |
+
validation_prompt=None,
|
77 |
+
repo_folder=None,
|
78 |
+
fps=8,
|
79 |
+
):
|
80 |
+
widget_dict = []
|
81 |
+
if videos is not None:
|
82 |
+
for i, video in enumerate(videos):
|
83 |
+
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
84 |
+
widget_dict.append(
|
85 |
+
{
|
86 |
+
"text": validation_prompt if validation_prompt else " ",
|
87 |
+
"output": {"url": f"video_{i}.mp4"},
|
88 |
+
}
|
89 |
+
)
|
90 |
+
|
91 |
+
model_description = f"""
|
92 |
+
# CogVideoX LoRA Finetune
|
93 |
+
|
94 |
+
<Gallery />
|
95 |
+
|
96 |
+
## Model description
|
97 |
+
|
98 |
+
This is a lora finetune of the CogVideoX model `{base_model}`.
|
99 |
+
|
100 |
+
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
101 |
+
|
102 |
+
## Download model
|
103 |
+
|
104 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
105 |
+
|
106 |
+
## Usage
|
107 |
+
|
108 |
+
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
109 |
+
|
110 |
+
```py
|
111 |
+
import torch
|
112 |
+
from diffusers import CogVideoXImageToVideoPipeline
|
113 |
+
from diffusers.utils import export_to_video, load_image
|
114 |
+
|
115 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda")
|
116 |
+
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
|
117 |
+
|
118 |
+
# The LoRA adapter weights are determined by what was used for training.
|
119 |
+
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
120 |
+
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
121 |
+
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
122 |
+
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
123 |
+
|
124 |
+
image = load_image("/path/to/image.png")
|
125 |
+
video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
126 |
+
export_to_video(video, "output.mp4", fps=8)
|
127 |
+
```
|
128 |
+
|
129 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
130 |
+
|
131 |
+
## License
|
132 |
+
|
133 |
+
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
|
134 |
+
"""
|
135 |
+
model_card = load_or_create_model_card(
|
136 |
+
repo_id_or_path=repo_id,
|
137 |
+
from_training=True,
|
138 |
+
license="other",
|
139 |
+
base_model=base_model,
|
140 |
+
prompt=validation_prompt,
|
141 |
+
model_description=model_description,
|
142 |
+
widget=widget_dict,
|
143 |
+
)
|
144 |
+
tags = [
|
145 |
+
"text-to-video",
|
146 |
+
"image-to-video",
|
147 |
+
"diffusers-training",
|
148 |
+
"diffusers",
|
149 |
+
"lora",
|
150 |
+
"cogvideox",
|
151 |
+
"cogvideox-diffusers",
|
152 |
+
"template:sd-lora",
|
153 |
+
]
|
154 |
+
|
155 |
+
model_card = populate_model_card(model_card, tags=tags)
|
156 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
157 |
+
|
158 |
+
|
159 |
+
def log_validation(
|
160 |
+
accelerator: Accelerator,
|
161 |
+
pipe: CogVideoXImageToVideoPipeline,
|
162 |
+
args: Dict[str, Any],
|
163 |
+
pipeline_args: Dict[str, Any],
|
164 |
+
is_final_validation: bool = False,
|
165 |
+
):
|
166 |
+
logger.info(
|
167 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
168 |
+
)
|
169 |
+
|
170 |
+
pipe = pipe.to(accelerator.device)
|
171 |
+
|
172 |
+
# run inference
|
173 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
174 |
+
|
175 |
+
videos = []
|
176 |
+
for _ in range(args.num_validation_videos):
|
177 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
178 |
+
videos.append(video)
|
179 |
+
|
180 |
+
for tracker in accelerator.trackers:
|
181 |
+
phase_name = "test" if is_final_validation else "validation"
|
182 |
+
if tracker.name == "wandb":
|
183 |
+
video_filenames = []
|
184 |
+
for i, video in enumerate(videos):
|
185 |
+
prompt = (
|
186 |
+
pipeline_args["prompt"][:25]
|
187 |
+
.replace(" ", "_")
|
188 |
+
.replace(" ", "_")
|
189 |
+
.replace("'", "_")
|
190 |
+
.replace('"', "_")
|
191 |
+
.replace("/", "_")
|
192 |
+
)
|
193 |
+
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
194 |
+
export_to_video(video, filename, fps=8)
|
195 |
+
video_filenames.append(filename)
|
196 |
+
|
197 |
+
tracker.log(
|
198 |
+
{
|
199 |
+
phase_name: [
|
200 |
+
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
201 |
+
for i, filename in enumerate(video_filenames)
|
202 |
+
]
|
203 |
+
}
|
204 |
+
)
|
205 |
+
|
206 |
+
return videos
|
207 |
+
|
208 |
+
|
209 |
+
def run_validation(
|
210 |
+
args: Dict[str, Any],
|
211 |
+
accelerator: Accelerator,
|
212 |
+
transformer,
|
213 |
+
scheduler,
|
214 |
+
model_config: Dict[str, Any],
|
215 |
+
weight_dtype: torch.dtype,
|
216 |
+
) -> None:
|
217 |
+
accelerator.print("===== Memory before validation =====")
|
218 |
+
print_memory(accelerator.device)
|
219 |
+
torch.cuda.synchronize(accelerator.device)
|
220 |
+
|
221 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
222 |
+
args.pretrained_model_name_or_path,
|
223 |
+
transformer=unwrap_model(accelerator, transformer),
|
224 |
+
scheduler=scheduler,
|
225 |
+
revision=args.revision,
|
226 |
+
variant=args.variant,
|
227 |
+
torch_dtype=weight_dtype,
|
228 |
+
)
|
229 |
+
|
230 |
+
if args.enable_slicing:
|
231 |
+
pipe.vae.enable_slicing()
|
232 |
+
if args.enable_tiling:
|
233 |
+
pipe.vae.enable_tiling()
|
234 |
+
if args.enable_model_cpu_offload:
|
235 |
+
pipe.enable_model_cpu_offload()
|
236 |
+
|
237 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
238 |
+
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
239 |
+
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
240 |
+
pipeline_args = {
|
241 |
+
"image": load_image(validation_image),
|
242 |
+
"prompt": validation_prompt,
|
243 |
+
"guidance_scale": args.guidance_scale,
|
244 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
245 |
+
"height": args.height,
|
246 |
+
"width": args.width,
|
247 |
+
"max_sequence_length": model_config.max_text_seq_length,
|
248 |
+
}
|
249 |
+
|
250 |
+
log_validation(
|
251 |
+
pipe=pipe,
|
252 |
+
args=args,
|
253 |
+
accelerator=accelerator,
|
254 |
+
pipeline_args=pipeline_args,
|
255 |
+
)
|
256 |
+
|
257 |
+
accelerator.print("===== Memory after validation =====")
|
258 |
+
print_memory(accelerator.device)
|
259 |
+
reset_memory(accelerator.device)
|
260 |
+
|
261 |
+
del pipe
|
262 |
+
gc.collect()
|
263 |
+
torch.cuda.empty_cache()
|
264 |
+
torch.cuda.synchronize(accelerator.device)
|
265 |
+
|
266 |
+
|
267 |
+
class CollateFunction:
|
268 |
+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
269 |
+
self.weight_dtype = weight_dtype
|
270 |
+
self.load_tensors = load_tensors
|
271 |
+
|
272 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
273 |
+
prompts = [x["prompt"] for x in data[0]]
|
274 |
+
|
275 |
+
if self.load_tensors:
|
276 |
+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
277 |
+
|
278 |
+
images = [x["image"] for x in data[0]]
|
279 |
+
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
|
280 |
+
|
281 |
+
videos = [x["video"] for x in data[0]]
|
282 |
+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
283 |
+
|
284 |
+
return {
|
285 |
+
"images": images,
|
286 |
+
"videos": videos,
|
287 |
+
"prompts": prompts,
|
288 |
+
}
|
289 |
+
|
290 |
+
|
291 |
+
def main(args):
|
292 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
293 |
+
raise ValueError(
|
294 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
295 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
296 |
+
)
|
297 |
+
|
298 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
299 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
300 |
+
raise ValueError(
|
301 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
302 |
+
)
|
303 |
+
|
304 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
305 |
+
|
306 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
307 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
308 |
+
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
309 |
+
accelerator = Accelerator(
|
310 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
311 |
+
mixed_precision=args.mixed_precision,
|
312 |
+
log_with=args.report_to,
|
313 |
+
project_config=accelerator_project_config,
|
314 |
+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
315 |
+
)
|
316 |
+
|
317 |
+
# Disable AMP for MPS.
|
318 |
+
if torch.backends.mps.is_available():
|
319 |
+
accelerator.native_amp = False
|
320 |
+
|
321 |
+
# Make one log on every process with the configuration for debugging.
|
322 |
+
logging.basicConfig(
|
323 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
324 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
325 |
+
level=logging.INFO,
|
326 |
+
)
|
327 |
+
logger.info(accelerator.state, main_process_only=False)
|
328 |
+
if accelerator.is_local_main_process:
|
329 |
+
transformers.utils.logging.set_verbosity_warning()
|
330 |
+
diffusers.utils.logging.set_verbosity_info()
|
331 |
+
else:
|
332 |
+
transformers.utils.logging.set_verbosity_error()
|
333 |
+
diffusers.utils.logging.set_verbosity_error()
|
334 |
+
|
335 |
+
# If passed along, set the training seed now.
|
336 |
+
if args.seed is not None:
|
337 |
+
set_seed(args.seed)
|
338 |
+
|
339 |
+
# Handle the repository creation
|
340 |
+
if accelerator.is_main_process:
|
341 |
+
if args.output_dir is not None:
|
342 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
343 |
+
|
344 |
+
if args.push_to_hub:
|
345 |
+
repo_id = create_repo(
|
346 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
347 |
+
exist_ok=True,
|
348 |
+
).repo_id
|
349 |
+
|
350 |
+
# Prepare models and scheduler
|
351 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
352 |
+
args.pretrained_model_name_or_path,
|
353 |
+
subfolder="tokenizer",
|
354 |
+
revision=args.revision,
|
355 |
+
)
|
356 |
+
|
357 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
358 |
+
args.pretrained_model_name_or_path,
|
359 |
+
subfolder="text_encoder",
|
360 |
+
revision=args.revision,
|
361 |
+
)
|
362 |
+
|
363 |
+
# CogVideoX-2b weights are stored in float16
|
364 |
+
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
365 |
+
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
366 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
367 |
+
args.pretrained_model_name_or_path,
|
368 |
+
subfolder="transformer",
|
369 |
+
torch_dtype=load_dtype,
|
370 |
+
revision=args.revision,
|
371 |
+
variant=args.variant,
|
372 |
+
)
|
373 |
+
|
374 |
+
# These changes will also be required when trying to run inference with the trained lora
|
375 |
+
if args.ignore_learned_positional_embeddings:
|
376 |
+
del transformer.patch_embed.pos_embedding
|
377 |
+
transformer.patch_embed.use_learned_positional_embeddings = False
|
378 |
+
transformer.config.use_learned_positional_embeddings = False
|
379 |
+
|
380 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
381 |
+
args.pretrained_model_name_or_path,
|
382 |
+
subfolder="vae",
|
383 |
+
revision=args.revision,
|
384 |
+
variant=args.variant,
|
385 |
+
)
|
386 |
+
|
387 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
388 |
+
|
389 |
+
if args.enable_slicing:
|
390 |
+
vae.enable_slicing()
|
391 |
+
if args.enable_tiling:
|
392 |
+
vae.enable_tiling()
|
393 |
+
|
394 |
+
# We only train the additional adapter LoRA layers
|
395 |
+
text_encoder.requires_grad_(False)
|
396 |
+
transformer.requires_grad_(False)
|
397 |
+
vae.requires_grad_(False)
|
398 |
+
|
399 |
+
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
400 |
+
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
401 |
+
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
402 |
+
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
403 |
+
|
404 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
405 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
406 |
+
weight_dtype = torch.float32
|
407 |
+
if accelerator.state.deepspeed_plugin:
|
408 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
409 |
+
if (
|
410 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
411 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
412 |
+
):
|
413 |
+
weight_dtype = torch.float16
|
414 |
+
if (
|
415 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
416 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
417 |
+
):
|
418 |
+
weight_dtype = torch.bfloat16
|
419 |
+
else:
|
420 |
+
if accelerator.mixed_precision == "fp16":
|
421 |
+
weight_dtype = torch.float16
|
422 |
+
elif accelerator.mixed_precision == "bf16":
|
423 |
+
weight_dtype = torch.bfloat16
|
424 |
+
|
425 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
426 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
427 |
+
raise ValueError(
|
428 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
429 |
+
)
|
430 |
+
|
431 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
432 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
433 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
434 |
+
|
435 |
+
if args.gradient_checkpointing:
|
436 |
+
transformer.enable_gradient_checkpointing()
|
437 |
+
|
438 |
+
# now we will add new LoRA weights to the attention layers
|
439 |
+
transformer_lora_config = LoraConfig(
|
440 |
+
r=args.rank,
|
441 |
+
lora_alpha=args.lora_alpha,
|
442 |
+
init_lora_weights=True,
|
443 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
444 |
+
)
|
445 |
+
transformer.add_adapter(transformer_lora_config)
|
446 |
+
|
447 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
448 |
+
def save_model_hook(models, weights, output_dir):
|
449 |
+
if accelerator.is_main_process:
|
450 |
+
transformer_lora_layers_to_save = None
|
451 |
+
|
452 |
+
for model in models:
|
453 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
454 |
+
model = unwrap_model(accelerator, model)
|
455 |
+
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
456 |
+
else:
|
457 |
+
raise ValueError(f"Unexpected save model: {model.__class__}")
|
458 |
+
|
459 |
+
# make sure to pop weight so that corresponding model is not saved again
|
460 |
+
if weights:
|
461 |
+
weights.pop()
|
462 |
+
|
463 |
+
CogVideoXImageToVideoPipeline.save_lora_weights(
|
464 |
+
output_dir,
|
465 |
+
transformer_lora_layers=transformer_lora_layers_to_save,
|
466 |
+
)
|
467 |
+
|
468 |
+
def load_model_hook(models, input_dir):
|
469 |
+
transformer_ = None
|
470 |
+
|
471 |
+
# This is a bit of a hack but I don't know any other solution.
|
472 |
+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
473 |
+
while len(models) > 0:
|
474 |
+
model = models.pop()
|
475 |
+
|
476 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
477 |
+
transformer_ = unwrap_model(accelerator, model)
|
478 |
+
else:
|
479 |
+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
480 |
+
else:
|
481 |
+
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
|
482 |
+
args.pretrained_model_name_or_path, subfolder="transformer"
|
483 |
+
)
|
484 |
+
transformer_.add_adapter(transformer_lora_config)
|
485 |
+
|
486 |
+
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
|
487 |
+
|
488 |
+
transformer_state_dict = {
|
489 |
+
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
490 |
+
}
|
491 |
+
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
492 |
+
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
493 |
+
if incompatible_keys is not None:
|
494 |
+
# check only for unexpected keys
|
495 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
496 |
+
if unexpected_keys:
|
497 |
+
logger.warning(
|
498 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
499 |
+
f" {unexpected_keys}. "
|
500 |
+
)
|
501 |
+
|
502 |
+
# Make sure the trainable params are in float32. This is again needed since the base models
|
503 |
+
# are in `weight_dtype`. More details:
|
504 |
+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
505 |
+
if args.mixed_precision == "fp16":
|
506 |
+
# only upcast trainable parameters (LoRA) into fp32
|
507 |
+
cast_training_params([transformer_])
|
508 |
+
|
509 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
510 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
511 |
+
|
512 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
513 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
514 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
515 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
516 |
+
|
517 |
+
if args.scale_lr:
|
518 |
+
args.learning_rate = (
|
519 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
520 |
+
)
|
521 |
+
|
522 |
+
# Make sure the trainable params are in float32.
|
523 |
+
if args.mixed_precision == "fp16":
|
524 |
+
# only upcast trainable parameters (LoRA) into fp32
|
525 |
+
cast_training_params([transformer], dtype=torch.float32)
|
526 |
+
|
527 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
528 |
+
|
529 |
+
# Optimization parameters
|
530 |
+
transformer_parameters_with_lr = {
|
531 |
+
"params": transformer_lora_parameters,
|
532 |
+
"lr": args.learning_rate,
|
533 |
+
}
|
534 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
535 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
536 |
+
|
537 |
+
use_deepspeed_optimizer = (
|
538 |
+
accelerator.state.deepspeed_plugin is not None
|
539 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
540 |
+
)
|
541 |
+
use_deepspeed_scheduler = (
|
542 |
+
accelerator.state.deepspeed_plugin is not None
|
543 |
+
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
544 |
+
)
|
545 |
+
|
546 |
+
optimizer = get_optimizer(
|
547 |
+
params_to_optimize=params_to_optimize,
|
548 |
+
optimizer_name=args.optimizer,
|
549 |
+
learning_rate=args.learning_rate,
|
550 |
+
beta1=args.beta1,
|
551 |
+
beta2=args.beta2,
|
552 |
+
beta3=args.beta3,
|
553 |
+
epsilon=args.epsilon,
|
554 |
+
weight_decay=args.weight_decay,
|
555 |
+
prodigy_decouple=args.prodigy_decouple,
|
556 |
+
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
557 |
+
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
558 |
+
use_8bit=args.use_8bit,
|
559 |
+
use_4bit=args.use_4bit,
|
560 |
+
use_torchao=args.use_torchao,
|
561 |
+
use_deepspeed=use_deepspeed_optimizer,
|
562 |
+
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
563 |
+
offload_gradients=args.offload_gradients,
|
564 |
+
)
|
565 |
+
|
566 |
+
# Dataset and DataLoader
|
567 |
+
dataset_init_kwargs = {
|
568 |
+
"data_root": args.data_root,
|
569 |
+
"dataset_file": args.dataset_file,
|
570 |
+
"caption_column": args.caption_column,
|
571 |
+
"video_column": args.video_column,
|
572 |
+
"max_num_frames": args.max_num_frames,
|
573 |
+
"id_token": args.id_token,
|
574 |
+
"height_buckets": args.height_buckets,
|
575 |
+
"width_buckets": args.width_buckets,
|
576 |
+
"frame_buckets": args.frame_buckets,
|
577 |
+
"load_tensors": args.load_tensors,
|
578 |
+
"random_flip": args.random_flip,
|
579 |
+
"image_to_video": True,
|
580 |
+
}
|
581 |
+
if args.video_reshape_mode is None:
|
582 |
+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
583 |
+
else:
|
584 |
+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
585 |
+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
586 |
+
)
|
587 |
+
|
588 |
+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
589 |
+
|
590 |
+
train_dataloader = DataLoader(
|
591 |
+
train_dataset,
|
592 |
+
batch_size=1,
|
593 |
+
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
594 |
+
collate_fn=collate_fn,
|
595 |
+
num_workers=args.dataloader_num_workers,
|
596 |
+
pin_memory=args.pin_memory,
|
597 |
+
)
|
598 |
+
|
599 |
+
# Scheduler and math around the number of training steps.
|
600 |
+
overrode_max_train_steps = False
|
601 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
602 |
+
if args.max_train_steps is None:
|
603 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
604 |
+
overrode_max_train_steps = True
|
605 |
+
|
606 |
+
if args.use_cpu_offload_optimizer:
|
607 |
+
lr_scheduler = None
|
608 |
+
accelerator.print(
|
609 |
+
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
610 |
+
"you are training with those settings, they will be ignored."
|
611 |
+
)
|
612 |
+
else:
|
613 |
+
if use_deepspeed_scheduler:
|
614 |
+
from accelerate.utils import DummyScheduler
|
615 |
+
|
616 |
+
lr_scheduler = DummyScheduler(
|
617 |
+
name=args.lr_scheduler,
|
618 |
+
optimizer=optimizer,
|
619 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
620 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
621 |
+
)
|
622 |
+
else:
|
623 |
+
lr_scheduler = get_scheduler(
|
624 |
+
args.lr_scheduler,
|
625 |
+
optimizer=optimizer,
|
626 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
627 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
628 |
+
num_cycles=args.lr_num_cycles,
|
629 |
+
power=args.lr_power,
|
630 |
+
)
|
631 |
+
|
632 |
+
# Prepare everything with our `accelerator`.
|
633 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
634 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
635 |
+
)
|
636 |
+
|
637 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
638 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
639 |
+
if overrode_max_train_steps:
|
640 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
641 |
+
# Afterwards we recalculate our number of training epochs
|
642 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
643 |
+
|
644 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
645 |
+
# The trackers initializes automatically on the main process.
|
646 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
647 |
+
tracker_name = args.tracker_name or "cogvideox-lora"
|
648 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
649 |
+
|
650 |
+
accelerator.print("===== Memory before training =====")
|
651 |
+
reset_memory(accelerator.device)
|
652 |
+
print_memory(accelerator.device)
|
653 |
+
|
654 |
+
# Train!
|
655 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
656 |
+
|
657 |
+
accelerator.print("***** Running training *****")
|
658 |
+
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
659 |
+
accelerator.print(f" Num examples = {len(train_dataset)}")
|
660 |
+
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
661 |
+
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
662 |
+
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
663 |
+
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
664 |
+
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
665 |
+
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
666 |
+
global_step = 0
|
667 |
+
first_epoch = 0
|
668 |
+
|
669 |
+
# Potentially load in the weights and states from a previous save
|
670 |
+
if not args.resume_from_checkpoint:
|
671 |
+
initial_global_step = 0
|
672 |
+
else:
|
673 |
+
if args.resume_from_checkpoint != "latest":
|
674 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
675 |
+
else:
|
676 |
+
# Get the most recent checkpoint
|
677 |
+
dirs = os.listdir(args.output_dir)
|
678 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
679 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
680 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
681 |
+
|
682 |
+
if path is None:
|
683 |
+
accelerator.print(
|
684 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
685 |
+
)
|
686 |
+
args.resume_from_checkpoint = None
|
687 |
+
initial_global_step = 0
|
688 |
+
else:
|
689 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
690 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
691 |
+
global_step = int(path.split("-")[1])
|
692 |
+
|
693 |
+
initial_global_step = global_step
|
694 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
695 |
+
|
696 |
+
progress_bar = tqdm(
|
697 |
+
range(0, args.max_train_steps),
|
698 |
+
initial=initial_global_step,
|
699 |
+
desc="Steps",
|
700 |
+
# Only show the progress bar once on each machine.
|
701 |
+
disable=not accelerator.is_local_main_process,
|
702 |
+
)
|
703 |
+
|
704 |
+
# For DeepSpeed training
|
705 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
706 |
+
|
707 |
+
if args.load_tensors:
|
708 |
+
del vae, text_encoder
|
709 |
+
gc.collect()
|
710 |
+
torch.cuda.empty_cache()
|
711 |
+
torch.cuda.synchronize(accelerator.device)
|
712 |
+
|
713 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
714 |
+
|
715 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
716 |
+
transformer.train()
|
717 |
+
|
718 |
+
for step, batch in enumerate(train_dataloader):
|
719 |
+
models_to_accumulate = [transformer]
|
720 |
+
logs = {}
|
721 |
+
|
722 |
+
with accelerator.accumulate(models_to_accumulate):
|
723 |
+
images = batch["images"].to(accelerator.device, non_blocking=True)
|
724 |
+
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
725 |
+
prompts = batch["prompts"]
|
726 |
+
|
727 |
+
# Encode videos
|
728 |
+
if not args.load_tensors:
|
729 |
+
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
730 |
+
image_noise_sigma = torch.normal(
|
731 |
+
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
|
732 |
+
)
|
733 |
+
image_noise_sigma = torch.exp(image_noise_sigma)
|
734 |
+
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
735 |
+
image_latent_dist = vae.encode(noisy_images).latent_dist
|
736 |
+
|
737 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
738 |
+
latent_dist = vae.encode(videos).latent_dist
|
739 |
+
else:
|
740 |
+
image_latent_dist = DiagonalGaussianDistribution(images)
|
741 |
+
latent_dist = DiagonalGaussianDistribution(videos)
|
742 |
+
|
743 |
+
image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
|
744 |
+
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
745 |
+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
746 |
+
|
747 |
+
video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
|
748 |
+
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
749 |
+
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
750 |
+
|
751 |
+
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
|
752 |
+
latent_padding = image_latents.new_zeros(padding_shape)
|
753 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
754 |
+
|
755 |
+
if random.random() < args.noised_image_dropout:
|
756 |
+
image_latents = torch.zeros_like(image_latents)
|
757 |
+
|
758 |
+
# Encode prompts
|
759 |
+
if not args.load_tensors:
|
760 |
+
prompt_embeds = compute_prompt_embeddings(
|
761 |
+
tokenizer,
|
762 |
+
text_encoder,
|
763 |
+
prompts,
|
764 |
+
model_config.max_text_seq_length,
|
765 |
+
accelerator.device,
|
766 |
+
weight_dtype,
|
767 |
+
requires_grad=False,
|
768 |
+
)
|
769 |
+
else:
|
770 |
+
prompt_embeds = prompts.to(dtype=weight_dtype)
|
771 |
+
|
772 |
+
# Sample noise that will be added to the latents
|
773 |
+
noise = torch.randn_like(video_latents)
|
774 |
+
batch_size, num_frames, num_channels, height, width = video_latents.shape
|
775 |
+
|
776 |
+
# Sample a random timestep for each image
|
777 |
+
timesteps = torch.randint(
|
778 |
+
0,
|
779 |
+
scheduler.config.num_train_timesteps,
|
780 |
+
(batch_size,),
|
781 |
+
dtype=torch.int64,
|
782 |
+
device=accelerator.device,
|
783 |
+
)
|
784 |
+
|
785 |
+
# Prepare rotary embeds
|
786 |
+
image_rotary_emb = (
|
787 |
+
prepare_rotary_positional_embeddings(
|
788 |
+
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
789 |
+
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
790 |
+
num_frames=num_frames,
|
791 |
+
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
792 |
+
patch_size=model_config.patch_size,
|
793 |
+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
794 |
+
attention_head_dim=model_config.attention_head_dim,
|
795 |
+
device=accelerator.device,
|
796 |
+
base_height=RoPE_BASE_HEIGHT,
|
797 |
+
base_width=RoPE_BASE_WIDTH,
|
798 |
+
)
|
799 |
+
if model_config.use_rotary_positional_embeddings
|
800 |
+
else None
|
801 |
+
)
|
802 |
+
|
803 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
804 |
+
# (this is the forward diffusion process)
|
805 |
+
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
|
806 |
+
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
|
807 |
+
|
808 |
+
ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
|
809 |
+
ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
|
810 |
+
# Predict the noise residual
|
811 |
+
model_output = transformer(
|
812 |
+
hidden_states=noisy_model_input,
|
813 |
+
encoder_hidden_states=prompt_embeds,
|
814 |
+
timestep=timesteps,
|
815 |
+
ofs=ofs_emb,
|
816 |
+
image_rotary_emb=image_rotary_emb,
|
817 |
+
return_dict=False,
|
818 |
+
)[0]
|
819 |
+
|
820 |
+
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
|
821 |
+
|
822 |
+
weights = 1 / (1 - alphas_cumprod[timesteps])
|
823 |
+
while len(weights.shape) < len(model_pred.shape):
|
824 |
+
weights = weights.unsqueeze(-1)
|
825 |
+
|
826 |
+
target = video_latents
|
827 |
+
|
828 |
+
loss = torch.mean(
|
829 |
+
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
830 |
+
dim=1,
|
831 |
+
)
|
832 |
+
loss = loss.mean()
|
833 |
+
accelerator.backward(loss)
|
834 |
+
|
835 |
+
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
836 |
+
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
837 |
+
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
838 |
+
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
839 |
+
logs.update(
|
840 |
+
{
|
841 |
+
"gradient_norm_before_clip": gradient_norm_before_clip,
|
842 |
+
"gradient_norm_after_clip": gradient_norm_after_clip,
|
843 |
+
}
|
844 |
+
)
|
845 |
+
|
846 |
+
if accelerator.state.deepspeed_plugin is None:
|
847 |
+
optimizer.step()
|
848 |
+
optimizer.zero_grad()
|
849 |
+
|
850 |
+
if not args.use_cpu_offload_optimizer:
|
851 |
+
lr_scheduler.step()
|
852 |
+
|
853 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
854 |
+
if accelerator.sync_gradients:
|
855 |
+
progress_bar.update(1)
|
856 |
+
global_step += 1
|
857 |
+
|
858 |
+
# Checkpointing
|
859 |
+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
860 |
+
if global_step % args.checkpointing_steps == 0:
|
861 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
862 |
+
if args.checkpoints_total_limit is not None:
|
863 |
+
checkpoints = os.listdir(args.output_dir)
|
864 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
865 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
866 |
+
|
867 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
868 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
869 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
870 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
871 |
+
|
872 |
+
logger.info(
|
873 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
874 |
+
)
|
875 |
+
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
876 |
+
|
877 |
+
for removing_checkpoint in removing_checkpoints:
|
878 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
879 |
+
shutil.rmtree(removing_checkpoint)
|
880 |
+
|
881 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
882 |
+
accelerator.save_state(save_path)
|
883 |
+
logger.info(f"Saved state to {save_path}")
|
884 |
+
|
885 |
+
# Validation
|
886 |
+
should_run_validation = args.validation_prompt is not None and (
|
887 |
+
args.validation_steps is not None and global_step % args.validation_steps == 0
|
888 |
+
)
|
889 |
+
if should_run_validation:
|
890 |
+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
891 |
+
|
892 |
+
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
893 |
+
logs.update(
|
894 |
+
{
|
895 |
+
"loss": loss.detach().item(),
|
896 |
+
"lr": last_lr,
|
897 |
+
}
|
898 |
+
)
|
899 |
+
progress_bar.set_postfix(**logs)
|
900 |
+
accelerator.log(logs, step=global_step)
|
901 |
+
|
902 |
+
if global_step >= args.max_train_steps:
|
903 |
+
break
|
904 |
+
|
905 |
+
if accelerator.is_main_process:
|
906 |
+
should_run_validation = args.validation_prompt is not None and (
|
907 |
+
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
|
908 |
+
)
|
909 |
+
if should_run_validation:
|
910 |
+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
911 |
+
|
912 |
+
accelerator.wait_for_everyone()
|
913 |
+
|
914 |
+
if accelerator.is_main_process:
|
915 |
+
transformer = unwrap_model(accelerator, transformer)
|
916 |
+
dtype = (
|
917 |
+
torch.float16
|
918 |
+
if args.mixed_precision == "fp16"
|
919 |
+
else torch.bfloat16
|
920 |
+
if args.mixed_precision == "bf16"
|
921 |
+
else torch.float32
|
922 |
+
)
|
923 |
+
transformer = transformer.to(dtype)
|
924 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
925 |
+
|
926 |
+
CogVideoXImageToVideoPipeline.save_lora_weights(
|
927 |
+
save_directory=args.output_dir,
|
928 |
+
transformer_lora_layers=transformer_lora_layers,
|
929 |
+
)
|
930 |
+
|
931 |
+
# Cleanup trained models to save memory
|
932 |
+
if args.load_tensors:
|
933 |
+
del transformer
|
934 |
+
else:
|
935 |
+
del transformer, text_encoder, vae
|
936 |
+
|
937 |
+
gc.collect()
|
938 |
+
torch.cuda.empty_cache()
|
939 |
+
torch.cuda.synchronize(accelerator.device)
|
940 |
+
|
941 |
+
accelerator.print("===== Memory before testing =====")
|
942 |
+
print_memory(accelerator.device)
|
943 |
+
reset_memory(accelerator.device)
|
944 |
+
|
945 |
+
# Final test inference
|
946 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
947 |
+
args.pretrained_model_name_or_path,
|
948 |
+
revision=args.revision,
|
949 |
+
variant=args.variant,
|
950 |
+
torch_dtype=weight_dtype,
|
951 |
+
)
|
952 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
953 |
+
|
954 |
+
if args.enable_slicing:
|
955 |
+
pipe.vae.enable_slicing()
|
956 |
+
if args.enable_tiling:
|
957 |
+
pipe.vae.enable_tiling()
|
958 |
+
if args.enable_model_cpu_offload:
|
959 |
+
pipe.enable_model_cpu_offload()
|
960 |
+
|
961 |
+
# Load LoRA weights
|
962 |
+
lora_scaling = args.lora_alpha / args.rank
|
963 |
+
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
964 |
+
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
|
965 |
+
|
966 |
+
# Run inference
|
967 |
+
validation_outputs = []
|
968 |
+
if args.validation_prompt and args.num_validation_videos > 0:
|
969 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
970 |
+
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
971 |
+
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
972 |
+
pipeline_args = {
|
973 |
+
"image": load_image(validation_image),
|
974 |
+
"prompt": validation_prompt,
|
975 |
+
"guidance_scale": args.guidance_scale,
|
976 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
977 |
+
"height": args.height,
|
978 |
+
"width": args.width,
|
979 |
+
}
|
980 |
+
|
981 |
+
video = log_validation(
|
982 |
+
accelerator=accelerator,
|
983 |
+
pipe=pipe,
|
984 |
+
args=args,
|
985 |
+
pipeline_args=pipeline_args,
|
986 |
+
is_final_validation=True,
|
987 |
+
)
|
988 |
+
validation_outputs.extend(video)
|
989 |
+
|
990 |
+
accelerator.print("===== Memory after testing =====")
|
991 |
+
print_memory(accelerator.device)
|
992 |
+
reset_memory(accelerator.device)
|
993 |
+
torch.cuda.synchronize(accelerator.device)
|
994 |
+
|
995 |
+
if args.push_to_hub:
|
996 |
+
save_model_card(
|
997 |
+
repo_id,
|
998 |
+
videos=validation_outputs,
|
999 |
+
base_model=args.pretrained_model_name_or_path,
|
1000 |
+
validation_prompt=args.validation_prompt,
|
1001 |
+
repo_folder=args.output_dir,
|
1002 |
+
fps=args.fps,
|
1003 |
+
)
|
1004 |
+
upload_folder(
|
1005 |
+
repo_id=repo_id,
|
1006 |
+
folder_path=args.output_dir,
|
1007 |
+
commit_message="End of training",
|
1008 |
+
ignore_patterns=["step_*", "epoch_*"],
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
accelerator.end_training()
|
1012 |
+
|
1013 |
+
|
1014 |
+
if __name__ == "__main__":
|
1015 |
+
args = get_args()
|
1016 |
+
main(args)
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_image_to_video_sft.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import shutil
|
22 |
+
from datetime import timedelta
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Dict
|
25 |
+
|
26 |
+
import diffusers
|
27 |
+
import torch
|
28 |
+
import transformers
|
29 |
+
import wandb
|
30 |
+
from accelerate import Accelerator, DistributedType, init_empty_weights
|
31 |
+
from accelerate.logging import get_logger
|
32 |
+
from accelerate.utils import (
|
33 |
+
DistributedDataParallelKwargs,
|
34 |
+
InitProcessGroupKwargs,
|
35 |
+
ProjectConfiguration,
|
36 |
+
set_seed,
|
37 |
+
)
|
38 |
+
from diffusers import (
|
39 |
+
AutoencoderKLCogVideoX,
|
40 |
+
CogVideoXDPMScheduler,
|
41 |
+
CogVideoXImageToVideoPipeline,
|
42 |
+
CogVideoXTransformer3DModel,
|
43 |
+
)
|
44 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
45 |
+
from diffusers.optimization import get_scheduler
|
46 |
+
from diffusers.training_utils import cast_training_params
|
47 |
+
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
|
48 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
49 |
+
from huggingface_hub import create_repo, upload_folder
|
50 |
+
from torch.utils.data import DataLoader
|
51 |
+
from tqdm.auto import tqdm
|
52 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
53 |
+
|
54 |
+
|
55 |
+
from args import get_args # isort:skip
|
56 |
+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
57 |
+
from text_encoder import compute_prompt_embeddings # isort:skip
|
58 |
+
from utils import (
|
59 |
+
get_gradient_norm,
|
60 |
+
get_optimizer,
|
61 |
+
prepare_rotary_positional_embeddings,
|
62 |
+
print_memory,
|
63 |
+
reset_memory,
|
64 |
+
unwrap_model,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
logger = get_logger(__name__)
|
69 |
+
|
70 |
+
|
71 |
+
def save_model_card(
|
72 |
+
repo_id: str,
|
73 |
+
videos=None,
|
74 |
+
base_model: str = None,
|
75 |
+
validation_prompt=None,
|
76 |
+
repo_folder=None,
|
77 |
+
fps=8,
|
78 |
+
):
|
79 |
+
widget_dict = []
|
80 |
+
if videos is not None:
|
81 |
+
for i, video in enumerate(videos):
|
82 |
+
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
83 |
+
widget_dict.append(
|
84 |
+
{
|
85 |
+
"text": validation_prompt if validation_prompt else " ",
|
86 |
+
"output": {"url": f"video_{i}.mp4"},
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
model_description = f"""
|
91 |
+
# CogVideoX Full Finetune
|
92 |
+
|
93 |
+
<Gallery />
|
94 |
+
|
95 |
+
## Model description
|
96 |
+
|
97 |
+
This is a full finetune of the CogVideoX model `{base_model}`.
|
98 |
+
|
99 |
+
## License
|
100 |
+
|
101 |
+
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
|
102 |
+
"""
|
103 |
+
model_card = load_or_create_model_card(
|
104 |
+
repo_id_or_path=repo_id,
|
105 |
+
from_training=True,
|
106 |
+
license="other",
|
107 |
+
base_model=base_model,
|
108 |
+
prompt=validation_prompt,
|
109 |
+
model_description=model_description,
|
110 |
+
widget=widget_dict,
|
111 |
+
)
|
112 |
+
tags = [
|
113 |
+
"text-to-video",
|
114 |
+
"image-to-video",
|
115 |
+
"diffusers-training",
|
116 |
+
"diffusers",
|
117 |
+
"cogvideox",
|
118 |
+
"cogvideox-diffusers",
|
119 |
+
]
|
120 |
+
|
121 |
+
model_card = populate_model_card(model_card, tags=tags)
|
122 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
123 |
+
|
124 |
+
|
125 |
+
def log_validation(
|
126 |
+
accelerator: Accelerator,
|
127 |
+
pipe: CogVideoXImageToVideoPipeline,
|
128 |
+
args: Dict[str, Any],
|
129 |
+
pipeline_args: Dict[str, Any],
|
130 |
+
is_final_validation: bool = False,
|
131 |
+
):
|
132 |
+
logger.info(
|
133 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
134 |
+
)
|
135 |
+
|
136 |
+
pipe = pipe.to(accelerator.device)
|
137 |
+
|
138 |
+
# run inference
|
139 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
140 |
+
|
141 |
+
videos = []
|
142 |
+
for _ in range(args.num_validation_videos):
|
143 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
144 |
+
videos.append(video)
|
145 |
+
|
146 |
+
for tracker in accelerator.trackers:
|
147 |
+
phase_name = "test" if is_final_validation else "validation"
|
148 |
+
if tracker.name == "wandb":
|
149 |
+
video_filenames = []
|
150 |
+
for i, video in enumerate(videos):
|
151 |
+
prompt = (
|
152 |
+
pipeline_args["prompt"][:25]
|
153 |
+
.replace(" ", "_")
|
154 |
+
.replace(" ", "_")
|
155 |
+
.replace("'", "_")
|
156 |
+
.replace('"', "_")
|
157 |
+
.replace("/", "_")
|
158 |
+
)
|
159 |
+
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
160 |
+
export_to_video(video, filename, fps=8)
|
161 |
+
video_filenames.append(filename)
|
162 |
+
|
163 |
+
tracker.log(
|
164 |
+
{
|
165 |
+
phase_name: [
|
166 |
+
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
167 |
+
for i, filename in enumerate(video_filenames)
|
168 |
+
]
|
169 |
+
}
|
170 |
+
)
|
171 |
+
|
172 |
+
return videos
|
173 |
+
|
174 |
+
|
175 |
+
def run_validation(
|
176 |
+
args: Dict[str, Any],
|
177 |
+
accelerator: Accelerator,
|
178 |
+
transformer,
|
179 |
+
scheduler,
|
180 |
+
model_config: Dict[str, Any],
|
181 |
+
weight_dtype: torch.dtype,
|
182 |
+
) -> None:
|
183 |
+
accelerator.print("===== Memory before validation =====")
|
184 |
+
print_memory(accelerator.device)
|
185 |
+
torch.cuda.synchronize(accelerator.device)
|
186 |
+
|
187 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
188 |
+
args.pretrained_model_name_or_path,
|
189 |
+
transformer=unwrap_model(accelerator, transformer),
|
190 |
+
scheduler=scheduler,
|
191 |
+
revision=args.revision,
|
192 |
+
variant=args.variant,
|
193 |
+
torch_dtype=weight_dtype,
|
194 |
+
)
|
195 |
+
|
196 |
+
if args.enable_slicing:
|
197 |
+
pipe.vae.enable_slicing()
|
198 |
+
if args.enable_tiling:
|
199 |
+
pipe.vae.enable_tiling()
|
200 |
+
if args.enable_model_cpu_offload:
|
201 |
+
pipe.enable_model_cpu_offload()
|
202 |
+
|
203 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
204 |
+
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
205 |
+
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
206 |
+
pipeline_args = {
|
207 |
+
"image": load_image(validation_image),
|
208 |
+
"prompt": validation_prompt,
|
209 |
+
"guidance_scale": args.guidance_scale,
|
210 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
211 |
+
"height": args.height,
|
212 |
+
"width": args.width,
|
213 |
+
"max_sequence_length": model_config.max_text_seq_length,
|
214 |
+
}
|
215 |
+
|
216 |
+
log_validation(
|
217 |
+
pipe=pipe,
|
218 |
+
args=args,
|
219 |
+
accelerator=accelerator,
|
220 |
+
pipeline_args=pipeline_args,
|
221 |
+
)
|
222 |
+
|
223 |
+
accelerator.print("===== Memory after validation =====")
|
224 |
+
print_memory(accelerator.device)
|
225 |
+
reset_memory(accelerator.device)
|
226 |
+
|
227 |
+
del pipe
|
228 |
+
gc.collect()
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
torch.cuda.synchronize(accelerator.device)
|
231 |
+
|
232 |
+
|
233 |
+
class CollateFunction:
|
234 |
+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
235 |
+
self.weight_dtype = weight_dtype
|
236 |
+
self.load_tensors = load_tensors
|
237 |
+
|
238 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
239 |
+
prompts = [x["prompt"] for x in data[0]]
|
240 |
+
|
241 |
+
if self.load_tensors:
|
242 |
+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
243 |
+
|
244 |
+
images = [x["image"] for x in data[0]]
|
245 |
+
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
|
246 |
+
|
247 |
+
videos = [x["video"] for x in data[0]]
|
248 |
+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
249 |
+
|
250 |
+
return {
|
251 |
+
"images": images,
|
252 |
+
"videos": videos,
|
253 |
+
"prompts": prompts,
|
254 |
+
}
|
255 |
+
|
256 |
+
|
257 |
+
def main(args):
|
258 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
259 |
+
raise ValueError(
|
260 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
261 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
262 |
+
)
|
263 |
+
|
264 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
265 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
266 |
+
raise ValueError(
|
267 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
268 |
+
)
|
269 |
+
|
270 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
271 |
+
|
272 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
273 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
274 |
+
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
275 |
+
accelerator = Accelerator(
|
276 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
277 |
+
mixed_precision=args.mixed_precision,
|
278 |
+
log_with=args.report_to,
|
279 |
+
project_config=accelerator_project_config,
|
280 |
+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
281 |
+
)
|
282 |
+
|
283 |
+
# Disable AMP for MPS.
|
284 |
+
if torch.backends.mps.is_available():
|
285 |
+
accelerator.native_amp = False
|
286 |
+
|
287 |
+
# Make one log on every process with the configuration for debugging.
|
288 |
+
logging.basicConfig(
|
289 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
290 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
291 |
+
level=logging.INFO,
|
292 |
+
)
|
293 |
+
logger.info(accelerator.state, main_process_only=False)
|
294 |
+
if accelerator.is_local_main_process:
|
295 |
+
transformers.utils.logging.set_verbosity_warning()
|
296 |
+
diffusers.utils.logging.set_verbosity_info()
|
297 |
+
else:
|
298 |
+
transformers.utils.logging.set_verbosity_error()
|
299 |
+
diffusers.utils.logging.set_verbosity_error()
|
300 |
+
|
301 |
+
# If passed along, set the training seed now.
|
302 |
+
if args.seed is not None:
|
303 |
+
set_seed(args.seed)
|
304 |
+
|
305 |
+
# Handle the repository creation
|
306 |
+
if accelerator.is_main_process:
|
307 |
+
if args.output_dir is not None:
|
308 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
309 |
+
|
310 |
+
if args.push_to_hub:
|
311 |
+
repo_id = create_repo(
|
312 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
313 |
+
exist_ok=True,
|
314 |
+
).repo_id
|
315 |
+
|
316 |
+
# Prepare models and scheduler
|
317 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
318 |
+
args.pretrained_model_name_or_path,
|
319 |
+
subfolder="tokenizer",
|
320 |
+
revision=args.revision,
|
321 |
+
)
|
322 |
+
|
323 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
324 |
+
args.pretrained_model_name_or_path,
|
325 |
+
subfolder="text_encoder",
|
326 |
+
revision=args.revision,
|
327 |
+
)
|
328 |
+
|
329 |
+
# CogVideoX-2b weights are stored in float16
|
330 |
+
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
331 |
+
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
332 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
333 |
+
args.pretrained_model_name_or_path,
|
334 |
+
subfolder="transformer",
|
335 |
+
torch_dtype=load_dtype,
|
336 |
+
revision=args.revision,
|
337 |
+
variant=args.variant,
|
338 |
+
)
|
339 |
+
|
340 |
+
if args.ignore_learned_positional_embeddings:
|
341 |
+
del transformer.patch_embed.pos_embedding
|
342 |
+
transformer.patch_embed.use_learned_positional_embeddings = False
|
343 |
+
transformer.config.use_learned_positional_embeddings = False
|
344 |
+
|
345 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
346 |
+
args.pretrained_model_name_or_path,
|
347 |
+
subfolder="vae",
|
348 |
+
revision=args.revision,
|
349 |
+
variant=args.variant,
|
350 |
+
)
|
351 |
+
|
352 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
353 |
+
|
354 |
+
if args.enable_slicing:
|
355 |
+
vae.enable_slicing()
|
356 |
+
if args.enable_tiling:
|
357 |
+
vae.enable_tiling()
|
358 |
+
|
359 |
+
text_encoder.requires_grad_(False)
|
360 |
+
vae.requires_grad_(False)
|
361 |
+
transformer.requires_grad_(True)
|
362 |
+
|
363 |
+
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
364 |
+
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
365 |
+
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
366 |
+
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
367 |
+
|
368 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
369 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
370 |
+
weight_dtype = torch.float32
|
371 |
+
if accelerator.state.deepspeed_plugin:
|
372 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
373 |
+
if (
|
374 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
375 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
376 |
+
):
|
377 |
+
weight_dtype = torch.float16
|
378 |
+
if (
|
379 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
380 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
381 |
+
):
|
382 |
+
weight_dtype = torch.bfloat16
|
383 |
+
else:
|
384 |
+
if accelerator.mixed_precision == "fp16":
|
385 |
+
weight_dtype = torch.float16
|
386 |
+
elif accelerator.mixed_precision == "bf16":
|
387 |
+
weight_dtype = torch.bfloat16
|
388 |
+
|
389 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
390 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
391 |
+
raise ValueError(
|
392 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
393 |
+
)
|
394 |
+
|
395 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
396 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
397 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
398 |
+
|
399 |
+
if args.gradient_checkpointing:
|
400 |
+
transformer.enable_gradient_checkpointing()
|
401 |
+
|
402 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
403 |
+
def save_model_hook(models, weights, output_dir):
|
404 |
+
if accelerator.is_main_process:
|
405 |
+
for model in models:
|
406 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
407 |
+
model = unwrap_model(accelerator, model)
|
408 |
+
model.save_pretrained(
|
409 |
+
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
raise ValueError(f"Unexpected save model: {model.__class__}")
|
413 |
+
|
414 |
+
# make sure to pop weight so that corresponding model is not saved again
|
415 |
+
if weights:
|
416 |
+
weights.pop()
|
417 |
+
|
418 |
+
def load_model_hook(models, input_dir):
|
419 |
+
transformer_ = None
|
420 |
+
init_under_meta = False
|
421 |
+
|
422 |
+
# This is a bit of a hack but I don't know any other solution.
|
423 |
+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
424 |
+
while len(models) > 0:
|
425 |
+
model = models.pop()
|
426 |
+
|
427 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
428 |
+
transformer_ = unwrap_model(accelerator, model)
|
429 |
+
else:
|
430 |
+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
431 |
+
else:
|
432 |
+
with init_empty_weights():
|
433 |
+
transformer_ = CogVideoXTransformer3DModel.from_config(
|
434 |
+
args.pretrained_model_name_or_path, subfolder="transformer"
|
435 |
+
)
|
436 |
+
init_under_meta = True
|
437 |
+
|
438 |
+
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
439 |
+
transformer_.register_to_config(**load_model.config)
|
440 |
+
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
|
441 |
+
del load_model
|
442 |
+
|
443 |
+
# Make sure the trainable params are in float32. This is again needed since the base models
|
444 |
+
# are in `weight_dtype`. More details:
|
445 |
+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
446 |
+
if args.mixed_precision == "fp16":
|
447 |
+
cast_training_params([transformer_])
|
448 |
+
|
449 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
450 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
451 |
+
|
452 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
453 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
454 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
455 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
456 |
+
|
457 |
+
if args.scale_lr:
|
458 |
+
args.learning_rate = (
|
459 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
460 |
+
)
|
461 |
+
|
462 |
+
# Make sure the trainable params are in float32.
|
463 |
+
if args.mixed_precision == "fp16":
|
464 |
+
cast_training_params([transformer], dtype=torch.float32)
|
465 |
+
|
466 |
+
transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
467 |
+
|
468 |
+
# Optimization parameters
|
469 |
+
transformer_parameters_with_lr = {
|
470 |
+
"params": transformer_parameters,
|
471 |
+
"lr": args.learning_rate,
|
472 |
+
}
|
473 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
474 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
475 |
+
|
476 |
+
use_deepspeed_optimizer = (
|
477 |
+
accelerator.state.deepspeed_plugin is not None
|
478 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
479 |
+
)
|
480 |
+
use_deepspeed_scheduler = (
|
481 |
+
accelerator.state.deepspeed_plugin is not None
|
482 |
+
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
483 |
+
)
|
484 |
+
|
485 |
+
optimizer = get_optimizer(
|
486 |
+
params_to_optimize=params_to_optimize,
|
487 |
+
optimizer_name=args.optimizer,
|
488 |
+
learning_rate=args.learning_rate,
|
489 |
+
beta1=args.beta1,
|
490 |
+
beta2=args.beta2,
|
491 |
+
beta3=args.beta3,
|
492 |
+
epsilon=args.epsilon,
|
493 |
+
weight_decay=args.weight_decay,
|
494 |
+
prodigy_decouple=args.prodigy_decouple,
|
495 |
+
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
496 |
+
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
497 |
+
use_8bit=args.use_8bit,
|
498 |
+
use_4bit=args.use_4bit,
|
499 |
+
use_torchao=args.use_torchao,
|
500 |
+
use_deepspeed=use_deepspeed_optimizer,
|
501 |
+
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
502 |
+
offload_gradients=args.offload_gradients,
|
503 |
+
)
|
504 |
+
|
505 |
+
# Dataset and DataLoader
|
506 |
+
dataset_init_kwargs = {
|
507 |
+
"data_root": args.data_root,
|
508 |
+
"dataset_file": args.dataset_file,
|
509 |
+
"caption_column": args.caption_column,
|
510 |
+
"video_column": args.video_column,
|
511 |
+
"max_num_frames": args.max_num_frames,
|
512 |
+
"id_token": args.id_token,
|
513 |
+
"height_buckets": args.height_buckets,
|
514 |
+
"width_buckets": args.width_buckets,
|
515 |
+
"frame_buckets": args.frame_buckets,
|
516 |
+
"load_tensors": args.load_tensors,
|
517 |
+
"random_flip": args.random_flip,
|
518 |
+
"image_to_video": True,
|
519 |
+
}
|
520 |
+
if args.video_reshape_mode is None:
|
521 |
+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
522 |
+
else:
|
523 |
+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
524 |
+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
525 |
+
)
|
526 |
+
|
527 |
+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
528 |
+
|
529 |
+
train_dataloader = DataLoader(
|
530 |
+
train_dataset,
|
531 |
+
batch_size=1,
|
532 |
+
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
533 |
+
collate_fn=collate_fn,
|
534 |
+
num_workers=args.dataloader_num_workers,
|
535 |
+
pin_memory=args.pin_memory,
|
536 |
+
)
|
537 |
+
|
538 |
+
# Scheduler and math around the number of training steps.
|
539 |
+
overrode_max_train_steps = False
|
540 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
541 |
+
if args.max_train_steps is None:
|
542 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
543 |
+
overrode_max_train_steps = True
|
544 |
+
|
545 |
+
if args.use_cpu_offload_optimizer:
|
546 |
+
lr_scheduler = None
|
547 |
+
accelerator.print(
|
548 |
+
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
549 |
+
"you are training with those settings, they will be ignored."
|
550 |
+
)
|
551 |
+
else:
|
552 |
+
if use_deepspeed_scheduler:
|
553 |
+
from accelerate.utils import DummyScheduler
|
554 |
+
|
555 |
+
lr_scheduler = DummyScheduler(
|
556 |
+
name=args.lr_scheduler,
|
557 |
+
optimizer=optimizer,
|
558 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
559 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
560 |
+
)
|
561 |
+
else:
|
562 |
+
lr_scheduler = get_scheduler(
|
563 |
+
args.lr_scheduler,
|
564 |
+
optimizer=optimizer,
|
565 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
566 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
567 |
+
num_cycles=args.lr_num_cycles,
|
568 |
+
power=args.lr_power,
|
569 |
+
)
|
570 |
+
|
571 |
+
# Prepare everything with our `accelerator`.
|
572 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
573 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
574 |
+
)
|
575 |
+
|
576 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
577 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
578 |
+
if overrode_max_train_steps:
|
579 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
580 |
+
# Afterwards we recalculate our number of training epochs
|
581 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
582 |
+
|
583 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
584 |
+
# The trackers initializes automatically on the main process.
|
585 |
+
if accelerator.is_main_process:
|
586 |
+
tracker_name = args.tracker_name or "cogvideox-sft"
|
587 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
588 |
+
|
589 |
+
accelerator.print("===== Memory before training =====")
|
590 |
+
reset_memory(accelerator.device)
|
591 |
+
print_memory(accelerator.device)
|
592 |
+
|
593 |
+
# Train!
|
594 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
595 |
+
|
596 |
+
accelerator.print("***** Running training *****")
|
597 |
+
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
598 |
+
accelerator.print(f" Num examples = {len(train_dataset)}")
|
599 |
+
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
600 |
+
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
601 |
+
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
602 |
+
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
603 |
+
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
604 |
+
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
605 |
+
global_step = 0
|
606 |
+
first_epoch = 0
|
607 |
+
|
608 |
+
# Potentially load in the weights and states from a previous save
|
609 |
+
if not args.resume_from_checkpoint:
|
610 |
+
initial_global_step = 0
|
611 |
+
else:
|
612 |
+
if args.resume_from_checkpoint != "latest":
|
613 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
614 |
+
else:
|
615 |
+
# Get the most recent checkpoint
|
616 |
+
dirs = os.listdir(args.output_dir)
|
617 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
618 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
619 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
620 |
+
|
621 |
+
if path is None:
|
622 |
+
accelerator.print(
|
623 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
624 |
+
)
|
625 |
+
args.resume_from_checkpoint = None
|
626 |
+
initial_global_step = 0
|
627 |
+
else:
|
628 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
629 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
630 |
+
global_step = int(path.split("-")[1])
|
631 |
+
|
632 |
+
initial_global_step = global_step
|
633 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
634 |
+
|
635 |
+
progress_bar = tqdm(
|
636 |
+
range(0, args.max_train_steps),
|
637 |
+
initial=initial_global_step,
|
638 |
+
desc="Steps",
|
639 |
+
# Only show the progress bar once on each machine.
|
640 |
+
disable=not accelerator.is_local_main_process,
|
641 |
+
)
|
642 |
+
|
643 |
+
# For DeepSpeed training
|
644 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
645 |
+
|
646 |
+
if args.load_tensors:
|
647 |
+
del vae, text_encoder
|
648 |
+
gc.collect()
|
649 |
+
torch.cuda.empty_cache()
|
650 |
+
torch.cuda.synchronize(accelerator.device)
|
651 |
+
|
652 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
653 |
+
|
654 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
655 |
+
transformer.train()
|
656 |
+
for step, batch in enumerate(train_dataloader):
|
657 |
+
models_to_accumulate = [transformer]
|
658 |
+
logs = {}
|
659 |
+
|
660 |
+
with accelerator.accumulate(models_to_accumulate):
|
661 |
+
images = batch["images"].to(accelerator.device, non_blocking=True)
|
662 |
+
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
663 |
+
prompts = batch["prompts"]
|
664 |
+
|
665 |
+
# Encode videos
|
666 |
+
if not args.load_tensors:
|
667 |
+
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
668 |
+
image_noise_sigma = torch.normal(
|
669 |
+
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
|
670 |
+
)
|
671 |
+
image_noise_sigma = torch.exp(image_noise_sigma)
|
672 |
+
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
673 |
+
image_latent_dist = vae.encode(noisy_images).latent_dist
|
674 |
+
|
675 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
676 |
+
latent_dist = vae.encode(videos).latent_dist
|
677 |
+
else:
|
678 |
+
image_latent_dist = DiagonalGaussianDistribution(images)
|
679 |
+
latent_dist = DiagonalGaussianDistribution(videos)
|
680 |
+
|
681 |
+
image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
|
682 |
+
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
683 |
+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
684 |
+
|
685 |
+
video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
|
686 |
+
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
687 |
+
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
688 |
+
|
689 |
+
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
|
690 |
+
latent_padding = image_latents.new_zeros(padding_shape)
|
691 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
692 |
+
|
693 |
+
if random.random() < args.noised_image_dropout:
|
694 |
+
image_latents = torch.zeros_like(image_latents)
|
695 |
+
|
696 |
+
# Encode prompts
|
697 |
+
if not args.load_tensors:
|
698 |
+
prompt_embeds = compute_prompt_embeddings(
|
699 |
+
tokenizer,
|
700 |
+
text_encoder,
|
701 |
+
prompts,
|
702 |
+
model_config.max_text_seq_length,
|
703 |
+
accelerator.device,
|
704 |
+
weight_dtype,
|
705 |
+
requires_grad=False,
|
706 |
+
)
|
707 |
+
else:
|
708 |
+
prompt_embeds = prompts.to(dtype=weight_dtype)
|
709 |
+
|
710 |
+
# Sample noise that will be added to the latents
|
711 |
+
noise = torch.randn_like(video_latents)
|
712 |
+
batch_size, num_frames, num_channels, height, width = video_latents.shape
|
713 |
+
|
714 |
+
# Sample a random timestep for each image
|
715 |
+
timesteps = torch.randint(
|
716 |
+
0,
|
717 |
+
scheduler.config.num_train_timesteps,
|
718 |
+
(batch_size,),
|
719 |
+
dtype=torch.int64,
|
720 |
+
device=accelerator.device,
|
721 |
+
)
|
722 |
+
|
723 |
+
# Prepare rotary embeds
|
724 |
+
image_rotary_emb = (
|
725 |
+
prepare_rotary_positional_embeddings(
|
726 |
+
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
727 |
+
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
728 |
+
num_frames=num_frames,
|
729 |
+
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
730 |
+
patch_size=model_config.patch_size,
|
731 |
+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
732 |
+
attention_head_dim=model_config.attention_head_dim,
|
733 |
+
device=accelerator.device,
|
734 |
+
base_height=RoPE_BASE_HEIGHT,
|
735 |
+
base_width=RoPE_BASE_WIDTH,
|
736 |
+
)
|
737 |
+
if model_config.use_rotary_positional_embeddings
|
738 |
+
else None
|
739 |
+
)
|
740 |
+
|
741 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
742 |
+
# (this is the forward diffusion process)
|
743 |
+
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
|
744 |
+
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
|
745 |
+
model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
746 |
+
ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
|
747 |
+
ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
|
748 |
+
# Predict the noise residual
|
749 |
+
model_output = transformer(
|
750 |
+
hidden_states=noisy_model_input,
|
751 |
+
encoder_hidden_states=prompt_embeds,
|
752 |
+
timestep=timesteps,
|
753 |
+
ofs=ofs_emb,
|
754 |
+
image_rotary_emb=image_rotary_emb,
|
755 |
+
return_dict=False,
|
756 |
+
)[0]
|
757 |
+
|
758 |
+
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
|
759 |
+
|
760 |
+
weights = 1 / (1 - alphas_cumprod[timesteps])
|
761 |
+
while len(weights.shape) < len(model_pred.shape):
|
762 |
+
weights = weights.unsqueeze(-1)
|
763 |
+
|
764 |
+
target = video_latents
|
765 |
+
|
766 |
+
loss = torch.mean(
|
767 |
+
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
768 |
+
dim=1,
|
769 |
+
)
|
770 |
+
loss = loss.mean()
|
771 |
+
accelerator.backward(loss)
|
772 |
+
|
773 |
+
if accelerator.sync_gradients:
|
774 |
+
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
775 |
+
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
776 |
+
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
777 |
+
logs.update(
|
778 |
+
{
|
779 |
+
"gradient_norm_before_clip": gradient_norm_before_clip,
|
780 |
+
"gradient_norm_after_clip": gradient_norm_after_clip,
|
781 |
+
}
|
782 |
+
)
|
783 |
+
if accelerator.state.deepspeed_plugin is None:
|
784 |
+
optimizer.step()
|
785 |
+
optimizer.zero_grad()
|
786 |
+
|
787 |
+
if not args.use_cpu_offload_optimizer:
|
788 |
+
lr_scheduler.step()
|
789 |
+
|
790 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
791 |
+
if accelerator.sync_gradients:
|
792 |
+
progress_bar.update(1)
|
793 |
+
global_step += 1
|
794 |
+
|
795 |
+
# Checkpointing
|
796 |
+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
797 |
+
if global_step % args.checkpointing_steps == 0:
|
798 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
799 |
+
if args.checkpoints_total_limit is not None:
|
800 |
+
checkpoints = os.listdir(args.output_dir)
|
801 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
802 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
803 |
+
|
804 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
805 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
806 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
807 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
808 |
+
|
809 |
+
logger.info(
|
810 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
811 |
+
)
|
812 |
+
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
813 |
+
|
814 |
+
for removing_checkpoint in removing_checkpoints:
|
815 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
816 |
+
shutil.rmtree(removing_checkpoint)
|
817 |
+
|
818 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
819 |
+
accelerator.save_state(save_path)
|
820 |
+
logger.info(f"Saved state to {save_path}")
|
821 |
+
|
822 |
+
# Validation
|
823 |
+
should_run_validation = args.validation_prompt is not None and (
|
824 |
+
args.validation_steps is not None and global_step % args.validation_steps == 0
|
825 |
+
)
|
826 |
+
if should_run_validation:
|
827 |
+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
828 |
+
|
829 |
+
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
830 |
+
logs.update(
|
831 |
+
{
|
832 |
+
"loss": loss.detach().item(),
|
833 |
+
"lr": last_lr,
|
834 |
+
}
|
835 |
+
)
|
836 |
+
progress_bar.set_postfix(**logs)
|
837 |
+
accelerator.log(logs, step=global_step)
|
838 |
+
|
839 |
+
if global_step >= args.max_train_steps:
|
840 |
+
break
|
841 |
+
|
842 |
+
if accelerator.is_main_process:
|
843 |
+
should_run_validation = args.validation_prompt is not None and (
|
844 |
+
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
|
845 |
+
)
|
846 |
+
if should_run_validation:
|
847 |
+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
848 |
+
accelerator.wait_for_everyone()
|
849 |
+
|
850 |
+
if accelerator.is_main_process:
|
851 |
+
transformer = unwrap_model(accelerator, transformer)
|
852 |
+
dtype = (
|
853 |
+
torch.float16
|
854 |
+
if args.mixed_precision == "fp16"
|
855 |
+
else torch.bfloat16
|
856 |
+
if args.mixed_precision == "bf16"
|
857 |
+
else torch.float32
|
858 |
+
)
|
859 |
+
transformer = transformer.to(dtype)
|
860 |
+
|
861 |
+
transformer.save_pretrained(
|
862 |
+
os.path.join(args.output_dir, "transformer"),
|
863 |
+
safe_serialization=True,
|
864 |
+
max_shard_size="5GB",
|
865 |
+
)
|
866 |
+
|
867 |
+
# Cleanup trained models to save memory
|
868 |
+
if args.load_tensors:
|
869 |
+
del transformer
|
870 |
+
else:
|
871 |
+
del transformer, text_encoder, vae
|
872 |
+
|
873 |
+
gc.collect()
|
874 |
+
torch.cuda.empty_cache()
|
875 |
+
torch.cuda.synchronize(accelerator.device)
|
876 |
+
|
877 |
+
accelerator.print("===== Memory before testing =====")
|
878 |
+
print_memory(accelerator.device)
|
879 |
+
reset_memory(accelerator.device)
|
880 |
+
|
881 |
+
# Final test inference
|
882 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
883 |
+
args.pretrained_model_name_or_path,
|
884 |
+
revision=args.revision,
|
885 |
+
variant=args.variant,
|
886 |
+
torch_dtype=weight_dtype,
|
887 |
+
)
|
888 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
889 |
+
|
890 |
+
if args.enable_slicing:
|
891 |
+
pipe.vae.enable_slicing()
|
892 |
+
if args.enable_tiling:
|
893 |
+
pipe.vae.enable_tiling()
|
894 |
+
if args.enable_model_cpu_offload:
|
895 |
+
pipe.enable_model_cpu_offload()
|
896 |
+
|
897 |
+
# Run inference
|
898 |
+
validation_outputs = []
|
899 |
+
if args.validation_prompt and args.num_validation_videos > 0:
|
900 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
901 |
+
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
902 |
+
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
903 |
+
pipeline_args = {
|
904 |
+
"image": load_image(validation_image),
|
905 |
+
"prompt": validation_prompt,
|
906 |
+
"guidance_scale": args.guidance_scale,
|
907 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
908 |
+
"height": args.height,
|
909 |
+
"width": args.width,
|
910 |
+
}
|
911 |
+
|
912 |
+
video = log_validation(
|
913 |
+
accelerator=accelerator,
|
914 |
+
pipe=pipe,
|
915 |
+
args=args,
|
916 |
+
pipeline_args=pipeline_args,
|
917 |
+
is_final_validation=True,
|
918 |
+
)
|
919 |
+
validation_outputs.extend(video)
|
920 |
+
|
921 |
+
accelerator.print("===== Memory after testing =====")
|
922 |
+
print_memory(accelerator.device)
|
923 |
+
reset_memory(accelerator.device)
|
924 |
+
torch.cuda.synchronize(accelerator.device)
|
925 |
+
|
926 |
+
if args.push_to_hub:
|
927 |
+
save_model_card(
|
928 |
+
repo_id,
|
929 |
+
videos=validation_outputs,
|
930 |
+
base_model=args.pretrained_model_name_or_path,
|
931 |
+
validation_prompt=args.validation_prompt,
|
932 |
+
repo_folder=args.output_dir,
|
933 |
+
fps=args.fps,
|
934 |
+
)
|
935 |
+
upload_folder(
|
936 |
+
repo_id=repo_id,
|
937 |
+
folder_path=args.output_dir,
|
938 |
+
commit_message="End of training",
|
939 |
+
ignore_patterns=["step_*", "epoch_*"],
|
940 |
+
)
|
941 |
+
|
942 |
+
accelerator.end_training()
|
943 |
+
|
944 |
+
|
945 |
+
if __name__ == "__main__":
|
946 |
+
args = get_args()
|
947 |
+
main(args)
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_lora.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import shutil
|
21 |
+
from datetime import timedelta
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import Any, Dict
|
24 |
+
|
25 |
+
import diffusers
|
26 |
+
import torch
|
27 |
+
import transformers
|
28 |
+
import wandb
|
29 |
+
from accelerate import Accelerator, DistributedType
|
30 |
+
from accelerate.logging import get_logger
|
31 |
+
from accelerate.utils import (
|
32 |
+
DistributedDataParallelKwargs,
|
33 |
+
InitProcessGroupKwargs,
|
34 |
+
ProjectConfiguration,
|
35 |
+
set_seed,
|
36 |
+
)
|
37 |
+
from diffusers import (
|
38 |
+
AutoencoderKLCogVideoX,
|
39 |
+
CogVideoXDPMScheduler,
|
40 |
+
CogVideoXPipeline,
|
41 |
+
CogVideoXTransformer3DModel,
|
42 |
+
)
|
43 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
44 |
+
from diffusers.optimization import get_scheduler
|
45 |
+
from diffusers.training_utils import cast_training_params
|
46 |
+
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video
|
47 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
48 |
+
from huggingface_hub import create_repo, upload_folder
|
49 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
50 |
+
from torch.utils.data import DataLoader
|
51 |
+
from tqdm.auto import tqdm
|
52 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
53 |
+
|
54 |
+
|
55 |
+
from args import get_args # isort:skip
|
56 |
+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
57 |
+
from text_encoder import compute_prompt_embeddings # isort:skip
|
58 |
+
from utils import (
|
59 |
+
get_gradient_norm,
|
60 |
+
get_optimizer,
|
61 |
+
prepare_rotary_positional_embeddings,
|
62 |
+
print_memory,
|
63 |
+
reset_memory,
|
64 |
+
unwrap_model,
|
65 |
+
) # isort:skip
|
66 |
+
|
67 |
+
|
68 |
+
logger = get_logger(__name__)
|
69 |
+
|
70 |
+
|
71 |
+
def save_model_card(
|
72 |
+
repo_id: str,
|
73 |
+
videos=None,
|
74 |
+
base_model: str = None,
|
75 |
+
validation_prompt=None,
|
76 |
+
repo_folder=None,
|
77 |
+
fps=8,
|
78 |
+
):
|
79 |
+
widget_dict = []
|
80 |
+
if videos is not None:
|
81 |
+
for i, video in enumerate(videos):
|
82 |
+
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
83 |
+
widget_dict.append(
|
84 |
+
{
|
85 |
+
"text": validation_prompt if validation_prompt else " ",
|
86 |
+
"output": {"url": f"video_{i}.mp4"},
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
model_description = f"""
|
91 |
+
# CogVideoX LoRA Finetune
|
92 |
+
|
93 |
+
<Gallery />
|
94 |
+
|
95 |
+
## Model description
|
96 |
+
|
97 |
+
This is a lora finetune of the CogVideoX model `{base_model}`.
|
98 |
+
|
99 |
+
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
100 |
+
|
101 |
+
## Download model
|
102 |
+
|
103 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
104 |
+
|
105 |
+
## Usage
|
106 |
+
|
107 |
+
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
108 |
+
|
109 |
+
```py
|
110 |
+
import torch
|
111 |
+
from diffusers import CogVideoXPipeline
|
112 |
+
from diffusers.utils import export_to_video
|
113 |
+
|
114 |
+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
115 |
+
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
|
116 |
+
|
117 |
+
# The LoRA adapter weights are determined by what was used for training.
|
118 |
+
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
119 |
+
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
120 |
+
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
121 |
+
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
122 |
+
|
123 |
+
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
124 |
+
export_to_video(video, "output.mp4", fps=8)
|
125 |
+
```
|
126 |
+
|
127 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
128 |
+
|
129 |
+
## License
|
130 |
+
|
131 |
+
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
|
132 |
+
"""
|
133 |
+
model_card = load_or_create_model_card(
|
134 |
+
repo_id_or_path=repo_id,
|
135 |
+
from_training=True,
|
136 |
+
license="other",
|
137 |
+
base_model=base_model,
|
138 |
+
prompt=validation_prompt,
|
139 |
+
model_description=model_description,
|
140 |
+
widget=widget_dict,
|
141 |
+
)
|
142 |
+
tags = [
|
143 |
+
"text-to-video",
|
144 |
+
"diffusers-training",
|
145 |
+
"diffusers",
|
146 |
+
"lora",
|
147 |
+
"cogvideox",
|
148 |
+
"cogvideox-diffusers",
|
149 |
+
"template:sd-lora",
|
150 |
+
]
|
151 |
+
|
152 |
+
model_card = populate_model_card(model_card, tags=tags)
|
153 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
154 |
+
|
155 |
+
|
156 |
+
def log_validation(
|
157 |
+
accelerator: Accelerator,
|
158 |
+
pipe: CogVideoXPipeline,
|
159 |
+
args: Dict[str, Any],
|
160 |
+
pipeline_args: Dict[str, Any],
|
161 |
+
epoch,
|
162 |
+
is_final_validation: bool = False,
|
163 |
+
):
|
164 |
+
logger.info(
|
165 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
166 |
+
)
|
167 |
+
|
168 |
+
pipe = pipe.to(accelerator.device)
|
169 |
+
|
170 |
+
# run inference
|
171 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
172 |
+
|
173 |
+
videos = []
|
174 |
+
for _ in range(args.num_validation_videos):
|
175 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
176 |
+
videos.append(video)
|
177 |
+
|
178 |
+
for tracker in accelerator.trackers:
|
179 |
+
phase_name = "test" if is_final_validation else "validation"
|
180 |
+
if tracker.name == "wandb":
|
181 |
+
video_filenames = []
|
182 |
+
for i, video in enumerate(videos):
|
183 |
+
prompt = (
|
184 |
+
pipeline_args["prompt"][:25]
|
185 |
+
.replace(" ", "_")
|
186 |
+
.replace(" ", "_")
|
187 |
+
.replace("'", "_")
|
188 |
+
.replace('"', "_")
|
189 |
+
.replace("/", "_")
|
190 |
+
)
|
191 |
+
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
192 |
+
export_to_video(video, filename, fps=8)
|
193 |
+
video_filenames.append(filename)
|
194 |
+
|
195 |
+
tracker.log(
|
196 |
+
{
|
197 |
+
phase_name: [
|
198 |
+
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
199 |
+
for i, filename in enumerate(video_filenames)
|
200 |
+
]
|
201 |
+
}
|
202 |
+
)
|
203 |
+
|
204 |
+
return videos
|
205 |
+
|
206 |
+
|
207 |
+
class CollateFunction:
|
208 |
+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
209 |
+
self.weight_dtype = weight_dtype
|
210 |
+
self.load_tensors = load_tensors
|
211 |
+
|
212 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
213 |
+
prompts = [x["prompt"] for x in data[0]]
|
214 |
+
|
215 |
+
if self.load_tensors:
|
216 |
+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
217 |
+
|
218 |
+
videos = [x["video"] for x in data[0]]
|
219 |
+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
220 |
+
|
221 |
+
return {
|
222 |
+
"videos": videos,
|
223 |
+
"prompts": prompts,
|
224 |
+
}
|
225 |
+
|
226 |
+
|
227 |
+
def main(args):
|
228 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
229 |
+
raise ValueError(
|
230 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
231 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
232 |
+
)
|
233 |
+
|
234 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
235 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
236 |
+
raise ValueError(
|
237 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
238 |
+
)
|
239 |
+
|
240 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
241 |
+
|
242 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
243 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
244 |
+
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
245 |
+
accelerator = Accelerator(
|
246 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
247 |
+
mixed_precision=args.mixed_precision,
|
248 |
+
log_with=args.report_to,
|
249 |
+
project_config=accelerator_project_config,
|
250 |
+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
251 |
+
)
|
252 |
+
|
253 |
+
# Disable AMP for MPS.
|
254 |
+
if torch.backends.mps.is_available():
|
255 |
+
accelerator.native_amp = False
|
256 |
+
|
257 |
+
# Make one log on every process with the configuration for debugging.
|
258 |
+
logging.basicConfig(
|
259 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
260 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
261 |
+
level=logging.INFO,
|
262 |
+
)
|
263 |
+
logger.info(accelerator.state, main_process_only=False)
|
264 |
+
if accelerator.is_local_main_process:
|
265 |
+
transformers.utils.logging.set_verbosity_warning()
|
266 |
+
diffusers.utils.logging.set_verbosity_info()
|
267 |
+
else:
|
268 |
+
transformers.utils.logging.set_verbosity_error()
|
269 |
+
diffusers.utils.logging.set_verbosity_error()
|
270 |
+
|
271 |
+
# If passed along, set the training seed now.
|
272 |
+
if args.seed is not None:
|
273 |
+
set_seed(args.seed)
|
274 |
+
|
275 |
+
# Handle the repository creation
|
276 |
+
if accelerator.is_main_process:
|
277 |
+
if args.output_dir is not None:
|
278 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
279 |
+
|
280 |
+
if args.push_to_hub:
|
281 |
+
repo_id = create_repo(
|
282 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
283 |
+
exist_ok=True,
|
284 |
+
).repo_id
|
285 |
+
|
286 |
+
# Prepare models and scheduler
|
287 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
288 |
+
args.pretrained_model_name_or_path,
|
289 |
+
subfolder="tokenizer",
|
290 |
+
revision=args.revision,
|
291 |
+
)
|
292 |
+
|
293 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
294 |
+
args.pretrained_model_name_or_path,
|
295 |
+
subfolder="text_encoder",
|
296 |
+
revision=args.revision,
|
297 |
+
)
|
298 |
+
|
299 |
+
# CogVideoX-2b weights are stored in float16
|
300 |
+
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
301 |
+
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
302 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
303 |
+
args.pretrained_model_name_or_path,
|
304 |
+
subfolder="transformer",
|
305 |
+
torch_dtype=load_dtype,
|
306 |
+
revision=args.revision,
|
307 |
+
variant=args.variant,
|
308 |
+
)
|
309 |
+
|
310 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
311 |
+
args.pretrained_model_name_or_path,
|
312 |
+
subfolder="vae",
|
313 |
+
revision=args.revision,
|
314 |
+
variant=args.variant,
|
315 |
+
)
|
316 |
+
|
317 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
318 |
+
|
319 |
+
if args.enable_slicing:
|
320 |
+
vae.enable_slicing()
|
321 |
+
if args.enable_tiling:
|
322 |
+
vae.enable_tiling()
|
323 |
+
|
324 |
+
# We only train the additional adapter LoRA layers
|
325 |
+
text_encoder.requires_grad_(False)
|
326 |
+
transformer.requires_grad_(False)
|
327 |
+
vae.requires_grad_(False)
|
328 |
+
|
329 |
+
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
330 |
+
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
331 |
+
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
332 |
+
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
333 |
+
|
334 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
335 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
336 |
+
weight_dtype = torch.float32
|
337 |
+
if accelerator.state.deepspeed_plugin:
|
338 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
339 |
+
if (
|
340 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
341 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
342 |
+
):
|
343 |
+
weight_dtype = torch.float16
|
344 |
+
if (
|
345 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
346 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
347 |
+
):
|
348 |
+
weight_dtype = torch.bfloat16
|
349 |
+
else:
|
350 |
+
if accelerator.mixed_precision == "fp16":
|
351 |
+
weight_dtype = torch.float16
|
352 |
+
elif accelerator.mixed_precision == "bf16":
|
353 |
+
weight_dtype = torch.bfloat16
|
354 |
+
|
355 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
356 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
357 |
+
raise ValueError(
|
358 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
359 |
+
)
|
360 |
+
|
361 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
362 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
363 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
364 |
+
|
365 |
+
if args.gradient_checkpointing:
|
366 |
+
transformer.enable_gradient_checkpointing()
|
367 |
+
|
368 |
+
# now we will add new LoRA weights to the attention layers
|
369 |
+
transformer_lora_config = LoraConfig(
|
370 |
+
r=args.rank,
|
371 |
+
lora_alpha=args.lora_alpha,
|
372 |
+
init_lora_weights=True,
|
373 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
374 |
+
)
|
375 |
+
transformer.add_adapter(transformer_lora_config)
|
376 |
+
|
377 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
378 |
+
def save_model_hook(models, weights, output_dir):
|
379 |
+
if accelerator.is_main_process:
|
380 |
+
transformer_lora_layers_to_save = None
|
381 |
+
|
382 |
+
for model in models:
|
383 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
384 |
+
model = unwrap_model(accelerator, model)
|
385 |
+
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
386 |
+
else:
|
387 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
388 |
+
|
389 |
+
# make sure to pop weight so that corresponding model is not saved again
|
390 |
+
if weights:
|
391 |
+
weights.pop()
|
392 |
+
|
393 |
+
CogVideoXPipeline.save_lora_weights(
|
394 |
+
output_dir,
|
395 |
+
transformer_lora_layers=transformer_lora_layers_to_save,
|
396 |
+
)
|
397 |
+
|
398 |
+
def load_model_hook(models, input_dir):
|
399 |
+
transformer_ = None
|
400 |
+
|
401 |
+
# This is a bit of a hack but I don't know any other solution.
|
402 |
+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
403 |
+
while len(models) > 0:
|
404 |
+
model = models.pop()
|
405 |
+
|
406 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
407 |
+
transformer_ = unwrap_model(accelerator, model)
|
408 |
+
else:
|
409 |
+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
410 |
+
else:
|
411 |
+
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
|
412 |
+
args.pretrained_model_name_or_path, subfolder="transformer"
|
413 |
+
)
|
414 |
+
transformer_.add_adapter(transformer_lora_config)
|
415 |
+
|
416 |
+
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
417 |
+
|
418 |
+
transformer_state_dict = {
|
419 |
+
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
420 |
+
}
|
421 |
+
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
422 |
+
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
423 |
+
if incompatible_keys is not None:
|
424 |
+
# check only for unexpected keys
|
425 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
426 |
+
if unexpected_keys:
|
427 |
+
logger.warning(
|
428 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
429 |
+
f" {unexpected_keys}. "
|
430 |
+
)
|
431 |
+
|
432 |
+
# Make sure the trainable params are in float32. This is again needed since the base models
|
433 |
+
# are in `weight_dtype`. More details:
|
434 |
+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
435 |
+
if args.mixed_precision == "fp16":
|
436 |
+
# only upcast trainable parameters (LoRA) into fp32
|
437 |
+
cast_training_params([transformer_])
|
438 |
+
|
439 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
440 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
441 |
+
|
442 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
443 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
444 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
445 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
446 |
+
|
447 |
+
if args.scale_lr:
|
448 |
+
args.learning_rate = (
|
449 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
450 |
+
)
|
451 |
+
|
452 |
+
# Make sure the trainable params are in float32.
|
453 |
+
if args.mixed_precision == "fp16":
|
454 |
+
# only upcast trainable parameters (LoRA) into fp32
|
455 |
+
cast_training_params([transformer], dtype=torch.float32)
|
456 |
+
|
457 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
458 |
+
|
459 |
+
# Optimization parameters
|
460 |
+
transformer_parameters_with_lr = {
|
461 |
+
"params": transformer_lora_parameters,
|
462 |
+
"lr": args.learning_rate,
|
463 |
+
}
|
464 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
465 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
466 |
+
|
467 |
+
use_deepspeed_optimizer = (
|
468 |
+
accelerator.state.deepspeed_plugin is not None
|
469 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
470 |
+
)
|
471 |
+
use_deepspeed_scheduler = (
|
472 |
+
accelerator.state.deepspeed_plugin is not None
|
473 |
+
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
474 |
+
)
|
475 |
+
|
476 |
+
optimizer = get_optimizer(
|
477 |
+
params_to_optimize=params_to_optimize,
|
478 |
+
optimizer_name=args.optimizer,
|
479 |
+
learning_rate=args.learning_rate,
|
480 |
+
beta1=args.beta1,
|
481 |
+
beta2=args.beta2,
|
482 |
+
beta3=args.beta3,
|
483 |
+
epsilon=args.epsilon,
|
484 |
+
weight_decay=args.weight_decay,
|
485 |
+
prodigy_decouple=args.prodigy_decouple,
|
486 |
+
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
487 |
+
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
488 |
+
use_8bit=args.use_8bit,
|
489 |
+
use_4bit=args.use_4bit,
|
490 |
+
use_torchao=args.use_torchao,
|
491 |
+
use_deepspeed=use_deepspeed_optimizer,
|
492 |
+
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
493 |
+
offload_gradients=args.offload_gradients,
|
494 |
+
)
|
495 |
+
|
496 |
+
# Dataset and DataLoader
|
497 |
+
dataset_init_kwargs = {
|
498 |
+
"data_root": args.data_root,
|
499 |
+
"dataset_file": args.dataset_file,
|
500 |
+
"caption_column": args.caption_column,
|
501 |
+
"video_column": args.video_column,
|
502 |
+
"max_num_frames": args.max_num_frames,
|
503 |
+
"id_token": args.id_token,
|
504 |
+
"height_buckets": args.height_buckets,
|
505 |
+
"width_buckets": args.width_buckets,
|
506 |
+
"frame_buckets": args.frame_buckets,
|
507 |
+
"load_tensors": args.load_tensors,
|
508 |
+
"random_flip": args.random_flip,
|
509 |
+
}
|
510 |
+
if args.video_reshape_mode is None:
|
511 |
+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
512 |
+
else:
|
513 |
+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
514 |
+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
515 |
+
)
|
516 |
+
|
517 |
+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
518 |
+
|
519 |
+
train_dataloader = DataLoader(
|
520 |
+
train_dataset,
|
521 |
+
batch_size=1,
|
522 |
+
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
523 |
+
collate_fn=collate_fn,
|
524 |
+
num_workers=args.dataloader_num_workers,
|
525 |
+
pin_memory=args.pin_memory,
|
526 |
+
)
|
527 |
+
|
528 |
+
# Scheduler and math around the number of training steps.
|
529 |
+
overrode_max_train_steps = False
|
530 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
531 |
+
if args.max_train_steps is None:
|
532 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
533 |
+
overrode_max_train_steps = True
|
534 |
+
|
535 |
+
if args.use_cpu_offload_optimizer:
|
536 |
+
lr_scheduler = None
|
537 |
+
accelerator.print(
|
538 |
+
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
539 |
+
"you are training with those settings, they will be ignored."
|
540 |
+
)
|
541 |
+
else:
|
542 |
+
if use_deepspeed_scheduler:
|
543 |
+
from accelerate.utils import DummyScheduler
|
544 |
+
|
545 |
+
lr_scheduler = DummyScheduler(
|
546 |
+
name=args.lr_scheduler,
|
547 |
+
optimizer=optimizer,
|
548 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
549 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
550 |
+
)
|
551 |
+
else:
|
552 |
+
lr_scheduler = get_scheduler(
|
553 |
+
args.lr_scheduler,
|
554 |
+
optimizer=optimizer,
|
555 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
556 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
557 |
+
num_cycles=args.lr_num_cycles,
|
558 |
+
power=args.lr_power,
|
559 |
+
)
|
560 |
+
|
561 |
+
# Prepare everything with our `accelerator`.
|
562 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
563 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
564 |
+
)
|
565 |
+
|
566 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
567 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
568 |
+
if overrode_max_train_steps:
|
569 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
570 |
+
# Afterwards we recalculate our number of training epochs
|
571 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
572 |
+
|
573 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
574 |
+
# The trackers initializes automatically on the main process.
|
575 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
576 |
+
tracker_name = args.tracker_name or "cogvideox-lora"
|
577 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
578 |
+
|
579 |
+
accelerator.print("===== Memory before training =====")
|
580 |
+
reset_memory(accelerator.device)
|
581 |
+
print_memory(accelerator.device)
|
582 |
+
|
583 |
+
# Train!
|
584 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
585 |
+
|
586 |
+
accelerator.print("***** Running training *****")
|
587 |
+
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
588 |
+
accelerator.print(f" Num examples = {len(train_dataset)}")
|
589 |
+
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
590 |
+
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
591 |
+
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
592 |
+
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
593 |
+
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
594 |
+
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
595 |
+
global_step = 0
|
596 |
+
first_epoch = 0
|
597 |
+
|
598 |
+
# Potentially load in the weights and states from a previous save
|
599 |
+
if not args.resume_from_checkpoint:
|
600 |
+
initial_global_step = 0
|
601 |
+
else:
|
602 |
+
if args.resume_from_checkpoint != "latest":
|
603 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
604 |
+
else:
|
605 |
+
# Get the most recent checkpoint
|
606 |
+
dirs = os.listdir(args.output_dir)
|
607 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
608 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
609 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
610 |
+
|
611 |
+
if path is None:
|
612 |
+
accelerator.print(
|
613 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
614 |
+
)
|
615 |
+
args.resume_from_checkpoint = None
|
616 |
+
initial_global_step = 0
|
617 |
+
else:
|
618 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
619 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
620 |
+
global_step = int(path.split("-")[1])
|
621 |
+
|
622 |
+
initial_global_step = global_step
|
623 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
624 |
+
|
625 |
+
progress_bar = tqdm(
|
626 |
+
range(0, args.max_train_steps),
|
627 |
+
initial=initial_global_step,
|
628 |
+
desc="Steps",
|
629 |
+
# Only show the progress bar once on each machine.
|
630 |
+
disable=not accelerator.is_local_main_process,
|
631 |
+
)
|
632 |
+
|
633 |
+
# For DeepSpeed training
|
634 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
635 |
+
|
636 |
+
if args.load_tensors:
|
637 |
+
del vae, text_encoder
|
638 |
+
gc.collect()
|
639 |
+
torch.cuda.empty_cache()
|
640 |
+
torch.cuda.synchronize(accelerator.device)
|
641 |
+
|
642 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
643 |
+
|
644 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
645 |
+
transformer.train()
|
646 |
+
|
647 |
+
for step, batch in enumerate(train_dataloader):
|
648 |
+
models_to_accumulate = [transformer]
|
649 |
+
logs = {}
|
650 |
+
|
651 |
+
with accelerator.accumulate(models_to_accumulate):
|
652 |
+
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
653 |
+
prompts = batch["prompts"]
|
654 |
+
|
655 |
+
# Encode videos
|
656 |
+
if not args.load_tensors:
|
657 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
658 |
+
latent_dist = vae.encode(videos).latent_dist
|
659 |
+
else:
|
660 |
+
latent_dist = DiagonalGaussianDistribution(videos)
|
661 |
+
|
662 |
+
videos = latent_dist.sample() * VAE_SCALING_FACTOR
|
663 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
664 |
+
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
665 |
+
model_input = videos
|
666 |
+
|
667 |
+
# Encode prompts
|
668 |
+
if not args.load_tensors:
|
669 |
+
prompt_embeds = compute_prompt_embeddings(
|
670 |
+
tokenizer,
|
671 |
+
text_encoder,
|
672 |
+
prompts,
|
673 |
+
model_config.max_text_seq_length,
|
674 |
+
accelerator.device,
|
675 |
+
weight_dtype,
|
676 |
+
requires_grad=False,
|
677 |
+
)
|
678 |
+
else:
|
679 |
+
prompt_embeds = prompts.to(dtype=weight_dtype)
|
680 |
+
|
681 |
+
# Sample noise that will be added to the latents
|
682 |
+
noise = torch.randn_like(model_input)
|
683 |
+
batch_size, num_frames, num_channels, height, width = model_input.shape
|
684 |
+
|
685 |
+
# Sample a random timestep for each image
|
686 |
+
timesteps = torch.randint(
|
687 |
+
0,
|
688 |
+
scheduler.config.num_train_timesteps,
|
689 |
+
(batch_size,),
|
690 |
+
dtype=torch.int64,
|
691 |
+
device=model_input.device,
|
692 |
+
)
|
693 |
+
|
694 |
+
# Prepare rotary embeds
|
695 |
+
image_rotary_emb = (
|
696 |
+
prepare_rotary_positional_embeddings(
|
697 |
+
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
698 |
+
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
699 |
+
num_frames=num_frames,
|
700 |
+
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
701 |
+
patch_size=model_config.patch_size,
|
702 |
+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
703 |
+
attention_head_dim=model_config.attention_head_dim,
|
704 |
+
device=accelerator.device,
|
705 |
+
base_height=RoPE_BASE_HEIGHT,
|
706 |
+
base_width=RoPE_BASE_WIDTH,
|
707 |
+
)
|
708 |
+
if model_config.use_rotary_positional_embeddings
|
709 |
+
else None
|
710 |
+
)
|
711 |
+
|
712 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
713 |
+
# (this is the forward diffusion process)
|
714 |
+
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
715 |
+
|
716 |
+
# Predict the noise residual
|
717 |
+
model_output = transformer(
|
718 |
+
hidden_states=noisy_model_input,
|
719 |
+
encoder_hidden_states=prompt_embeds,
|
720 |
+
timestep=timesteps,
|
721 |
+
image_rotary_emb=image_rotary_emb,
|
722 |
+
return_dict=False,
|
723 |
+
)[0]
|
724 |
+
|
725 |
+
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
726 |
+
|
727 |
+
weights = 1 / (1 - alphas_cumprod[timesteps])
|
728 |
+
while len(weights.shape) < len(model_pred.shape):
|
729 |
+
weights = weights.unsqueeze(-1)
|
730 |
+
|
731 |
+
target = model_input
|
732 |
+
|
733 |
+
loss = torch.mean(
|
734 |
+
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
735 |
+
dim=1,
|
736 |
+
)
|
737 |
+
loss = loss.mean()
|
738 |
+
accelerator.backward(loss)
|
739 |
+
|
740 |
+
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
741 |
+
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
742 |
+
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
743 |
+
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
744 |
+
logs.update(
|
745 |
+
{
|
746 |
+
"gradient_norm_before_clip": gradient_norm_before_clip,
|
747 |
+
"gradient_norm_after_clip": gradient_norm_after_clip,
|
748 |
+
}
|
749 |
+
)
|
750 |
+
|
751 |
+
if accelerator.state.deepspeed_plugin is None:
|
752 |
+
optimizer.step()
|
753 |
+
optimizer.zero_grad()
|
754 |
+
|
755 |
+
if not args.use_cpu_offload_optimizer:
|
756 |
+
lr_scheduler.step()
|
757 |
+
|
758 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
759 |
+
if accelerator.sync_gradients:
|
760 |
+
progress_bar.update(1)
|
761 |
+
global_step += 1
|
762 |
+
|
763 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
764 |
+
if global_step % args.checkpointing_steps == 0:
|
765 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
766 |
+
if args.checkpoints_total_limit is not None:
|
767 |
+
checkpoints = os.listdir(args.output_dir)
|
768 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
769 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
770 |
+
|
771 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
772 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
773 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
774 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
775 |
+
|
776 |
+
logger.info(
|
777 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
778 |
+
)
|
779 |
+
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
780 |
+
|
781 |
+
for removing_checkpoint in removing_checkpoints:
|
782 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
783 |
+
shutil.rmtree(removing_checkpoint)
|
784 |
+
|
785 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
786 |
+
accelerator.save_state(save_path)
|
787 |
+
logger.info(f"Saved state to {save_path}")
|
788 |
+
|
789 |
+
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
790 |
+
logs.update(
|
791 |
+
{
|
792 |
+
"loss": loss.detach().item(),
|
793 |
+
"lr": last_lr,
|
794 |
+
}
|
795 |
+
)
|
796 |
+
progress_bar.set_postfix(**logs)
|
797 |
+
accelerator.log(logs, step=global_step)
|
798 |
+
|
799 |
+
if global_step >= args.max_train_steps:
|
800 |
+
break
|
801 |
+
|
802 |
+
if accelerator.is_main_process:
|
803 |
+
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
804 |
+
accelerator.print("===== Memory before validation =====")
|
805 |
+
print_memory(accelerator.device)
|
806 |
+
torch.cuda.synchronize(accelerator.device)
|
807 |
+
|
808 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
809 |
+
args.pretrained_model_name_or_path,
|
810 |
+
transformer=unwrap_model(accelerator, transformer),
|
811 |
+
scheduler=scheduler,
|
812 |
+
revision=args.revision,
|
813 |
+
variant=args.variant,
|
814 |
+
torch_dtype=weight_dtype,
|
815 |
+
)
|
816 |
+
|
817 |
+
if args.enable_slicing:
|
818 |
+
pipe.vae.enable_slicing()
|
819 |
+
if args.enable_tiling:
|
820 |
+
pipe.vae.enable_tiling()
|
821 |
+
if args.enable_model_cpu_offload:
|
822 |
+
pipe.enable_model_cpu_offload()
|
823 |
+
|
824 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
825 |
+
for validation_prompt in validation_prompts:
|
826 |
+
pipeline_args = {
|
827 |
+
"prompt": validation_prompt,
|
828 |
+
"guidance_scale": args.guidance_scale,
|
829 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
830 |
+
"height": args.height,
|
831 |
+
"width": args.width,
|
832 |
+
"max_sequence_length": model_config.max_text_seq_length,
|
833 |
+
}
|
834 |
+
|
835 |
+
log_validation(
|
836 |
+
pipe=pipe,
|
837 |
+
args=args,
|
838 |
+
accelerator=accelerator,
|
839 |
+
pipeline_args=pipeline_args,
|
840 |
+
epoch=epoch,
|
841 |
+
)
|
842 |
+
|
843 |
+
accelerator.print("===== Memory after validation =====")
|
844 |
+
print_memory(accelerator.device)
|
845 |
+
reset_memory(accelerator.device)
|
846 |
+
|
847 |
+
del pipe
|
848 |
+
gc.collect()
|
849 |
+
torch.cuda.empty_cache()
|
850 |
+
torch.cuda.synchronize(accelerator.device)
|
851 |
+
|
852 |
+
accelerator.wait_for_everyone()
|
853 |
+
|
854 |
+
if accelerator.is_main_process:
|
855 |
+
transformer = unwrap_model(accelerator, transformer)
|
856 |
+
dtype = (
|
857 |
+
torch.float16
|
858 |
+
if args.mixed_precision == "fp16"
|
859 |
+
else torch.bfloat16
|
860 |
+
if args.mixed_precision == "bf16"
|
861 |
+
else torch.float32
|
862 |
+
)
|
863 |
+
transformer = transformer.to(dtype)
|
864 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
865 |
+
|
866 |
+
CogVideoXPipeline.save_lora_weights(
|
867 |
+
save_directory=args.output_dir,
|
868 |
+
transformer_lora_layers=transformer_lora_layers,
|
869 |
+
)
|
870 |
+
|
871 |
+
# Cleanup trained models to save memory
|
872 |
+
if args.load_tensors:
|
873 |
+
del transformer
|
874 |
+
else:
|
875 |
+
del transformer, text_encoder, vae
|
876 |
+
|
877 |
+
gc.collect()
|
878 |
+
torch.cuda.empty_cache()
|
879 |
+
torch.cuda.synchronize(accelerator.device)
|
880 |
+
|
881 |
+
accelerator.print("===== Memory before testing =====")
|
882 |
+
print_memory(accelerator.device)
|
883 |
+
reset_memory(accelerator.device)
|
884 |
+
|
885 |
+
# Final test inference
|
886 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
887 |
+
args.pretrained_model_name_or_path,
|
888 |
+
revision=args.revision,
|
889 |
+
variant=args.variant,
|
890 |
+
torch_dtype=weight_dtype,
|
891 |
+
)
|
892 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
893 |
+
|
894 |
+
if args.enable_slicing:
|
895 |
+
pipe.vae.enable_slicing()
|
896 |
+
if args.enable_tiling:
|
897 |
+
pipe.vae.enable_tiling()
|
898 |
+
if args.enable_model_cpu_offload:
|
899 |
+
pipe.enable_model_cpu_offload()
|
900 |
+
|
901 |
+
# Load LoRA weights
|
902 |
+
lora_scaling = args.lora_alpha / args.rank
|
903 |
+
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
904 |
+
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
|
905 |
+
|
906 |
+
# Run inference
|
907 |
+
validation_outputs = []
|
908 |
+
if args.validation_prompt and args.num_validation_videos > 0:
|
909 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
910 |
+
for validation_prompt in validation_prompts:
|
911 |
+
pipeline_args = {
|
912 |
+
"prompt": validation_prompt,
|
913 |
+
"guidance_scale": args.guidance_scale,
|
914 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
915 |
+
"height": args.height,
|
916 |
+
"width": args.width,
|
917 |
+
}
|
918 |
+
|
919 |
+
video = log_validation(
|
920 |
+
accelerator=accelerator,
|
921 |
+
pipe=pipe,
|
922 |
+
args=args,
|
923 |
+
pipeline_args=pipeline_args,
|
924 |
+
epoch=epoch,
|
925 |
+
is_final_validation=True,
|
926 |
+
)
|
927 |
+
validation_outputs.extend(video)
|
928 |
+
|
929 |
+
accelerator.print("===== Memory after testing =====")
|
930 |
+
print_memory(accelerator.device)
|
931 |
+
reset_memory(accelerator.device)
|
932 |
+
torch.cuda.synchronize(accelerator.device)
|
933 |
+
|
934 |
+
if args.push_to_hub:
|
935 |
+
save_model_card(
|
936 |
+
repo_id,
|
937 |
+
videos=validation_outputs,
|
938 |
+
base_model=args.pretrained_model_name_or_path,
|
939 |
+
validation_prompt=args.validation_prompt,
|
940 |
+
repo_folder=args.output_dir,
|
941 |
+
fps=args.fps,
|
942 |
+
)
|
943 |
+
upload_folder(
|
944 |
+
repo_id=repo_id,
|
945 |
+
folder_path=args.output_dir,
|
946 |
+
commit_message="End of training",
|
947 |
+
ignore_patterns=["step_*", "epoch_*"],
|
948 |
+
)
|
949 |
+
|
950 |
+
accelerator.end_training()
|
951 |
+
|
952 |
+
|
953 |
+
if __name__ == "__main__":
|
954 |
+
args = get_args()
|
955 |
+
main(args)
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/cogvideox_text_to_video_sft.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import shutil
|
21 |
+
from datetime import timedelta
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import Any, Dict
|
24 |
+
|
25 |
+
import diffusers
|
26 |
+
import torch
|
27 |
+
import transformers
|
28 |
+
import wandb
|
29 |
+
from accelerate import Accelerator, DistributedType, init_empty_weights
|
30 |
+
from accelerate.logging import get_logger
|
31 |
+
from accelerate.utils import (
|
32 |
+
DistributedDataParallelKwargs,
|
33 |
+
InitProcessGroupKwargs,
|
34 |
+
ProjectConfiguration,
|
35 |
+
set_seed,
|
36 |
+
)
|
37 |
+
from diffusers import (
|
38 |
+
AutoencoderKLCogVideoX,
|
39 |
+
CogVideoXDPMScheduler,
|
40 |
+
CogVideoXPipeline,
|
41 |
+
CogVideoXTransformer3DModel,
|
42 |
+
)
|
43 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
44 |
+
from diffusers.optimization import get_scheduler
|
45 |
+
from diffusers.training_utils import cast_training_params
|
46 |
+
from diffusers.utils import export_to_video
|
47 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
48 |
+
from huggingface_hub import create_repo, upload_folder
|
49 |
+
from torch.utils.data import DataLoader
|
50 |
+
from tqdm.auto import tqdm
|
51 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
52 |
+
|
53 |
+
|
54 |
+
from args import get_args # isort:skip
|
55 |
+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
56 |
+
from text_encoder import compute_prompt_embeddings # isort:skip
|
57 |
+
from utils import (
|
58 |
+
get_gradient_norm,
|
59 |
+
get_optimizer,
|
60 |
+
prepare_rotary_positional_embeddings,
|
61 |
+
print_memory,
|
62 |
+
reset_memory,
|
63 |
+
unwrap_model,
|
64 |
+
) # isort:skip
|
65 |
+
|
66 |
+
|
67 |
+
logger = get_logger(__name__)
|
68 |
+
|
69 |
+
|
70 |
+
def save_model_card(
|
71 |
+
repo_id: str,
|
72 |
+
videos=None,
|
73 |
+
base_model: str = None,
|
74 |
+
validation_prompt=None,
|
75 |
+
repo_folder=None,
|
76 |
+
fps=8,
|
77 |
+
):
|
78 |
+
widget_dict = []
|
79 |
+
if videos is not None:
|
80 |
+
for i, video in enumerate(videos):
|
81 |
+
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
82 |
+
widget_dict.append(
|
83 |
+
{
|
84 |
+
"text": validation_prompt if validation_prompt else " ",
|
85 |
+
"output": {"url": f"video_{i}.mp4"},
|
86 |
+
}
|
87 |
+
)
|
88 |
+
|
89 |
+
model_description = f"""
|
90 |
+
# CogVideoX Full Finetune
|
91 |
+
|
92 |
+
<Gallery />
|
93 |
+
|
94 |
+
## Model description
|
95 |
+
|
96 |
+
This is a full finetune of the CogVideoX model `{base_model}`.
|
97 |
+
|
98 |
+
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
99 |
+
|
100 |
+
## Download model
|
101 |
+
|
102 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
103 |
+
|
104 |
+
## Usage
|
105 |
+
|
106 |
+
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
107 |
+
|
108 |
+
```py
|
109 |
+
import torch
|
110 |
+
from diffusers import CogVideoXPipeline
|
111 |
+
from diffusers.utils import export_to_video
|
112 |
+
|
113 |
+
pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda")
|
114 |
+
|
115 |
+
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
116 |
+
export_to_video(video, "output.mp4", fps=8)
|
117 |
+
```
|
118 |
+
|
119 |
+
For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX.
|
120 |
+
|
121 |
+
## License
|
122 |
+
|
123 |
+
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
|
124 |
+
"""
|
125 |
+
model_card = load_or_create_model_card(
|
126 |
+
repo_id_or_path=repo_id,
|
127 |
+
from_training=True,
|
128 |
+
license="other",
|
129 |
+
base_model=base_model,
|
130 |
+
prompt=validation_prompt,
|
131 |
+
model_description=model_description,
|
132 |
+
widget=widget_dict,
|
133 |
+
)
|
134 |
+
tags = [
|
135 |
+
"text-to-video",
|
136 |
+
"diffusers-training",
|
137 |
+
"diffusers",
|
138 |
+
"cogvideox",
|
139 |
+
"cogvideox-diffusers",
|
140 |
+
]
|
141 |
+
|
142 |
+
model_card = populate_model_card(model_card, tags=tags)
|
143 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
144 |
+
|
145 |
+
|
146 |
+
def log_validation(
|
147 |
+
accelerator: Accelerator,
|
148 |
+
pipe: CogVideoXPipeline,
|
149 |
+
args: Dict[str, Any],
|
150 |
+
pipeline_args: Dict[str, Any],
|
151 |
+
epoch,
|
152 |
+
is_final_validation: bool = False,
|
153 |
+
):
|
154 |
+
logger.info(
|
155 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
156 |
+
)
|
157 |
+
|
158 |
+
pipe = pipe.to(accelerator.device)
|
159 |
+
|
160 |
+
# run inference
|
161 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
162 |
+
|
163 |
+
videos = []
|
164 |
+
for _ in range(args.num_validation_videos):
|
165 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
166 |
+
videos.append(video)
|
167 |
+
|
168 |
+
for tracker in accelerator.trackers:
|
169 |
+
phase_name = "test" if is_final_validation else "validation"
|
170 |
+
if tracker.name == "wandb":
|
171 |
+
video_filenames = []
|
172 |
+
for i, video in enumerate(videos):
|
173 |
+
prompt = (
|
174 |
+
pipeline_args["prompt"][:25]
|
175 |
+
.replace(" ", "_")
|
176 |
+
.replace(" ", "_")
|
177 |
+
.replace("'", "_")
|
178 |
+
.replace('"', "_")
|
179 |
+
.replace("/", "_")
|
180 |
+
)
|
181 |
+
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
182 |
+
export_to_video(video, filename, fps=8)
|
183 |
+
video_filenames.append(filename)
|
184 |
+
|
185 |
+
tracker.log(
|
186 |
+
{
|
187 |
+
phase_name: [
|
188 |
+
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
189 |
+
for i, filename in enumerate(video_filenames)
|
190 |
+
]
|
191 |
+
}
|
192 |
+
)
|
193 |
+
|
194 |
+
return videos
|
195 |
+
|
196 |
+
|
197 |
+
class CollateFunction:
|
198 |
+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
199 |
+
self.weight_dtype = weight_dtype
|
200 |
+
self.load_tensors = load_tensors
|
201 |
+
|
202 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
203 |
+
prompts = [x["prompt"] for x in data[0]]
|
204 |
+
|
205 |
+
if self.load_tensors:
|
206 |
+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
207 |
+
|
208 |
+
videos = [x["video"] for x in data[0]]
|
209 |
+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
210 |
+
|
211 |
+
return {
|
212 |
+
"videos": videos,
|
213 |
+
"prompts": prompts,
|
214 |
+
}
|
215 |
+
|
216 |
+
|
217 |
+
def main(args):
|
218 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
219 |
+
raise ValueError(
|
220 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
221 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
222 |
+
)
|
223 |
+
|
224 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
225 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
226 |
+
raise ValueError(
|
227 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
228 |
+
)
|
229 |
+
|
230 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
231 |
+
|
232 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
233 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
234 |
+
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
235 |
+
accelerator = Accelerator(
|
236 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
237 |
+
mixed_precision=args.mixed_precision,
|
238 |
+
log_with=args.report_to,
|
239 |
+
project_config=accelerator_project_config,
|
240 |
+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
241 |
+
)
|
242 |
+
|
243 |
+
# Disable AMP for MPS.
|
244 |
+
if torch.backends.mps.is_available():
|
245 |
+
accelerator.native_amp = False
|
246 |
+
|
247 |
+
# Make one log on every process with the configuration for debugging.
|
248 |
+
logging.basicConfig(
|
249 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
250 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
251 |
+
level=logging.INFO,
|
252 |
+
)
|
253 |
+
logger.info(accelerator.state, main_process_only=False)
|
254 |
+
if accelerator.is_local_main_process:
|
255 |
+
transformers.utils.logging.set_verbosity_warning()
|
256 |
+
diffusers.utils.logging.set_verbosity_info()
|
257 |
+
else:
|
258 |
+
transformers.utils.logging.set_verbosity_error()
|
259 |
+
diffusers.utils.logging.set_verbosity_error()
|
260 |
+
|
261 |
+
# If passed along, set the training seed now.
|
262 |
+
if args.seed is not None:
|
263 |
+
set_seed(args.seed)
|
264 |
+
|
265 |
+
# Handle the repository creation
|
266 |
+
if accelerator.is_main_process:
|
267 |
+
if args.output_dir is not None:
|
268 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
269 |
+
|
270 |
+
if args.push_to_hub:
|
271 |
+
repo_id = create_repo(
|
272 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
273 |
+
exist_ok=True,
|
274 |
+
).repo_id
|
275 |
+
|
276 |
+
# Prepare models and scheduler
|
277 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
278 |
+
args.pretrained_model_name_or_path,
|
279 |
+
subfolder="tokenizer",
|
280 |
+
revision=args.revision,
|
281 |
+
)
|
282 |
+
|
283 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
284 |
+
args.pretrained_model_name_or_path,
|
285 |
+
subfolder="text_encoder",
|
286 |
+
revision=args.revision,
|
287 |
+
)
|
288 |
+
|
289 |
+
# CogVideoX-2b weights are stored in float16
|
290 |
+
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
291 |
+
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
292 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
293 |
+
args.pretrained_model_name_or_path,
|
294 |
+
subfolder="transformer",
|
295 |
+
torch_dtype=load_dtype,
|
296 |
+
revision=args.revision,
|
297 |
+
variant=args.variant,
|
298 |
+
)
|
299 |
+
|
300 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
301 |
+
args.pretrained_model_name_or_path,
|
302 |
+
subfolder="vae",
|
303 |
+
revision=args.revision,
|
304 |
+
variant=args.variant,
|
305 |
+
)
|
306 |
+
|
307 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
308 |
+
|
309 |
+
if args.enable_slicing:
|
310 |
+
vae.enable_slicing()
|
311 |
+
if args.enable_tiling:
|
312 |
+
vae.enable_tiling()
|
313 |
+
|
314 |
+
text_encoder.requires_grad_(False)
|
315 |
+
vae.requires_grad_(False)
|
316 |
+
transformer.requires_grad_(True)
|
317 |
+
|
318 |
+
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
319 |
+
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
320 |
+
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
321 |
+
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
322 |
+
|
323 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
324 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
325 |
+
weight_dtype = torch.float32
|
326 |
+
if accelerator.state.deepspeed_plugin:
|
327 |
+
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
328 |
+
if (
|
329 |
+
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
330 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
331 |
+
):
|
332 |
+
weight_dtype = torch.float16
|
333 |
+
if (
|
334 |
+
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
335 |
+
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
336 |
+
):
|
337 |
+
weight_dtype = torch.bfloat16
|
338 |
+
else:
|
339 |
+
if accelerator.mixed_precision == "fp16":
|
340 |
+
weight_dtype = torch.float16
|
341 |
+
elif accelerator.mixed_precision == "bf16":
|
342 |
+
weight_dtype = torch.bfloat16
|
343 |
+
|
344 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
345 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
346 |
+
raise ValueError(
|
347 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
348 |
+
)
|
349 |
+
|
350 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
351 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
352 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
353 |
+
|
354 |
+
if args.gradient_checkpointing:
|
355 |
+
transformer.enable_gradient_checkpointing()
|
356 |
+
|
357 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
358 |
+
def save_model_hook(models, weights, output_dir):
|
359 |
+
if accelerator.is_main_process:
|
360 |
+
for model in models:
|
361 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
362 |
+
model: CogVideoXTransformer3DModel
|
363 |
+
model = unwrap_model(accelerator, model)
|
364 |
+
model.save_pretrained(
|
365 |
+
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
raise ValueError(f"Unexpected save model: {model.__class__}")
|
369 |
+
|
370 |
+
# make sure to pop weight so that corresponding model is not saved again
|
371 |
+
if weights:
|
372 |
+
weights.pop()
|
373 |
+
|
374 |
+
def load_model_hook(models, input_dir):
|
375 |
+
transformer_ = None
|
376 |
+
init_under_meta = False
|
377 |
+
|
378 |
+
# This is a bit of a hack but I don't know any other solution.
|
379 |
+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
380 |
+
while len(models) > 0:
|
381 |
+
model = models.pop()
|
382 |
+
|
383 |
+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
384 |
+
transformer_ = unwrap_model(accelerator, model)
|
385 |
+
else:
|
386 |
+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
387 |
+
else:
|
388 |
+
with init_empty_weights():
|
389 |
+
transformer_ = CogVideoXTransformer3DModel.from_config(
|
390 |
+
args.pretrained_model_name_or_path, subfolder="transformer"
|
391 |
+
)
|
392 |
+
init_under_meta = True
|
393 |
+
|
394 |
+
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
395 |
+
transformer_.register_to_config(**load_model.config)
|
396 |
+
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
|
397 |
+
del load_model
|
398 |
+
|
399 |
+
# Make sure the trainable params are in float32. This is again needed since the base models
|
400 |
+
# are in `weight_dtype`. More details:
|
401 |
+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
402 |
+
if args.mixed_precision == "fp16":
|
403 |
+
cast_training_params([transformer_])
|
404 |
+
|
405 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
406 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
407 |
+
|
408 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
409 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
410 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
411 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
412 |
+
|
413 |
+
if args.scale_lr:
|
414 |
+
args.learning_rate = (
|
415 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
416 |
+
)
|
417 |
+
|
418 |
+
# Make sure the trainable params are in float32.
|
419 |
+
if args.mixed_precision == "fp16":
|
420 |
+
# only upcast trainable parameters (LoRA) into fp32
|
421 |
+
cast_training_params([transformer], dtype=torch.float32)
|
422 |
+
|
423 |
+
transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
424 |
+
|
425 |
+
# Optimization parameters
|
426 |
+
transformer_parameters_with_lr = {
|
427 |
+
"params": transformer_parameters,
|
428 |
+
"lr": args.learning_rate,
|
429 |
+
}
|
430 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
431 |
+
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
432 |
+
|
433 |
+
use_deepspeed_optimizer = (
|
434 |
+
accelerator.state.deepspeed_plugin is not None
|
435 |
+
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
436 |
+
)
|
437 |
+
use_deepspeed_scheduler = (
|
438 |
+
accelerator.state.deepspeed_plugin is not None
|
439 |
+
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
440 |
+
)
|
441 |
+
|
442 |
+
optimizer = get_optimizer(
|
443 |
+
params_to_optimize=params_to_optimize,
|
444 |
+
optimizer_name=args.optimizer,
|
445 |
+
learning_rate=args.learning_rate,
|
446 |
+
beta1=args.beta1,
|
447 |
+
beta2=args.beta2,
|
448 |
+
beta3=args.beta3,
|
449 |
+
epsilon=args.epsilon,
|
450 |
+
weight_decay=args.weight_decay,
|
451 |
+
prodigy_decouple=args.prodigy_decouple,
|
452 |
+
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
453 |
+
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
454 |
+
use_8bit=args.use_8bit,
|
455 |
+
use_4bit=args.use_4bit,
|
456 |
+
use_torchao=args.use_torchao,
|
457 |
+
use_deepspeed=use_deepspeed_optimizer,
|
458 |
+
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
459 |
+
offload_gradients=args.offload_gradients,
|
460 |
+
)
|
461 |
+
|
462 |
+
# Dataset and DataLoader
|
463 |
+
dataset_init_kwargs = {
|
464 |
+
"data_root": args.data_root,
|
465 |
+
"dataset_file": args.dataset_file,
|
466 |
+
"caption_column": args.caption_column,
|
467 |
+
"video_column": args.video_column,
|
468 |
+
"max_num_frames": args.max_num_frames,
|
469 |
+
"id_token": args.id_token,
|
470 |
+
"height_buckets": args.height_buckets,
|
471 |
+
"width_buckets": args.width_buckets,
|
472 |
+
"frame_buckets": args.frame_buckets,
|
473 |
+
"load_tensors": args.load_tensors,
|
474 |
+
"random_flip": args.random_flip,
|
475 |
+
}
|
476 |
+
if args.video_reshape_mode is None:
|
477 |
+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
478 |
+
else:
|
479 |
+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
480 |
+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
481 |
+
)
|
482 |
+
|
483 |
+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
484 |
+
|
485 |
+
train_dataloader = DataLoader(
|
486 |
+
train_dataset,
|
487 |
+
batch_size=1,
|
488 |
+
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
489 |
+
collate_fn=collate_fn,
|
490 |
+
num_workers=args.dataloader_num_workers,
|
491 |
+
pin_memory=args.pin_memory,
|
492 |
+
)
|
493 |
+
|
494 |
+
# Scheduler and math around the number of training steps.
|
495 |
+
overrode_max_train_steps = False
|
496 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
497 |
+
if args.max_train_steps is None:
|
498 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
499 |
+
overrode_max_train_steps = True
|
500 |
+
|
501 |
+
if args.use_cpu_offload_optimizer:
|
502 |
+
lr_scheduler = None
|
503 |
+
accelerator.print(
|
504 |
+
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
505 |
+
"you are training with those settings, they will be ignored."
|
506 |
+
)
|
507 |
+
else:
|
508 |
+
if use_deepspeed_scheduler:
|
509 |
+
from accelerate.utils import DummyScheduler
|
510 |
+
|
511 |
+
lr_scheduler = DummyScheduler(
|
512 |
+
name=args.lr_scheduler,
|
513 |
+
optimizer=optimizer,
|
514 |
+
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
515 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
lr_scheduler = get_scheduler(
|
519 |
+
args.lr_scheduler,
|
520 |
+
optimizer=optimizer,
|
521 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
522 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
523 |
+
num_cycles=args.lr_num_cycles,
|
524 |
+
power=args.lr_power,
|
525 |
+
)
|
526 |
+
|
527 |
+
# Prepare everything with our `accelerator`.
|
528 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
529 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
530 |
+
)
|
531 |
+
|
532 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
533 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
534 |
+
if overrode_max_train_steps:
|
535 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
536 |
+
# Afterwards we recalculate our number of training epochs
|
537 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
538 |
+
|
539 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
540 |
+
# The trackers initializes automatically on the main process.
|
541 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
542 |
+
tracker_name = args.tracker_name or "cogvideox-sft"
|
543 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
544 |
+
|
545 |
+
accelerator.print("===== Memory before training =====")
|
546 |
+
reset_memory(accelerator.device)
|
547 |
+
print_memory(accelerator.device)
|
548 |
+
|
549 |
+
# Train!
|
550 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
551 |
+
|
552 |
+
accelerator.print("***** Running training *****")
|
553 |
+
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
554 |
+
accelerator.print(f" Num examples = {len(train_dataset)}")
|
555 |
+
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
556 |
+
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
557 |
+
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
558 |
+
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
559 |
+
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
560 |
+
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
561 |
+
global_step = 0
|
562 |
+
first_epoch = 0
|
563 |
+
|
564 |
+
# Potentially load in the weights and states from a previous save
|
565 |
+
if not args.resume_from_checkpoint:
|
566 |
+
initial_global_step = 0
|
567 |
+
else:
|
568 |
+
if args.resume_from_checkpoint != "latest":
|
569 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
570 |
+
else:
|
571 |
+
# Get the most recent checkpoint
|
572 |
+
dirs = os.listdir(args.output_dir)
|
573 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
574 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
575 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
576 |
+
|
577 |
+
if path is None:
|
578 |
+
accelerator.print(
|
579 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
580 |
+
)
|
581 |
+
args.resume_from_checkpoint = None
|
582 |
+
initial_global_step = 0
|
583 |
+
else:
|
584 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
585 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
586 |
+
global_step = int(path.split("-")[1])
|
587 |
+
|
588 |
+
initial_global_step = global_step
|
589 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
590 |
+
|
591 |
+
progress_bar = tqdm(
|
592 |
+
range(0, args.max_train_steps),
|
593 |
+
initial=initial_global_step,
|
594 |
+
desc="Steps",
|
595 |
+
# Only show the progress bar once on each machine.
|
596 |
+
disable=not accelerator.is_local_main_process,
|
597 |
+
)
|
598 |
+
|
599 |
+
# For DeepSpeed training
|
600 |
+
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
601 |
+
|
602 |
+
if args.load_tensors:
|
603 |
+
del vae, text_encoder
|
604 |
+
gc.collect()
|
605 |
+
torch.cuda.empty_cache()
|
606 |
+
torch.cuda.synchronize(accelerator.device)
|
607 |
+
|
608 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
609 |
+
|
610 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
611 |
+
transformer.train()
|
612 |
+
|
613 |
+
for step, batch in enumerate(train_dataloader):
|
614 |
+
models_to_accumulate = [transformer]
|
615 |
+
logs = {}
|
616 |
+
|
617 |
+
with accelerator.accumulate(models_to_accumulate):
|
618 |
+
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
619 |
+
prompts = batch["prompts"]
|
620 |
+
|
621 |
+
# Encode videos
|
622 |
+
if not args.load_tensors:
|
623 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
624 |
+
latent_dist = vae.encode(videos).latent_dist
|
625 |
+
else:
|
626 |
+
latent_dist = DiagonalGaussianDistribution(videos)
|
627 |
+
|
628 |
+
videos = latent_dist.sample() * VAE_SCALING_FACTOR
|
629 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
630 |
+
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
631 |
+
model_input = videos
|
632 |
+
|
633 |
+
# Encode prompts
|
634 |
+
if not args.load_tensors:
|
635 |
+
prompt_embeds = compute_prompt_embeddings(
|
636 |
+
tokenizer,
|
637 |
+
text_encoder,
|
638 |
+
prompts,
|
639 |
+
model_config.max_text_seq_length,
|
640 |
+
accelerator.device,
|
641 |
+
weight_dtype,
|
642 |
+
requires_grad=False,
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
prompt_embeds = prompts.to(dtype=weight_dtype)
|
646 |
+
|
647 |
+
# Sample noise that will be added to the latents
|
648 |
+
noise = torch.randn_like(model_input)
|
649 |
+
batch_size, num_frames, num_channels, height, width = model_input.shape
|
650 |
+
|
651 |
+
# Sample a random timestep for each image
|
652 |
+
timesteps = torch.randint(
|
653 |
+
0,
|
654 |
+
scheduler.config.num_train_timesteps,
|
655 |
+
(batch_size,),
|
656 |
+
dtype=torch.int64,
|
657 |
+
device=model_input.device,
|
658 |
+
)
|
659 |
+
|
660 |
+
# Prepare rotary embeds
|
661 |
+
image_rotary_emb = (
|
662 |
+
prepare_rotary_positional_embeddings(
|
663 |
+
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
664 |
+
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
665 |
+
num_frames=num_frames,
|
666 |
+
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
667 |
+
patch_size=model_config.patch_size,
|
668 |
+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
669 |
+
attention_head_dim=model_config.attention_head_dim,
|
670 |
+
device=accelerator.device,
|
671 |
+
base_height=RoPE_BASE_HEIGHT,
|
672 |
+
base_width=RoPE_BASE_WIDTH,
|
673 |
+
)
|
674 |
+
if model_config.use_rotary_positional_embeddings
|
675 |
+
else None
|
676 |
+
)
|
677 |
+
|
678 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
679 |
+
# (this is the forward diffusion process)
|
680 |
+
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
681 |
+
|
682 |
+
# Predict the noise residual
|
683 |
+
model_output = transformer(
|
684 |
+
hidden_states=noisy_model_input,
|
685 |
+
encoder_hidden_states=prompt_embeds,
|
686 |
+
timestep=timesteps,
|
687 |
+
image_rotary_emb=image_rotary_emb,
|
688 |
+
return_dict=False,
|
689 |
+
)[0]
|
690 |
+
|
691 |
+
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
692 |
+
|
693 |
+
weights = 1 / (1 - alphas_cumprod[timesteps])
|
694 |
+
while len(weights.shape) < len(model_pred.shape):
|
695 |
+
weights = weights.unsqueeze(-1)
|
696 |
+
|
697 |
+
target = model_input
|
698 |
+
|
699 |
+
loss = torch.mean(
|
700 |
+
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
701 |
+
dim=1,
|
702 |
+
)
|
703 |
+
loss = loss.mean()
|
704 |
+
accelerator.backward(loss)
|
705 |
+
|
706 |
+
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
707 |
+
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
708 |
+
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
709 |
+
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
710 |
+
logs.update(
|
711 |
+
{
|
712 |
+
"gradient_norm_before_clip": gradient_norm_before_clip,
|
713 |
+
"gradient_norm_after_clip": gradient_norm_after_clip,
|
714 |
+
}
|
715 |
+
)
|
716 |
+
|
717 |
+
if accelerator.state.deepspeed_plugin is None:
|
718 |
+
optimizer.step()
|
719 |
+
optimizer.zero_grad()
|
720 |
+
|
721 |
+
if not args.use_cpu_offload_optimizer:
|
722 |
+
lr_scheduler.step()
|
723 |
+
|
724 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
725 |
+
if accelerator.sync_gradients:
|
726 |
+
progress_bar.update(1)
|
727 |
+
global_step += 1
|
728 |
+
|
729 |
+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
730 |
+
if global_step % args.checkpointing_steps == 0:
|
731 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
732 |
+
if args.checkpoints_total_limit is not None:
|
733 |
+
checkpoints = os.listdir(args.output_dir)
|
734 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
735 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
736 |
+
|
737 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
738 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
739 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
740 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
741 |
+
|
742 |
+
logger.info(
|
743 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
744 |
+
)
|
745 |
+
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
746 |
+
|
747 |
+
for removing_checkpoint in removing_checkpoints:
|
748 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
749 |
+
shutil.rmtree(removing_checkpoint)
|
750 |
+
|
751 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
752 |
+
accelerator.save_state(save_path)
|
753 |
+
logger.info(f"Saved state to {save_path}")
|
754 |
+
|
755 |
+
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
756 |
+
logs.update(
|
757 |
+
{
|
758 |
+
"loss": loss.detach().item(),
|
759 |
+
"lr": last_lr,
|
760 |
+
}
|
761 |
+
)
|
762 |
+
progress_bar.set_postfix(**logs)
|
763 |
+
accelerator.log(logs, step=global_step)
|
764 |
+
|
765 |
+
if global_step >= args.max_train_steps:
|
766 |
+
break
|
767 |
+
|
768 |
+
if accelerator.is_main_process:
|
769 |
+
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
770 |
+
accelerator.print("===== Memory before validation =====")
|
771 |
+
print_memory(accelerator.device)
|
772 |
+
torch.cuda.synchronize(accelerator.device)
|
773 |
+
|
774 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
775 |
+
args.pretrained_model_name_or_path,
|
776 |
+
transformer=unwrap_model(accelerator, transformer),
|
777 |
+
scheduler=scheduler,
|
778 |
+
revision=args.revision,
|
779 |
+
variant=args.variant,
|
780 |
+
torch_dtype=weight_dtype,
|
781 |
+
)
|
782 |
+
|
783 |
+
if args.enable_slicing:
|
784 |
+
pipe.vae.enable_slicing()
|
785 |
+
if args.enable_tiling:
|
786 |
+
pipe.vae.enable_tiling()
|
787 |
+
if args.enable_model_cpu_offload:
|
788 |
+
pipe.enable_model_cpu_offload()
|
789 |
+
|
790 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
791 |
+
for validation_prompt in validation_prompts:
|
792 |
+
pipeline_args = {
|
793 |
+
"prompt": validation_prompt,
|
794 |
+
"guidance_scale": args.guidance_scale,
|
795 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
796 |
+
"height": args.height,
|
797 |
+
"width": args.width,
|
798 |
+
"max_sequence_length": model_config.max_text_seq_length,
|
799 |
+
}
|
800 |
+
|
801 |
+
log_validation(
|
802 |
+
accelerator=accelerator,
|
803 |
+
pipe=pipe,
|
804 |
+
args=args,
|
805 |
+
pipeline_args=pipeline_args,
|
806 |
+
epoch=epoch,
|
807 |
+
is_final_validation=False,
|
808 |
+
)
|
809 |
+
|
810 |
+
accelerator.print("===== Memory after validation =====")
|
811 |
+
print_memory(accelerator.device)
|
812 |
+
reset_memory(accelerator.device)
|
813 |
+
|
814 |
+
del pipe
|
815 |
+
gc.collect()
|
816 |
+
torch.cuda.empty_cache()
|
817 |
+
torch.cuda.synchronize(accelerator.device)
|
818 |
+
|
819 |
+
accelerator.wait_for_everyone()
|
820 |
+
|
821 |
+
if accelerator.is_main_process:
|
822 |
+
transformer = unwrap_model(accelerator, transformer)
|
823 |
+
dtype = (
|
824 |
+
torch.float16
|
825 |
+
if args.mixed_precision == "fp16"
|
826 |
+
else torch.bfloat16
|
827 |
+
if args.mixed_precision == "bf16"
|
828 |
+
else torch.float32
|
829 |
+
)
|
830 |
+
transformer = transformer.to(dtype)
|
831 |
+
|
832 |
+
transformer.save_pretrained(
|
833 |
+
os.path.join(args.output_dir, "transformer"),
|
834 |
+
safe_serialization=True,
|
835 |
+
max_shard_size="5GB",
|
836 |
+
)
|
837 |
+
|
838 |
+
# Cleanup trained models to save memory
|
839 |
+
if args.load_tensors:
|
840 |
+
del transformer
|
841 |
+
else:
|
842 |
+
del transformer, text_encoder, vae
|
843 |
+
|
844 |
+
gc.collect()
|
845 |
+
torch.cuda.empty_cache()
|
846 |
+
torch.cuda.synchronize(accelerator.device)
|
847 |
+
|
848 |
+
accelerator.print("===== Memory before testing =====")
|
849 |
+
print_memory(accelerator.device)
|
850 |
+
reset_memory(accelerator.device)
|
851 |
+
|
852 |
+
# Final test inference
|
853 |
+
pipe = CogVideoXPipeline.from_pretrained(
|
854 |
+
args.pretrained_model_name_or_path,
|
855 |
+
revision=args.revision,
|
856 |
+
variant=args.variant,
|
857 |
+
torch_dtype=weight_dtype,
|
858 |
+
)
|
859 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
860 |
+
|
861 |
+
if args.enable_slicing:
|
862 |
+
pipe.vae.enable_slicing()
|
863 |
+
if args.enable_tiling:
|
864 |
+
pipe.vae.enable_tiling()
|
865 |
+
if args.enable_model_cpu_offload:
|
866 |
+
pipe.enable_model_cpu_offload()
|
867 |
+
|
868 |
+
# Run inference
|
869 |
+
validation_outputs = []
|
870 |
+
if args.validation_prompt and args.num_validation_videos > 0:
|
871 |
+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
872 |
+
for validation_prompt in validation_prompts:
|
873 |
+
pipeline_args = {
|
874 |
+
"prompt": validation_prompt,
|
875 |
+
"guidance_scale": args.guidance_scale,
|
876 |
+
"use_dynamic_cfg": args.use_dynamic_cfg,
|
877 |
+
"height": args.height,
|
878 |
+
"width": args.width,
|
879 |
+
}
|
880 |
+
|
881 |
+
video = log_validation(
|
882 |
+
accelerator=accelerator,
|
883 |
+
pipe=pipe,
|
884 |
+
args=args,
|
885 |
+
pipeline_args=pipeline_args,
|
886 |
+
epoch=epoch,
|
887 |
+
is_final_validation=True,
|
888 |
+
)
|
889 |
+
validation_outputs.extend(video)
|
890 |
+
|
891 |
+
accelerator.print("===== Memory after testing =====")
|
892 |
+
print_memory(accelerator.device)
|
893 |
+
reset_memory(accelerator.device)
|
894 |
+
torch.cuda.synchronize(accelerator.device)
|
895 |
+
|
896 |
+
if args.push_to_hub:
|
897 |
+
save_model_card(
|
898 |
+
repo_id,
|
899 |
+
videos=validation_outputs,
|
900 |
+
base_model=args.pretrained_model_name_or_path,
|
901 |
+
validation_prompt=args.validation_prompt,
|
902 |
+
repo_folder=args.output_dir,
|
903 |
+
fps=args.fps,
|
904 |
+
)
|
905 |
+
upload_folder(
|
906 |
+
repo_id=repo_id,
|
907 |
+
folder_path=args.output_dir,
|
908 |
+
commit_message="End of training",
|
909 |
+
ignore_patterns=["step_*", "epoch_*"],
|
910 |
+
)
|
911 |
+
|
912 |
+
accelerator.end_training()
|
913 |
+
|
914 |
+
|
915 |
+
if __name__ == "__main__":
|
916 |
+
args = get_args()
|
917 |
+
main(args)
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/dataset.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as TT
|
9 |
+
from accelerate.logging import get_logger
|
10 |
+
from torch.utils.data import Dataset, Sampler
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms import InterpolationMode
|
13 |
+
from torchvision.transforms.functional import resize
|
14 |
+
|
15 |
+
|
16 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
17 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
18 |
+
import decord # isort:skip
|
19 |
+
|
20 |
+
decord.bridge.set_bridge("torch")
|
21 |
+
|
22 |
+
logger = get_logger(__name__)
|
23 |
+
|
24 |
+
HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
25 |
+
WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
26 |
+
FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
|
27 |
+
|
28 |
+
|
29 |
+
class VideoDataset(Dataset):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
data_root: str,
|
33 |
+
dataset_file: Optional[str] = None,
|
34 |
+
caption_column: str = "text",
|
35 |
+
video_column: str = "video",
|
36 |
+
max_num_frames: int = 49,
|
37 |
+
id_token: Optional[str] = None,
|
38 |
+
height_buckets: List[int] = None,
|
39 |
+
width_buckets: List[int] = None,
|
40 |
+
frame_buckets: List[int] = None,
|
41 |
+
load_tensors: bool = False,
|
42 |
+
random_flip: Optional[float] = None,
|
43 |
+
image_to_video: bool = False,
|
44 |
+
) -> None:
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.data_root = Path(data_root)
|
48 |
+
self.dataset_file = dataset_file
|
49 |
+
self.caption_column = caption_column
|
50 |
+
self.video_column = video_column
|
51 |
+
self.max_num_frames = max_num_frames
|
52 |
+
self.id_token = f"{id_token.strip()} " if id_token else ""
|
53 |
+
self.height_buckets = height_buckets or HEIGHT_BUCKETS
|
54 |
+
self.width_buckets = width_buckets or WIDTH_BUCKETS
|
55 |
+
self.frame_buckets = frame_buckets or FRAME_BUCKETS
|
56 |
+
self.load_tensors = load_tensors
|
57 |
+
self.random_flip = random_flip
|
58 |
+
self.image_to_video = image_to_video
|
59 |
+
|
60 |
+
self.resolutions = [
|
61 |
+
(f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
|
62 |
+
]
|
63 |
+
|
64 |
+
# Two methods of loading data are supported.
|
65 |
+
# - Using a CSV: caption_column and video_column must be some column in the CSV. One could
|
66 |
+
# make use of other columns too, such as a motion score or aesthetic score, by modifying the
|
67 |
+
# logic in CSV processing.
|
68 |
+
# - Using two files containing line-separate captions and relative paths to videos.
|
69 |
+
# For a more detailed explanation about preparing dataset format, checkout the README.
|
70 |
+
if dataset_file is None:
|
71 |
+
(
|
72 |
+
self.prompts,
|
73 |
+
self.video_paths,
|
74 |
+
) = self._load_dataset_from_local_path()
|
75 |
+
else:
|
76 |
+
(
|
77 |
+
self.prompts,
|
78 |
+
self.video_paths,
|
79 |
+
) = self._load_dataset_from_csv()
|
80 |
+
|
81 |
+
if len(self.video_paths) != len(self.prompts):
|
82 |
+
raise ValueError(
|
83 |
+
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
84 |
+
)
|
85 |
+
|
86 |
+
self.video_transforms = transforms.Compose(
|
87 |
+
[
|
88 |
+
transforms.RandomHorizontalFlip(random_flip)
|
89 |
+
if random_flip
|
90 |
+
else transforms.Lambda(self.identity_transform),
|
91 |
+
transforms.Lambda(self.scale_transform),
|
92 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def identity_transform(x):
|
98 |
+
return x
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def scale_transform(x):
|
102 |
+
return x / 255.0
|
103 |
+
|
104 |
+
def __len__(self) -> int:
|
105 |
+
return len(self.video_paths)
|
106 |
+
|
107 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
108 |
+
if isinstance(index, list):
|
109 |
+
# Here, index is actually a list of data objects that we need to return.
|
110 |
+
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
111 |
+
# to have information about num_frames, height and width. Since this is not stored
|
112 |
+
# as metadata, we need to read the video to get this information. You could read this
|
113 |
+
# information without loading the full video in memory, but we do it anyway. In order
|
114 |
+
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
115 |
+
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
116 |
+
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
117 |
+
# that data is not loaded a second time. PRs are welcome for improvements.
|
118 |
+
return index
|
119 |
+
|
120 |
+
if self.load_tensors:
|
121 |
+
image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
|
122 |
+
|
123 |
+
# This is hardcoded for now.
|
124 |
+
# The VAE's temporal compression ratio is 4.
|
125 |
+
# The VAE's spatial compression ratio is 8.
|
126 |
+
latent_num_frames = video_latents.size(1)
|
127 |
+
if latent_num_frames % 2 == 0:
|
128 |
+
num_frames = latent_num_frames * 4
|
129 |
+
else:
|
130 |
+
num_frames = (latent_num_frames - 1) * 4 + 1
|
131 |
+
|
132 |
+
height = video_latents.size(2) * 8
|
133 |
+
width = video_latents.size(3) * 8
|
134 |
+
|
135 |
+
return {
|
136 |
+
"prompt": prompt_embeds,
|
137 |
+
"image": image_latents,
|
138 |
+
"video": video_latents,
|
139 |
+
"video_metadata": {
|
140 |
+
"num_frames": num_frames,
|
141 |
+
"height": height,
|
142 |
+
"width": width,
|
143 |
+
},
|
144 |
+
}
|
145 |
+
else:
|
146 |
+
image, video, _ = self._preprocess_video(self.video_paths[index])
|
147 |
+
|
148 |
+
return {
|
149 |
+
"prompt": self.id_token + self.prompts[index],
|
150 |
+
"image": image,
|
151 |
+
"video": video,
|
152 |
+
"video_metadata": {
|
153 |
+
"num_frames": video.shape[0],
|
154 |
+
"height": video.shape[2],
|
155 |
+
"width": video.shape[3],
|
156 |
+
},
|
157 |
+
}
|
158 |
+
|
159 |
+
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
|
160 |
+
if not self.data_root.exists():
|
161 |
+
raise ValueError("Root folder for videos does not exist")
|
162 |
+
|
163 |
+
prompt_path = self.data_root.joinpath(self.caption_column)
|
164 |
+
video_path = self.data_root.joinpath(self.video_column)
|
165 |
+
|
166 |
+
if not prompt_path.exists() or not prompt_path.is_file():
|
167 |
+
raise ValueError(
|
168 |
+
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
|
169 |
+
)
|
170 |
+
if not video_path.exists() or not video_path.is_file():
|
171 |
+
raise ValueError(
|
172 |
+
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
|
173 |
+
)
|
174 |
+
|
175 |
+
with open(prompt_path, "r", encoding="utf-8") as file:
|
176 |
+
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
177 |
+
with open(video_path, "r", encoding="utf-8") as file:
|
178 |
+
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
|
179 |
+
|
180 |
+
if not self.load_tensors and any(not path.is_file() for path in video_paths):
|
181 |
+
raise ValueError(
|
182 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
183 |
+
)
|
184 |
+
|
185 |
+
return prompts, video_paths
|
186 |
+
|
187 |
+
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
|
188 |
+
df = pd.read_csv(self.dataset_file)
|
189 |
+
prompts = df[self.caption_column].tolist()
|
190 |
+
video_paths = df[self.video_column].tolist()
|
191 |
+
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
|
192 |
+
|
193 |
+
if any(not path.is_file() for path in video_paths):
|
194 |
+
raise ValueError(
|
195 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
196 |
+
)
|
197 |
+
|
198 |
+
return prompts, video_paths
|
199 |
+
|
200 |
+
def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
201 |
+
r"""
|
202 |
+
Loads a single video, or latent and prompt embedding, based on initialization parameters.
|
203 |
+
|
204 |
+
If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
|
205 |
+
F, C, H and W are the frames, channels, height and width of the input video.
|
206 |
+
|
207 |
+
If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
|
208 |
+
F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
|
209 |
+
and embedding dimension of prompt embeddings.
|
210 |
+
"""
|
211 |
+
if self.load_tensors:
|
212 |
+
return self._load_preprocessed_latents_and_embeds(path)
|
213 |
+
else:
|
214 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
215 |
+
video_num_frames = len(video_reader)
|
216 |
+
|
217 |
+
indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
|
218 |
+
frames = video_reader.get_batch(indices)
|
219 |
+
frames = frames[: self.max_num_frames].float()
|
220 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
221 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
|
222 |
+
|
223 |
+
image = frames[:1].clone() if self.image_to_video else None
|
224 |
+
|
225 |
+
return image, frames, None
|
226 |
+
|
227 |
+
def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
228 |
+
filename_without_ext = path.name.split(".")[0]
|
229 |
+
pt_filename = f"{filename_without_ext}.pt"
|
230 |
+
|
231 |
+
# The current path is something like: /a/b/c/d/videos/00001.mp4
|
232 |
+
# We need to reach: /a/b/c/d/video_latents/00001.pt
|
233 |
+
image_latents_path = path.parent.parent.joinpath("image_latents")
|
234 |
+
video_latents_path = path.parent.parent.joinpath("video_latents")
|
235 |
+
embeds_path = path.parent.parent.joinpath("prompt_embeds")
|
236 |
+
|
237 |
+
if (
|
238 |
+
not video_latents_path.exists()
|
239 |
+
or not embeds_path.exists()
|
240 |
+
or (self.image_to_video and not image_latents_path.exists())
|
241 |
+
):
|
242 |
+
raise ValueError(
|
243 |
+
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
|
244 |
+
)
|
245 |
+
|
246 |
+
if self.image_to_video:
|
247 |
+
image_latent_filepath = image_latents_path.joinpath(pt_filename)
|
248 |
+
video_latent_filepath = video_latents_path.joinpath(pt_filename)
|
249 |
+
embeds_filepath = embeds_path.joinpath(pt_filename)
|
250 |
+
|
251 |
+
if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
|
252 |
+
if self.image_to_video:
|
253 |
+
image_latent_filepath = image_latent_filepath.as_posix()
|
254 |
+
video_latent_filepath = video_latent_filepath.as_posix()
|
255 |
+
embeds_filepath = embeds_filepath.as_posix()
|
256 |
+
raise ValueError(
|
257 |
+
f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
|
258 |
+
)
|
259 |
+
|
260 |
+
images = (
|
261 |
+
torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
|
262 |
+
)
|
263 |
+
latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
|
264 |
+
embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
|
265 |
+
|
266 |
+
return images, latents, embeds
|
267 |
+
|
268 |
+
|
269 |
+
class VideoDatasetWithResizing(VideoDataset):
|
270 |
+
def __init__(self, *args, **kwargs) -> None:
|
271 |
+
super().__init__(*args, **kwargs)
|
272 |
+
|
273 |
+
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
274 |
+
if self.load_tensors:
|
275 |
+
return self._load_preprocessed_latents_and_embeds(path)
|
276 |
+
else:
|
277 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
278 |
+
video_num_frames = len(video_reader)
|
279 |
+
nearest_frame_bucket = min(
|
280 |
+
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
|
281 |
+
)
|
282 |
+
|
283 |
+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
284 |
+
|
285 |
+
frames = video_reader.get_batch(frame_indices)
|
286 |
+
frames = frames[:nearest_frame_bucket].float()
|
287 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
288 |
+
|
289 |
+
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
290 |
+
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
|
291 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
292 |
+
|
293 |
+
image = frames[:1].clone() if self.image_to_video else None
|
294 |
+
|
295 |
+
return image, frames, None
|
296 |
+
|
297 |
+
def _find_nearest_resolution(self, height, width):
|
298 |
+
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
299 |
+
return nearest_res[1], nearest_res[2]
|
300 |
+
|
301 |
+
|
302 |
+
class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
|
303 |
+
def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
|
304 |
+
super().__init__(*args, **kwargs)
|
305 |
+
self.video_reshape_mode = video_reshape_mode
|
306 |
+
|
307 |
+
def _resize_for_rectangle_crop(self, arr, image_size):
|
308 |
+
reshape_mode = self.video_reshape_mode
|
309 |
+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
310 |
+
arr = resize(
|
311 |
+
arr,
|
312 |
+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
313 |
+
interpolation=InterpolationMode.BICUBIC,
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
arr = resize(
|
317 |
+
arr,
|
318 |
+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
319 |
+
interpolation=InterpolationMode.BICUBIC,
|
320 |
+
)
|
321 |
+
|
322 |
+
h, w = arr.shape[2], arr.shape[3]
|
323 |
+
arr = arr.squeeze(0)
|
324 |
+
|
325 |
+
delta_h = h - image_size[0]
|
326 |
+
delta_w = w - image_size[1]
|
327 |
+
|
328 |
+
if reshape_mode == "random" or reshape_mode == "none":
|
329 |
+
top = np.random.randint(0, delta_h + 1)
|
330 |
+
left = np.random.randint(0, delta_w + 1)
|
331 |
+
elif reshape_mode == "center":
|
332 |
+
top, left = delta_h // 2, delta_w // 2
|
333 |
+
else:
|
334 |
+
raise NotImplementedError
|
335 |
+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
336 |
+
return arr
|
337 |
+
|
338 |
+
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
339 |
+
if self.load_tensors:
|
340 |
+
return self._load_preprocessed_latents_and_embeds(path)
|
341 |
+
else:
|
342 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
343 |
+
video_num_frames = len(video_reader)
|
344 |
+
nearest_frame_bucket = min(
|
345 |
+
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
|
346 |
+
)
|
347 |
+
|
348 |
+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
349 |
+
|
350 |
+
frames = video_reader.get_batch(frame_indices)
|
351 |
+
frames = frames[:nearest_frame_bucket].float()
|
352 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
353 |
+
|
354 |
+
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
355 |
+
frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
|
356 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
357 |
+
|
358 |
+
image = frames[:1].clone() if self.image_to_video else None
|
359 |
+
|
360 |
+
return image, frames, None
|
361 |
+
|
362 |
+
def _find_nearest_resolution(self, height, width):
|
363 |
+
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
364 |
+
return nearest_res[1], nearest_res[2]
|
365 |
+
|
366 |
+
|
367 |
+
class BucketSampler(Sampler):
|
368 |
+
r"""
|
369 |
+
PyTorch Sampler that groups 3D data by height, width and frames.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
data_source (`VideoDataset`):
|
373 |
+
A PyTorch dataset object that is an instance of `VideoDataset`.
|
374 |
+
batch_size (`int`, defaults to `8`):
|
375 |
+
The batch size to use for training.
|
376 |
+
shuffle (`bool`, defaults to `True`):
|
377 |
+
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
378 |
+
drop_last (`bool`, defaults to `False`):
|
379 |
+
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
380 |
+
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
381 |
+
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
382 |
+
and batches that do not have `batch_size` number of entries will also be yielded.
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
387 |
+
) -> None:
|
388 |
+
self.data_source = data_source
|
389 |
+
self.batch_size = batch_size
|
390 |
+
self.shuffle = shuffle
|
391 |
+
self.drop_last = drop_last
|
392 |
+
|
393 |
+
self.buckets = {resolution: [] for resolution in data_source.resolutions}
|
394 |
+
|
395 |
+
self._raised_warning_for_drop_last = False
|
396 |
+
|
397 |
+
def __len__(self):
|
398 |
+
if self.drop_last and not self._raised_warning_for_drop_last:
|
399 |
+
self._raised_warning_for_drop_last = True
|
400 |
+
logger.warning(
|
401 |
+
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
402 |
+
)
|
403 |
+
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
404 |
+
|
405 |
+
def __iter__(self):
|
406 |
+
for index, data in enumerate(self.data_source):
|
407 |
+
video_metadata = data["video_metadata"]
|
408 |
+
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
409 |
+
|
410 |
+
self.buckets[(f, h, w)].append(data)
|
411 |
+
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
412 |
+
if self.shuffle:
|
413 |
+
random.shuffle(self.buckets[(f, h, w)])
|
414 |
+
yield self.buckets[(f, h, w)]
|
415 |
+
del self.buckets[(f, h, w)]
|
416 |
+
self.buckets[(f, h, w)] = []
|
417 |
+
|
418 |
+
if self.drop_last:
|
419 |
+
return
|
420 |
+
|
421 |
+
for fhw, bucket in list(self.buckets.items()):
|
422 |
+
if len(bucket) == 0:
|
423 |
+
continue
|
424 |
+
if self.shuffle:
|
425 |
+
random.shuffle(bucket)
|
426 |
+
yield bucket
|
427 |
+
del self.buckets[fhw]
|
428 |
+
self.buckets[fhw] = []
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/prepare_dataset.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import functools
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import queue
|
9 |
+
import traceback
|
10 |
+
import uuid
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
from typing import Any, Dict, List, Optional, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
from diffusers import AutoencoderKLCogVideoX
|
17 |
+
from diffusers.training_utils import set_seed
|
18 |
+
from diffusers.utils import export_to_video, get_logger
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torchvision import transforms
|
21 |
+
from tqdm import tqdm
|
22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
23 |
+
|
24 |
+
|
25 |
+
import decord # isort:skip
|
26 |
+
|
27 |
+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
28 |
+
|
29 |
+
|
30 |
+
decord.bridge.set_bridge("torch")
|
31 |
+
|
32 |
+
logger = get_logger(__name__)
|
33 |
+
|
34 |
+
DTYPE_MAPPING = {
|
35 |
+
"fp32": torch.float32,
|
36 |
+
"fp16": torch.float16,
|
37 |
+
"bf16": torch.bfloat16,
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def check_height(x: Any) -> int:
|
42 |
+
x = int(x)
|
43 |
+
if x % 16 != 0:
|
44 |
+
raise argparse.ArgumentTypeError(
|
45 |
+
f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria."
|
46 |
+
)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def check_width(x: Any) -> int:
|
51 |
+
x = int(x)
|
52 |
+
if x % 16 != 0:
|
53 |
+
raise argparse.ArgumentTypeError(
|
54 |
+
f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria."
|
55 |
+
)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def check_frames(x: Any) -> int:
|
60 |
+
x = int(x)
|
61 |
+
if x % 4 != 0 and x % 4 != 1:
|
62 |
+
raise argparse.ArgumentTypeError(
|
63 |
+
f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria."
|
64 |
+
)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
def get_args() -> Dict[str, Any]:
|
69 |
+
parser = argparse.ArgumentParser()
|
70 |
+
parser.add_argument(
|
71 |
+
"--model_id",
|
72 |
+
type=str,
|
73 |
+
default="THUDM/CogVideoX-2b",
|
74 |
+
help="Hugging Face model ID to use for tokenizer, text encoder and VAE.",
|
75 |
+
)
|
76 |
+
parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.")
|
77 |
+
parser.add_argument(
|
78 |
+
"--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data."
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--caption_column",
|
82 |
+
type=str,
|
83 |
+
default="caption",
|
84 |
+
help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--video_column",
|
88 |
+
type=str,
|
89 |
+
default="video",
|
90 |
+
help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--id_token",
|
94 |
+
type=str,
|
95 |
+
default=None,
|
96 |
+
help="Identifier token appended to the start of each prompt if provided.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--height_buckets",
|
100 |
+
nargs="+",
|
101 |
+
type=check_height,
|
102 |
+
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--width_buckets",
|
106 |
+
nargs="+",
|
107 |
+
type=check_width,
|
108 |
+
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--frame_buckets",
|
112 |
+
nargs="+",
|
113 |
+
type=check_frames,
|
114 |
+
default=[49],
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--random_flip",
|
118 |
+
type=float,
|
119 |
+
default=None,
|
120 |
+
help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--dataloader_num_workers",
|
124 |
+
type=int,
|
125 |
+
default=0,
|
126 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--pin_memory",
|
130 |
+
action="store_true",
|
131 |
+
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--video_reshape_mode",
|
135 |
+
type=str,
|
136 |
+
default=None,
|
137 |
+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--save_image_latents",
|
141 |
+
action="store_true",
|
142 |
+
help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--output_dir",
|
146 |
+
type=str,
|
147 |
+
required=True,
|
148 |
+
help="Path to output directory where preprocessed videos/latents/embeddings will be saved.",
|
149 |
+
)
|
150 |
+
parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.")
|
151 |
+
parser.add_argument(
|
152 |
+
"--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings."
|
153 |
+
)
|
154 |
+
parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.")
|
155 |
+
parser.add_argument(
|
156 |
+
"--save_latents_and_embeddings",
|
157 |
+
action="store_true",
|
158 |
+
help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--use_slicing",
|
162 |
+
action="store_true",
|
163 |
+
help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--use_tiling",
|
167 |
+
action="store_true",
|
168 |
+
help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
|
169 |
+
)
|
170 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.")
|
171 |
+
parser.add_argument(
|
172 |
+
"--num_decode_threads",
|
173 |
+
type=int,
|
174 |
+
default=0,
|
175 |
+
help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--dtype",
|
179 |
+
type=str,
|
180 |
+
choices=["fp32", "fp16", "bf16"],
|
181 |
+
default="fp32",
|
182 |
+
help="Data type to use when generating latents and prompt embeddings.",
|
183 |
+
)
|
184 |
+
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
|
185 |
+
parser.add_argument(
|
186 |
+
"--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts."
|
187 |
+
)
|
188 |
+
return parser.parse_args()
|
189 |
+
|
190 |
+
|
191 |
+
def _get_t5_prompt_embeds(
|
192 |
+
tokenizer: T5Tokenizer,
|
193 |
+
text_encoder: T5EncoderModel,
|
194 |
+
prompt: Union[str, List[str]],
|
195 |
+
num_videos_per_prompt: int = 1,
|
196 |
+
max_sequence_length: int = 226,
|
197 |
+
device: Optional[torch.device] = None,
|
198 |
+
dtype: Optional[torch.dtype] = None,
|
199 |
+
text_input_ids=None,
|
200 |
+
):
|
201 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
202 |
+
batch_size = len(prompt)
|
203 |
+
|
204 |
+
if tokenizer is not None:
|
205 |
+
text_inputs = tokenizer(
|
206 |
+
prompt,
|
207 |
+
padding="max_length",
|
208 |
+
max_length=max_sequence_length,
|
209 |
+
truncation=True,
|
210 |
+
add_special_tokens=True,
|
211 |
+
return_tensors="pt",
|
212 |
+
)
|
213 |
+
text_input_ids = text_inputs.input_ids
|
214 |
+
else:
|
215 |
+
if text_input_ids is None:
|
216 |
+
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
217 |
+
|
218 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
219 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
220 |
+
|
221 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
222 |
+
_, seq_len, _ = prompt_embeds.shape
|
223 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
224 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
225 |
+
|
226 |
+
return prompt_embeds
|
227 |
+
|
228 |
+
|
229 |
+
def encode_prompt(
|
230 |
+
tokenizer: T5Tokenizer,
|
231 |
+
text_encoder: T5EncoderModel,
|
232 |
+
prompt: Union[str, List[str]],
|
233 |
+
num_videos_per_prompt: int = 1,
|
234 |
+
max_sequence_length: int = 226,
|
235 |
+
device: Optional[torch.device] = None,
|
236 |
+
dtype: Optional[torch.dtype] = None,
|
237 |
+
text_input_ids=None,
|
238 |
+
):
|
239 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
240 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
241 |
+
tokenizer,
|
242 |
+
text_encoder,
|
243 |
+
prompt=prompt,
|
244 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
245 |
+
max_sequence_length=max_sequence_length,
|
246 |
+
device=device,
|
247 |
+
dtype=dtype,
|
248 |
+
text_input_ids=text_input_ids,
|
249 |
+
)
|
250 |
+
return prompt_embeds
|
251 |
+
|
252 |
+
|
253 |
+
def compute_prompt_embeddings(
|
254 |
+
tokenizer: T5Tokenizer,
|
255 |
+
text_encoder: T5EncoderModel,
|
256 |
+
prompts: List[str],
|
257 |
+
max_sequence_length: int,
|
258 |
+
device: torch.device,
|
259 |
+
dtype: torch.dtype,
|
260 |
+
requires_grad: bool = False,
|
261 |
+
):
|
262 |
+
if requires_grad:
|
263 |
+
prompt_embeds = encode_prompt(
|
264 |
+
tokenizer,
|
265 |
+
text_encoder,
|
266 |
+
prompts,
|
267 |
+
num_videos_per_prompt=1,
|
268 |
+
max_sequence_length=max_sequence_length,
|
269 |
+
device=device,
|
270 |
+
dtype=dtype,
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
with torch.no_grad():
|
274 |
+
prompt_embeds = encode_prompt(
|
275 |
+
tokenizer,
|
276 |
+
text_encoder,
|
277 |
+
prompts,
|
278 |
+
num_videos_per_prompt=1,
|
279 |
+
max_sequence_length=max_sequence_length,
|
280 |
+
device=device,
|
281 |
+
dtype=dtype,
|
282 |
+
)
|
283 |
+
return prompt_embeds
|
284 |
+
|
285 |
+
|
286 |
+
to_pil_image = transforms.ToPILImage(mode="RGB")
|
287 |
+
|
288 |
+
|
289 |
+
def save_image(image: torch.Tensor, path: pathlib.Path) -> None:
|
290 |
+
image = image.to(dtype=torch.float32).clamp(-1, 1)
|
291 |
+
image = to_pil_image(image.float())
|
292 |
+
image.save(path)
|
293 |
+
|
294 |
+
|
295 |
+
def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None:
|
296 |
+
video = video.to(dtype=torch.float32).clamp(-1, 1)
|
297 |
+
video = [to_pil_image(frame) for frame in video]
|
298 |
+
export_to_video(video, path, fps=fps)
|
299 |
+
|
300 |
+
|
301 |
+
def save_prompt(prompt: str, path: pathlib.Path) -> None:
|
302 |
+
with open(path, "w", encoding="utf-8") as file:
|
303 |
+
file.write(prompt)
|
304 |
+
|
305 |
+
|
306 |
+
def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None:
|
307 |
+
with open(path, "w", encoding="utf-8") as file:
|
308 |
+
file.write(json.dumps(metadata))
|
309 |
+
|
310 |
+
|
311 |
+
@torch.no_grad()
|
312 |
+
def serialize_artifacts(
|
313 |
+
batch_size: int,
|
314 |
+
fps: int,
|
315 |
+
images_dir: Optional[pathlib.Path] = None,
|
316 |
+
image_latents_dir: Optional[pathlib.Path] = None,
|
317 |
+
videos_dir: Optional[pathlib.Path] = None,
|
318 |
+
video_latents_dir: Optional[pathlib.Path] = None,
|
319 |
+
prompts_dir: Optional[pathlib.Path] = None,
|
320 |
+
prompt_embeds_dir: Optional[pathlib.Path] = None,
|
321 |
+
images: Optional[torch.Tensor] = None,
|
322 |
+
image_latents: Optional[torch.Tensor] = None,
|
323 |
+
videos: Optional[torch.Tensor] = None,
|
324 |
+
video_latents: Optional[torch.Tensor] = None,
|
325 |
+
prompts: Optional[List[str]] = None,
|
326 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
327 |
+
) -> None:
|
328 |
+
num_frames, height, width = videos.size(1), videos.size(3), videos.size(4)
|
329 |
+
metadata = [{"num_frames": num_frames, "height": height, "width": width}]
|
330 |
+
|
331 |
+
data_folder_mapper_list = [
|
332 |
+
(images, images_dir, lambda img, path: save_image(img[0], path), "png"),
|
333 |
+
(image_latents, image_latents_dir, torch.save, "pt"),
|
334 |
+
(videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"),
|
335 |
+
(video_latents, video_latents_dir, torch.save, "pt"),
|
336 |
+
(prompts, prompts_dir, save_prompt, "txt"),
|
337 |
+
(prompt_embeds, prompt_embeds_dir, torch.save, "pt"),
|
338 |
+
(metadata, videos_dir, save_metadata, "txt"),
|
339 |
+
]
|
340 |
+
filenames = [uuid.uuid4() for _ in range(batch_size)]
|
341 |
+
|
342 |
+
for data, folder, save_fn, extension in data_folder_mapper_list:
|
343 |
+
if data is None:
|
344 |
+
continue
|
345 |
+
for slice, filename in zip(data, filenames):
|
346 |
+
if isinstance(slice, torch.Tensor):
|
347 |
+
slice = slice.clone().to("cpu")
|
348 |
+
path = folder.joinpath(f"{filename}.{extension}")
|
349 |
+
save_fn(slice, path)
|
350 |
+
|
351 |
+
|
352 |
+
def save_intermediates(output_queue: queue.Queue) -> None:
|
353 |
+
while True:
|
354 |
+
try:
|
355 |
+
item = output_queue.get(timeout=30)
|
356 |
+
if item is None:
|
357 |
+
break
|
358 |
+
serialize_artifacts(**item)
|
359 |
+
|
360 |
+
except queue.Empty:
|
361 |
+
continue
|
362 |
+
|
363 |
+
|
364 |
+
@torch.no_grad()
|
365 |
+
def main():
|
366 |
+
args = get_args()
|
367 |
+
set_seed(args.seed)
|
368 |
+
|
369 |
+
output_dir = pathlib.Path(args.output_dir)
|
370 |
+
tmp_dir = output_dir.joinpath("tmp")
|
371 |
+
|
372 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
373 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
374 |
+
|
375 |
+
# Create task queue for non-blocking serializing of artifacts
|
376 |
+
output_queue = queue.Queue()
|
377 |
+
save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers)
|
378 |
+
save_future = save_thread.submit(save_intermediates, output_queue)
|
379 |
+
|
380 |
+
# Initialize distributed processing
|
381 |
+
if "LOCAL_RANK" in os.environ:
|
382 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
383 |
+
torch.cuda.set_device(local_rank)
|
384 |
+
dist.init_process_group(backend="nccl")
|
385 |
+
world_size = dist.get_world_size()
|
386 |
+
rank = dist.get_rank()
|
387 |
+
else:
|
388 |
+
# Single GPU
|
389 |
+
local_rank = 0
|
390 |
+
world_size = 1
|
391 |
+
rank = 0
|
392 |
+
torch.cuda.set_device(rank)
|
393 |
+
|
394 |
+
# Create folders where intermediate tensors from each rank will be saved
|
395 |
+
images_dir = tmp_dir.joinpath(f"images/{rank}")
|
396 |
+
image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}")
|
397 |
+
videos_dir = tmp_dir.joinpath(f"videos/{rank}")
|
398 |
+
video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}")
|
399 |
+
prompts_dir = tmp_dir.joinpath(f"prompts/{rank}")
|
400 |
+
prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}")
|
401 |
+
|
402 |
+
images_dir.mkdir(parents=True, exist_ok=True)
|
403 |
+
image_latents_dir.mkdir(parents=True, exist_ok=True)
|
404 |
+
videos_dir.mkdir(parents=True, exist_ok=True)
|
405 |
+
video_latents_dir.mkdir(parents=True, exist_ok=True)
|
406 |
+
prompts_dir.mkdir(parents=True, exist_ok=True)
|
407 |
+
prompt_embeds_dir.mkdir(parents=True, exist_ok=True)
|
408 |
+
|
409 |
+
weight_dtype = DTYPE_MAPPING[args.dtype]
|
410 |
+
target_fps = args.target_fps
|
411 |
+
|
412 |
+
# 1. Dataset
|
413 |
+
dataset_init_kwargs = {
|
414 |
+
"data_root": args.data_root,
|
415 |
+
"dataset_file": args.dataset_file,
|
416 |
+
"caption_column": args.caption_column,
|
417 |
+
"video_column": args.video_column,
|
418 |
+
"max_num_frames": args.max_num_frames,
|
419 |
+
"id_token": args.id_token,
|
420 |
+
"height_buckets": args.height_buckets,
|
421 |
+
"width_buckets": args.width_buckets,
|
422 |
+
"frame_buckets": args.frame_buckets,
|
423 |
+
"load_tensors": False,
|
424 |
+
"random_flip": args.random_flip,
|
425 |
+
"image_to_video": args.save_image_latents,
|
426 |
+
}
|
427 |
+
if args.video_reshape_mode is None:
|
428 |
+
dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
429 |
+
else:
|
430 |
+
dataset = VideoDatasetWithResizeAndRectangleCrop(
|
431 |
+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
432 |
+
)
|
433 |
+
|
434 |
+
original_dataset_size = len(dataset)
|
435 |
+
|
436 |
+
# Split data among GPUs
|
437 |
+
if world_size > 1:
|
438 |
+
samples_per_gpu = original_dataset_size // world_size
|
439 |
+
start_index = rank * samples_per_gpu
|
440 |
+
end_index = start_index + samples_per_gpu
|
441 |
+
if rank == world_size - 1:
|
442 |
+
end_index = original_dataset_size # Make sure the last GPU gets the remaining data
|
443 |
+
|
444 |
+
# Slice the data
|
445 |
+
dataset.prompts = dataset.prompts[start_index:end_index]
|
446 |
+
dataset.video_paths = dataset.video_paths[start_index:end_index]
|
447 |
+
else:
|
448 |
+
pass
|
449 |
+
|
450 |
+
rank_dataset_size = len(dataset)
|
451 |
+
|
452 |
+
# 2. Dataloader
|
453 |
+
def collate_fn(data):
|
454 |
+
prompts = [x["prompt"] for x in data[0]]
|
455 |
+
|
456 |
+
images = None
|
457 |
+
if args.save_image_latents:
|
458 |
+
images = [x["image"] for x in data[0]]
|
459 |
+
images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
|
460 |
+
|
461 |
+
videos = [x["video"] for x in data[0]]
|
462 |
+
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
|
463 |
+
|
464 |
+
return {
|
465 |
+
"images": images,
|
466 |
+
"videos": videos,
|
467 |
+
"prompts": prompts,
|
468 |
+
}
|
469 |
+
|
470 |
+
dataloader = DataLoader(
|
471 |
+
dataset,
|
472 |
+
batch_size=1,
|
473 |
+
sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False),
|
474 |
+
collate_fn=collate_fn,
|
475 |
+
num_workers=args.dataloader_num_workers,
|
476 |
+
pin_memory=args.pin_memory,
|
477 |
+
)
|
478 |
+
|
479 |
+
# 3. Prepare models
|
480 |
+
device = f"cuda:{rank}"
|
481 |
+
|
482 |
+
if args.save_latents_and_embeddings:
|
483 |
+
tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
|
484 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
485 |
+
args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype
|
486 |
+
)
|
487 |
+
text_encoder = text_encoder.to(device)
|
488 |
+
|
489 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype)
|
490 |
+
vae = vae.to(device)
|
491 |
+
|
492 |
+
if args.use_slicing:
|
493 |
+
vae.enable_slicing()
|
494 |
+
if args.use_tiling:
|
495 |
+
vae.enable_tiling()
|
496 |
+
|
497 |
+
# 4. Compute latents and embeddings and save
|
498 |
+
if rank == 0:
|
499 |
+
iterator = tqdm(
|
500 |
+
dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size
|
501 |
+
)
|
502 |
+
else:
|
503 |
+
iterator = dataloader
|
504 |
+
|
505 |
+
for step, batch in enumerate(iterator):
|
506 |
+
try:
|
507 |
+
images = None
|
508 |
+
image_latents = None
|
509 |
+
video_latents = None
|
510 |
+
prompt_embeds = None
|
511 |
+
|
512 |
+
if args.save_image_latents:
|
513 |
+
images = batch["images"].to(device, non_blocking=True)
|
514 |
+
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
515 |
+
|
516 |
+
videos = batch["videos"].to(device, non_blocking=True)
|
517 |
+
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
518 |
+
|
519 |
+
prompts = batch["prompts"]
|
520 |
+
|
521 |
+
# Encode videos & images
|
522 |
+
if args.save_latents_and_embeddings:
|
523 |
+
if args.use_slicing:
|
524 |
+
if args.save_image_latents:
|
525 |
+
encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
|
526 |
+
image_latents = torch.cat(encoded_slices)
|
527 |
+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
528 |
+
|
529 |
+
encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
|
530 |
+
video_latents = torch.cat(encoded_slices)
|
531 |
+
|
532 |
+
else:
|
533 |
+
if args.save_image_latents:
|
534 |
+
image_latents = vae._encode(images)
|
535 |
+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
536 |
+
|
537 |
+
video_latents = vae._encode(videos)
|
538 |
+
|
539 |
+
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
540 |
+
|
541 |
+
# Encode prompts
|
542 |
+
prompt_embeds = compute_prompt_embeddings(
|
543 |
+
tokenizer,
|
544 |
+
text_encoder,
|
545 |
+
prompts,
|
546 |
+
args.max_sequence_length,
|
547 |
+
device,
|
548 |
+
weight_dtype,
|
549 |
+
requires_grad=False,
|
550 |
+
)
|
551 |
+
|
552 |
+
if images is not None:
|
553 |
+
images = (images.permute(0, 2, 1, 3, 4) + 1) / 2
|
554 |
+
|
555 |
+
videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2
|
556 |
+
|
557 |
+
output_queue.put(
|
558 |
+
{
|
559 |
+
"batch_size": len(prompts),
|
560 |
+
"fps": target_fps,
|
561 |
+
"images_dir": images_dir,
|
562 |
+
"image_latents_dir": image_latents_dir,
|
563 |
+
"videos_dir": videos_dir,
|
564 |
+
"video_latents_dir": video_latents_dir,
|
565 |
+
"prompts_dir": prompts_dir,
|
566 |
+
"prompt_embeds_dir": prompt_embeds_dir,
|
567 |
+
"images": images,
|
568 |
+
"image_latents": image_latents,
|
569 |
+
"videos": videos,
|
570 |
+
"video_latents": video_latents,
|
571 |
+
"prompts": prompts,
|
572 |
+
"prompt_embeds": prompt_embeds,
|
573 |
+
}
|
574 |
+
)
|
575 |
+
|
576 |
+
except Exception:
|
577 |
+
print("-------------------------")
|
578 |
+
print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}")
|
579 |
+
traceback.print_exc()
|
580 |
+
print("-------------------------")
|
581 |
+
|
582 |
+
# 5. Complete distributed processing
|
583 |
+
if world_size > 1:
|
584 |
+
dist.barrier()
|
585 |
+
dist.destroy_process_group()
|
586 |
+
|
587 |
+
output_queue.put(None)
|
588 |
+
save_thread.shutdown(wait=True)
|
589 |
+
save_future.result()
|
590 |
+
|
591 |
+
# 6. Combine results from each rank
|
592 |
+
if rank == 0:
|
593 |
+
print(
|
594 |
+
f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`"
|
595 |
+
)
|
596 |
+
|
597 |
+
# Move files from each rank to common directory
|
598 |
+
for subfolder, extension in [
|
599 |
+
("images", "png"),
|
600 |
+
("image_latents", "pt"),
|
601 |
+
("videos", "mp4"),
|
602 |
+
("video_latents", "pt"),
|
603 |
+
("prompts", "txt"),
|
604 |
+
("prompt_embeds", "pt"),
|
605 |
+
("videos", "txt"),
|
606 |
+
]:
|
607 |
+
tmp_subfolder = tmp_dir.joinpath(subfolder)
|
608 |
+
combined_subfolder = output_dir.joinpath(subfolder)
|
609 |
+
combined_subfolder.mkdir(parents=True, exist_ok=True)
|
610 |
+
pattern = f"*.{extension}"
|
611 |
+
|
612 |
+
for file in tmp_subfolder.rglob(pattern):
|
613 |
+
file.replace(combined_subfolder / file.name)
|
614 |
+
|
615 |
+
# Remove temporary directories
|
616 |
+
def rmdir_recursive(dir: pathlib.Path) -> None:
|
617 |
+
for child in dir.iterdir():
|
618 |
+
if child.is_file():
|
619 |
+
child.unlink()
|
620 |
+
else:
|
621 |
+
rmdir_recursive(child)
|
622 |
+
dir.rmdir()
|
623 |
+
|
624 |
+
rmdir_recursive(tmp_dir)
|
625 |
+
|
626 |
+
# Combine prompts and videos into individual text files and single jsonl
|
627 |
+
prompts_folder = output_dir.joinpath("prompts")
|
628 |
+
prompts = []
|
629 |
+
stems = []
|
630 |
+
|
631 |
+
for filename in prompts_folder.rglob("*.txt"):
|
632 |
+
with open(filename, "r") as file:
|
633 |
+
prompts.append(file.read().strip())
|
634 |
+
stems.append(filename.stem)
|
635 |
+
|
636 |
+
prompts_txt = output_dir.joinpath("prompts.txt")
|
637 |
+
videos_txt = output_dir.joinpath("videos.txt")
|
638 |
+
data_jsonl = output_dir.joinpath("data.jsonl")
|
639 |
+
|
640 |
+
with open(prompts_txt, "w") as file:
|
641 |
+
for prompt in prompts:
|
642 |
+
file.write(f"{prompt}\n")
|
643 |
+
|
644 |
+
with open(videos_txt, "w") as file:
|
645 |
+
for stem in stems:
|
646 |
+
file.write(f"videos/{stem}.mp4\n")
|
647 |
+
|
648 |
+
with open(data_jsonl, "w") as file:
|
649 |
+
for prompt, stem in zip(prompts, stems):
|
650 |
+
video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt")
|
651 |
+
with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file:
|
652 |
+
metadata = json.loads(metadata_file.read())
|
653 |
+
|
654 |
+
data = {
|
655 |
+
"prompt": prompt,
|
656 |
+
"prompt_embed": f"prompt_embeds/{stem}.pt",
|
657 |
+
"image": f"images/{stem}.png",
|
658 |
+
"image_latent": f"image_latents/{stem}.pt",
|
659 |
+
"video": f"videos/{stem}.mp4",
|
660 |
+
"video_latent": f"video_latents/{stem}.pt",
|
661 |
+
"metadata": metadata,
|
662 |
+
}
|
663 |
+
file.write(json.dumps(data) + "\n")
|
664 |
+
|
665 |
+
print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`")
|
666 |
+
|
667 |
+
|
668 |
+
if __name__ == "__main__":
|
669 |
+
main()
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .text_encoder import compute_prompt_embeddings
|
docs/finetrainers-src-codebase/examples/_legacy/training/cogvideox/text_encoder/text_encoder.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
5 |
+
|
6 |
+
|
7 |
+
def _get_t5_prompt_embeds(
|
8 |
+
tokenizer: T5Tokenizer,
|
9 |
+
text_encoder: T5EncoderModel,
|
10 |
+
prompt: Union[str, List[str]],
|
11 |
+
num_videos_per_prompt: int = 1,
|
12 |
+
max_sequence_length: int = 226,
|
13 |
+
device: Optional[torch.device] = None,
|
14 |
+
dtype: Optional[torch.dtype] = None,
|
15 |
+
text_input_ids=None,
|
16 |
+
):
|
17 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
18 |
+
batch_size = len(prompt)
|
19 |
+
|
20 |
+
if tokenizer is not None:
|
21 |
+
text_inputs = tokenizer(
|
22 |
+
prompt,
|
23 |
+
padding="max_length",
|
24 |
+
max_length=max_sequence_length,
|
25 |
+
truncation=True,
|
26 |
+
add_special_tokens=True,
|
27 |
+
return_tensors="pt",
|
28 |
+
)
|
29 |
+
text_input_ids = text_inputs.input_ids
|
30 |
+
else:
|
31 |
+
if text_input_ids is None:
|
32 |
+
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
33 |
+
|
34 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
35 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
36 |
+
|
37 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
38 |
+
_, seq_len, _ = prompt_embeds.shape
|
39 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
40 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
41 |
+
|
42 |
+
return prompt_embeds
|
43 |
+
|
44 |
+
|
45 |
+
def encode_prompt(
|
46 |
+
tokenizer: T5Tokenizer,
|
47 |
+
text_encoder: T5EncoderModel,
|
48 |
+
prompt: Union[str, List[str]],
|
49 |
+
num_videos_per_prompt: int = 1,
|
50 |
+
max_sequence_length: int = 226,
|
51 |
+
device: Optional[torch.device] = None,
|
52 |
+
dtype: Optional[torch.dtype] = None,
|
53 |
+
text_input_ids=None,
|
54 |
+
):
|
55 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
56 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
57 |
+
tokenizer,
|
58 |
+
text_encoder,
|
59 |
+
prompt=prompt,
|
60 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
61 |
+
max_sequence_length=max_sequence_length,
|
62 |
+
device=device,
|
63 |
+
dtype=dtype,
|
64 |
+
text_input_ids=text_input_ids,
|
65 |
+
)
|
66 |
+
return prompt_embeds
|
67 |
+
|
68 |
+
|
69 |
+
def compute_prompt_embeddings(
|
70 |
+
tokenizer: T5Tokenizer,
|
71 |
+
text_encoder: T5EncoderModel,
|
72 |
+
prompt: str,
|
73 |
+
max_sequence_length: int,
|
74 |
+
device: torch.device,
|
75 |
+
dtype: torch.dtype,
|
76 |
+
requires_grad: bool = False,
|
77 |
+
):
|
78 |
+
if requires_grad:
|
79 |
+
prompt_embeds = encode_prompt(
|
80 |
+
tokenizer,
|
81 |
+
text_encoder,
|
82 |
+
prompt,
|
83 |
+
num_videos_per_prompt=1,
|
84 |
+
max_sequence_length=max_sequence_length,
|
85 |
+
device=device,
|
86 |
+
dtype=dtype,
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
with torch.no_grad():
|
90 |
+
prompt_embeds = encode_prompt(
|
91 |
+
tokenizer,
|
92 |
+
text_encoder,
|
93 |
+
prompt,
|
94 |
+
num_videos_per_prompt=1,
|
95 |
+
max_sequence_length=max_sequence_length,
|
96 |
+
device=device,
|
97 |
+
dtype=dtype,
|
98 |
+
)
|
99 |
+
return prompt_embeds
|