Spaces:
Runtime error
Runtime error
lllyasviel
commited on
Commit
·
06fccba
0
Parent(s):
- .gitignore +168 -0
- LICENSE +201 -0
- README.md +111 -0
- diffusers_helper/cat_cond.py +24 -0
- diffusers_helper/code_cond.py +34 -0
- diffusers_helper/k_diffusion.py +145 -0
- diffusers_helper/utils.py +136 -0
- diffusers_vdm/attention.py +385 -0
- diffusers_vdm/basics.py +148 -0
- diffusers_vdm/dynamic_tsnr_sampler.py +177 -0
- diffusers_vdm/improved_clip_vision.py +58 -0
- diffusers_vdm/pipeline.py +188 -0
- diffusers_vdm/projection.py +160 -0
- diffusers_vdm/unet.py +650 -0
- diffusers_vdm/utils.py +43 -0
- diffusers_vdm/vae.py +826 -0
- gradio_app.py +321 -0
- imgs/1.jpg +0 -0
- imgs/2.jpg +0 -0
- imgs/3.jpg +0 -0
- memory_management.py +67 -0
- requirements.txt +16 -0
- wd14tagger.py +105 -0
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hf_token.txt
|
2 |
+
hf_download/
|
3 |
+
results/
|
4 |
+
*.csv
|
5 |
+
*.onnx
|
6 |
+
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Paints-Undo
|
2 |
+
|
3 |
+
PaintsUndo: A Base Model of Drawing Behaviors in Digital Paintings
|
4 |
+
|
5 |
+
Paints-Undo is a project aimed at providing base models of human drawing behaviors with a hope that future AI models can better align with the real needs of human artists.
|
6 |
+
|
7 |
+
The name "Paints-Undo" is inspired by the similarity that, the model's outputs look like pressing the "undo" button (usually Ctrl+Z) many times in digital painting software.
|
8 |
+
|
9 |
+
Paints-Undo presents a family of models that take an image as input and then output the drawing sequence of that image. The model displays all kinds of human behaviors, including but not limited to sketching, inking, coloring, shading, transforming, left-right flipping, color curve tuning, changing the visibility of layers, and even changing the overall idea during the drawing process.
|
10 |
+
|
11 |
+
**This page does not contain any examples. All examples are in the below Git page:**
|
12 |
+
|
13 |
+
[>>> Click Here to See the Example Page <<<](https://lllyasviel.github.io/pages/paints_undo/)
|
14 |
+
|
15 |
+
# Get Started
|
16 |
+
|
17 |
+
You can deploy PaintsUndo locally via:
|
18 |
+
|
19 |
+
git clone https://github.com/lllyasviel/Paints-UNDO.git
|
20 |
+
cd Paints-UNDO
|
21 |
+
conda create -n paints_undo python=3.10
|
22 |
+
conda activate paints_undo
|
23 |
+
pip install xformers
|
24 |
+
pip install -r requirements.txt
|
25 |
+
python gradio_app.py
|
26 |
+
|
27 |
+
(If you do not know how to use these commands, you can paste those commands to ChatGPT and ask ChatGPT to explain and give more detailed instructions.)
|
28 |
+
|
29 |
+
The inference is tested with 24GB VRAM on Nvidia 4090 and 3090TI. It may also work with 16GB VRAM, but does not work with 8GB. My estimation is that, under extreme optimization (including weight offloading and sliced attention), the theoretical minimal VRAM requirement is about 10~12.5 GB.
|
30 |
+
|
31 |
+
You can expect to process one image in about 5 to 10 minutes, depending on your settings. As a typical result, you will get a video of 25 seconds at FPS 4, with resolution 320x512, or 512x320, or 384x448, or 448x384.
|
32 |
+
|
33 |
+
Because the processing time, in most cases, is significantly longer than most tasks/quota in HuggingFace Space, I personally do not highly recommend to deploy this to HuggingFace Space, to avoid placing an unnecessary burden on the HF servers.
|
34 |
+
|
35 |
+
If you do not have required computation devices and still wants an online solution, one option is to wait us to release a Colab notebook (but I am not sure if Colab free tier will work).
|
36 |
+
|
37 |
+
# Model Notes
|
38 |
+
|
39 |
+
We currently release two models `paints_undo_single_frame` and `paints_undo_multi_frame`. Let's call them single-frame model and multi-frame model.
|
40 |
+
|
41 |
+
The single-frame model takes one image and an `operation step` as input, and outputs one single image. Assuming that an artwork can always be created with 1000 human operations (for example, one brush stroke is one operation), and the `operation step` is an int number from 0 to 999. The number 0 is the finished final artwork, and the number 999 is the first brush stroke drawn on the pure white canvas. You can understand this model as an "undo" (or called Ctrl+Z) model. You input the final image, and indicate how many times you want to "Ctrl+Z", and the model will give you a "simulated" screenshot after those "Ctrl+Z"s are pressed. If your `operation step` is 100, then it means you want to simulate "Ctrl+Z" 100 times on this image get the appearance after the 100-th "Ctrl+Z".
|
42 |
+
|
43 |
+
The multi-frame model takes two images as inputs and output 16 intermediate frames between the two input images. The result is much more consistent than the single-frame model, but also much slower, less "creative", and limited in 16 frames.
|
44 |
+
|
45 |
+
In this repo, the default method is to use them together. We will first infer the single-frame model about 5-7 times to get 5-7 "keyframes", and then we use the multi-frame model to "interpolate" those keyframes to actually generate a relatively long video.
|
46 |
+
|
47 |
+
In theory this system can be used in many ways and even give infinitely long video, but in practice results are good when the final frame count is about 100-500.
|
48 |
+
|
49 |
+
### Model Architecture (paints_undo_single_frame)
|
50 |
+
|
51 |
+
The model is a modified architecture of SD1.5 trained on different betas scheduler, clip skip, and the aforementioned `operation step` condition. To be specific, the model is trained with the betas of:
|
52 |
+
|
53 |
+
`betas = torch.linspace(0.00085, 0.020, 1000, dtype=torch.float64)`
|
54 |
+
|
55 |
+
For comparison, the original SD1.5 is trained with the betas of:
|
56 |
+
|
57 |
+
`betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float64) ** 2`
|
58 |
+
|
59 |
+
You can notice the difference in the ending betas and the removed square. The choice of this scheduler is based on our internal user study.
|
60 |
+
|
61 |
+
The last layer of the text encoder CLIP ViT-L/14 is permanently removed. It is now mathematically consistent to always set CLIP Skip to 2 (if you use diffusers).
|
62 |
+
|
63 |
+
The `operation step` condition is added to layer embeddings in a way similar to SDXL's extra embeddings.
|
64 |
+
|
65 |
+
Also, since the solo purpose of this model is to process existing images, the model is strictly aligned with WD14 tagger without any other augmentations. You should always use WD14 tagger (the one in this repo) to process the input image to get the prompt. Otherwise, the results may be defective. Human-written prompts are not tested.
|
66 |
+
|
67 |
+
### Model Architecture (paints_undo_multi_frame)
|
68 |
+
|
69 |
+
This model is trained by resuming from [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter) family, but the original Crafter's `lvdm` is not used and all training/inference codes are completely implemented from scratch. (BTW, now the codes are based on modern Diffusers.) Although the initial weights are resumed from VideoCrafter, the topology of neural network is modified a lot, and the network behavior is now largely different from original Crafter after extensive training.
|
70 |
+
|
71 |
+
The overall architecture is like Crafter with 5 components, 3D-UNet, VAE, CLIP, CLIP-Vision, Image Projection.
|
72 |
+
|
73 |
+
**VAE**: The VAE is the exactly same anime VAE extracted from [ToonCrafter](https://github.com/ToonCrafter/ToonCrafter). Thanks ToonCrafter a lot for providing the excellent anime temporal VAE for Crafters.
|
74 |
+
|
75 |
+
**3D-UNet**: The 3D-UNet is modified from Crafters's `lvdm` with revisions to attention modules. Other than some minor changes in codes, the major change is that now the UNet are trained and supports temporal windows in Spatial Self Attention layers. You can change the codes in `diffusers_vdm.attention.CrossAttention.temporal_window_for_spatial_self_attention` and `temporal_window_type` to activate three types of attention windows:
|
76 |
+
|
77 |
+
1. "prv" mode: Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous frame. The first frame only attend itself.
|
78 |
+
2. "first": Each frame's Spatial Self-Attention also attend to full spatial contexts of the first frame of the entire sequence. The first frame only attend its self.
|
79 |
+
3. "roll": Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous and next frames, based on the ordering of `torch.roll`.
|
80 |
+
|
81 |
+
Note that this is by default disabled in inference to save GPU memory.
|
82 |
+
|
83 |
+
**CLIP**: The CLIP of SD2.1.
|
84 |
+
|
85 |
+
**CLIP-Vision**: Our implementation of Clip Vision (ViT/H) that supports arbitrary aspect ratios by interpolating the positional embedding. After experimenting with linear interpolation, nearest neighbor, and Rotary Positional Encoding (RoPE), our final choice is nearest neighbor. Note that this is different from Crafter methods that resize or center-crop images to 224x224.
|
86 |
+
|
87 |
+
**Image Projection**: Our implementation of a tiny transformer that takes two frames as inputs and outputs 16 image embeddings for each frame. Note that this is different from Crafter methods that only use one image.
|
88 |
+
|
89 |
+
# Tutorial
|
90 |
+
|
91 |
+
After you get into the Gradio interface:
|
92 |
+
|
93 |
+
Step 0: Upload an image or just click an Example image on the bottom of the page.
|
94 |
+
|
95 |
+
Step 1: In the UI titled "step 1", click generate prompts to get the global prompt.
|
96 |
+
|
97 |
+
Step 2: In the UI titled "step 2", click "Generate Key Frames". You can change seeds or other parameters on the left.
|
98 |
+
|
99 |
+
Step 3: In the UI titled "step 3", click "Generate Video". You can change seeds or other parameters on the left.
|
100 |
+
|
101 |
+
# Cite
|
102 |
+
|
103 |
+
@Misc{paintsundo,
|
104 |
+
author = {Paints-Undo Team},
|
105 |
+
title = {Paints-Undo GitHub Page},
|
106 |
+
year = {2024},
|
107 |
+
}
|
108 |
+
|
109 |
+
# Disclaimer
|
110 |
+
|
111 |
+
This project aims to develop base models of human drawing behaviors, facilitating future AI systems to better meet the real needs of human artists. Users are granted the freedom to create content using this tool, but they are expected to comply with local laws and use it responsibly. Users must not employ the tool to generate false information or incite confrontation. The developers do not assume any responsibility for potential misuse by users.
|
diffusers_helper/cat_cond.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def unet_add_concat_conds(unet, new_channels=4):
|
5 |
+
with torch.no_grad():
|
6 |
+
new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
7 |
+
new_conv_in.weight.zero_()
|
8 |
+
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
9 |
+
new_conv_in.bias = unet.conv_in.bias
|
10 |
+
unet.conv_in = new_conv_in
|
11 |
+
|
12 |
+
unet_original_forward = unet.forward
|
13 |
+
|
14 |
+
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
15 |
+
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
|
16 |
+
c_concat = cross_attention_kwargs.pop('concat_conds')
|
17 |
+
kwargs['cross_attention_kwargs'] = cross_attention_kwargs
|
18 |
+
|
19 |
+
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
|
20 |
+
new_sample = torch.cat([sample, c_concat], dim=1)
|
21 |
+
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
|
22 |
+
|
23 |
+
unet.forward = hooked_unet_forward
|
24 |
+
return
|
diffusers_helper/code_cond.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
4 |
+
|
5 |
+
|
6 |
+
def unet_add_coded_conds(unet, added_number_count=1):
|
7 |
+
unet.add_time_proj = Timesteps(256, True, 0)
|
8 |
+
unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)
|
9 |
+
|
10 |
+
def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
|
11 |
+
coded_conds = added_cond_kwargs.get("coded_conds")
|
12 |
+
batch_size = coded_conds.shape[0]
|
13 |
+
time_embeds = unet.add_time_proj(coded_conds.flatten())
|
14 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
15 |
+
time_embeds = time_embeds.to(emb)
|
16 |
+
aug_emb = unet.add_embedding(time_embeds)
|
17 |
+
return aug_emb
|
18 |
+
|
19 |
+
unet.get_aug_embed = get_aug_embed
|
20 |
+
|
21 |
+
unet_original_forward = unet.forward
|
22 |
+
|
23 |
+
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
24 |
+
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
|
25 |
+
coded_conds = cross_attention_kwargs.pop('coded_conds')
|
26 |
+
kwargs['cross_attention_kwargs'] = cross_attention_kwargs
|
27 |
+
|
28 |
+
coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
|
29 |
+
kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
|
30 |
+
return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)
|
31 |
+
|
32 |
+
unet.forward = hooked_unet_forward
|
33 |
+
|
34 |
+
return
|
diffusers_helper/k_diffusion.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None):
|
9 |
+
"""DPM-Solver++(2M)."""
|
10 |
+
extra_args = {} if extra_args is None else extra_args
|
11 |
+
s_in = x.new_ones([x.shape[0]])
|
12 |
+
sigma_fn = lambda t: t.neg().exp()
|
13 |
+
t_fn = lambda sigma: sigma.log().neg()
|
14 |
+
old_denoised = None
|
15 |
+
|
16 |
+
bar = tqdm if progress_tqdm is None else progress_tqdm
|
17 |
+
|
18 |
+
for i in bar(range(len(sigmas) - 1)):
|
19 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
20 |
+
if callback is not None:
|
21 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
22 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
23 |
+
h = t_next - t
|
24 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
25 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
26 |
+
else:
|
27 |
+
h_last = t - t_fn(sigmas[i - 1])
|
28 |
+
r = h_last / h
|
29 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
30 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
31 |
+
old_denoised = denoised
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class KModel:
|
36 |
+
def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False):
|
37 |
+
if linear:
|
38 |
+
betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64)
|
39 |
+
else:
|
40 |
+
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
|
41 |
+
|
42 |
+
alphas = 1. - betas
|
43 |
+
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
44 |
+
|
45 |
+
self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
46 |
+
self.log_sigmas = self.sigmas.log()
|
47 |
+
self.sigma_data = 1.0
|
48 |
+
self.unet = unet
|
49 |
+
return
|
50 |
+
|
51 |
+
@property
|
52 |
+
def sigma_min(self):
|
53 |
+
return self.sigmas[0]
|
54 |
+
|
55 |
+
@property
|
56 |
+
def sigma_max(self):
|
57 |
+
return self.sigmas[-1]
|
58 |
+
|
59 |
+
def timestep(self, sigma):
|
60 |
+
log_sigma = sigma.log()
|
61 |
+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
62 |
+
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
63 |
+
|
64 |
+
def get_sigmas_karras(self, n, rho=7.):
|
65 |
+
ramp = torch.linspace(0, 1, n)
|
66 |
+
min_inv_rho = self.sigma_min ** (1 / rho)
|
67 |
+
max_inv_rho = self.sigma_max ** (1 / rho)
|
68 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
69 |
+
return torch.cat([sigmas, sigmas.new_zeros([1])])
|
70 |
+
|
71 |
+
def __call__(self, x, sigma, **extra_args):
|
72 |
+
x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
|
73 |
+
x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype)
|
74 |
+
t = self.timestep(sigma)
|
75 |
+
cfg_scale = extra_args['cfg_scale']
|
76 |
+
eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
|
77 |
+
eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
|
78 |
+
noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
|
79 |
+
return x - noise_pred * sigma[:, None, None, None]
|
80 |
+
|
81 |
+
|
82 |
+
class KDiffusionSampler:
|
83 |
+
def __init__(self, unet, **kwargs):
|
84 |
+
self.unet = unet
|
85 |
+
self.k_model = KModel(unet=unet, **kwargs)
|
86 |
+
|
87 |
+
@torch.inference_mode()
|
88 |
+
def __call__(
|
89 |
+
self,
|
90 |
+
initial_latent = None,
|
91 |
+
strength = 1.0,
|
92 |
+
num_inference_steps = 25,
|
93 |
+
guidance_scale = 5.0,
|
94 |
+
batch_size = 1,
|
95 |
+
generator = None,
|
96 |
+
prompt_embeds = None,
|
97 |
+
negative_prompt_embeds = None,
|
98 |
+
cross_attention_kwargs = None,
|
99 |
+
same_noise_in_batch = False,
|
100 |
+
progress_tqdm = None,
|
101 |
+
):
|
102 |
+
|
103 |
+
device = self.unet.device
|
104 |
+
|
105 |
+
# Sigmas
|
106 |
+
|
107 |
+
sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength))
|
108 |
+
sigmas = sigmas[-(num_inference_steps + 1):].to(device)
|
109 |
+
|
110 |
+
# Initial latents
|
111 |
+
|
112 |
+
if same_noise_in_batch:
|
113 |
+
noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1)
|
114 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
|
115 |
+
else:
|
116 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
|
117 |
+
noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype)
|
118 |
+
|
119 |
+
latents = initial_latent + noise * sigmas[0].to(initial_latent)
|
120 |
+
|
121 |
+
# Batch
|
122 |
+
|
123 |
+
latents = latents.to(device)
|
124 |
+
prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device)
|
125 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device)
|
126 |
+
|
127 |
+
# Feeds
|
128 |
+
|
129 |
+
sampler_kwargs = dict(
|
130 |
+
cfg_scale=guidance_scale,
|
131 |
+
positive=dict(
|
132 |
+
encoder_hidden_states=prompt_embeds,
|
133 |
+
cross_attention_kwargs=cross_attention_kwargs
|
134 |
+
),
|
135 |
+
negative=dict(
|
136 |
+
encoder_hidden_states=negative_prompt_embeds,
|
137 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
# Sample
|
142 |
+
|
143 |
+
results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
|
144 |
+
|
145 |
+
return results
|
diffusers_helper/utils.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
import einops
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
import safetensors.torch as sf
|
10 |
+
|
11 |
+
|
12 |
+
def write_to_json(data, file_path):
|
13 |
+
temp_file_path = file_path + ".tmp"
|
14 |
+
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
15 |
+
json.dump(data, temp_file, indent=4)
|
16 |
+
os.replace(temp_file_path, file_path)
|
17 |
+
return
|
18 |
+
|
19 |
+
|
20 |
+
def read_from_json(file_path):
|
21 |
+
with open(file_path, 'rt', encoding='utf-8') as file:
|
22 |
+
data = json.load(file)
|
23 |
+
return data
|
24 |
+
|
25 |
+
|
26 |
+
def get_active_parameters(m):
|
27 |
+
return {k:v for k, v in m.named_parameters() if v.requires_grad}
|
28 |
+
|
29 |
+
|
30 |
+
def cast_training_params(m, dtype=torch.float32):
|
31 |
+
for param in m.parameters():
|
32 |
+
if param.requires_grad:
|
33 |
+
param.data = param.to(dtype)
|
34 |
+
return
|
35 |
+
|
36 |
+
|
37 |
+
def set_attr_recursive(obj, attr, value):
|
38 |
+
attrs = attr.split(".")
|
39 |
+
for name in attrs[:-1]:
|
40 |
+
obj = getattr(obj, name)
|
41 |
+
setattr(obj, attrs[-1], value)
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def batch_mixture(a, b, probability_a=0.5, mask_a=None):
|
47 |
+
assert a.shape == b.shape, "Tensors must have the same shape"
|
48 |
+
batch_size = a.size(0)
|
49 |
+
|
50 |
+
if mask_a is None:
|
51 |
+
mask_a = torch.rand(batch_size) < probability_a
|
52 |
+
|
53 |
+
mask_a = mask_a.to(a.device)
|
54 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
55 |
+
result = torch.where(mask_a, a, b)
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def zero_module(module):
|
61 |
+
for p in module.parameters():
|
62 |
+
p.detach().zero_()
|
63 |
+
return module
|
64 |
+
|
65 |
+
|
66 |
+
def load_last_state(model, folder='accelerator_output'):
|
67 |
+
file_pattern = os.path.join(folder, '**', 'model.safetensors')
|
68 |
+
files = glob.glob(file_pattern, recursive=True)
|
69 |
+
|
70 |
+
if not files:
|
71 |
+
print("No model.safetensors files found in the specified folder.")
|
72 |
+
return
|
73 |
+
|
74 |
+
newest_file = max(files, key=os.path.getmtime)
|
75 |
+
state_dict = sf.load_file(newest_file)
|
76 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
77 |
+
|
78 |
+
if missing_keys:
|
79 |
+
print("Missing keys:", missing_keys)
|
80 |
+
if unexpected_keys:
|
81 |
+
print("Unexpected keys:", unexpected_keys)
|
82 |
+
|
83 |
+
print("Loaded model state from:", newest_file)
|
84 |
+
return
|
85 |
+
|
86 |
+
|
87 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
88 |
+
tags = tags_str.split(', ')
|
89 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
90 |
+
prompt = ', '.join(tags)
|
91 |
+
return prompt
|
92 |
+
|
93 |
+
|
94 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10):
|
95 |
+
b, c, t, h, w = x.shape
|
96 |
+
|
97 |
+
per_row = b
|
98 |
+
for p in [6, 5, 4, 3, 2]:
|
99 |
+
if b % p == 0:
|
100 |
+
per_row = p
|
101 |
+
break
|
102 |
+
|
103 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
104 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
105 |
+
x = x.detach().cpu().to(torch.uint8)
|
106 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
107 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
def save_bcthw_as_png(x, output_filename):
|
112 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
113 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
114 |
+
x = x.detach().cpu().to(torch.uint8)
|
115 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
116 |
+
torchvision.io.write_png(x, output_filename)
|
117 |
+
return output_filename
|
118 |
+
|
119 |
+
|
120 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
121 |
+
if tensor1.shape == tensor2.shape:
|
122 |
+
return tensor1 + tensor2
|
123 |
+
|
124 |
+
shape1 = tensor1.shape
|
125 |
+
shape2 = tensor2.shape
|
126 |
+
|
127 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
128 |
+
|
129 |
+
padded_tensor1 = torch.zeros(new_shape)
|
130 |
+
padded_tensor2 = torch.zeros(new_shape)
|
131 |
+
|
132 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
133 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
134 |
+
|
135 |
+
result = padded_tensor1 + padded_tensor2
|
136 |
+
return result
|
diffusers_vdm/attention.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import xformers.ops
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from functools import partial
|
8 |
+
from diffusers_vdm.basics import zero_module, checkpoint, default, make_temporal_window
|
9 |
+
|
10 |
+
|
11 |
+
def sdp(q, k, v, heads):
|
12 |
+
b, _, C = q.shape
|
13 |
+
dim_head = C // heads
|
14 |
+
|
15 |
+
q, k, v = map(
|
16 |
+
lambda t: t.unsqueeze(3)
|
17 |
+
.reshape(b, t.shape[1], heads, dim_head)
|
18 |
+
.permute(0, 2, 1, 3)
|
19 |
+
.reshape(b * heads, t.shape[1], dim_head)
|
20 |
+
.contiguous(),
|
21 |
+
(q, k, v),
|
22 |
+
)
|
23 |
+
|
24 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
25 |
+
|
26 |
+
out = (
|
27 |
+
out.unsqueeze(0)
|
28 |
+
.reshape(b, heads, out.shape[1], dim_head)
|
29 |
+
.permute(0, 2, 1, 3)
|
30 |
+
.reshape(b, out.shape[1], heads * dim_head)
|
31 |
+
)
|
32 |
+
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class RelativePosition(nn.Module):
|
37 |
+
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
|
38 |
+
|
39 |
+
def __init__(self, num_units, max_relative_position):
|
40 |
+
super().__init__()
|
41 |
+
self.num_units = num_units
|
42 |
+
self.max_relative_position = max_relative_position
|
43 |
+
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
|
44 |
+
nn.init.xavier_uniform_(self.embeddings_table)
|
45 |
+
|
46 |
+
def forward(self, length_q, length_k):
|
47 |
+
device = self.embeddings_table.device
|
48 |
+
range_vec_q = torch.arange(length_q, device=device)
|
49 |
+
range_vec_k = torch.arange(length_k, device=device)
|
50 |
+
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
51 |
+
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
|
52 |
+
final_mat = distance_mat_clipped + self.max_relative_position
|
53 |
+
final_mat = final_mat.long()
|
54 |
+
embeddings = self.embeddings_table[final_mat]
|
55 |
+
return embeddings
|
56 |
+
|
57 |
+
|
58 |
+
class CrossAttention(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
|
61 |
+
relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False,
|
62 |
+
image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False,
|
63 |
+
text_context_len=77, temporal_window_for_spatial_self_attention=False):
|
64 |
+
super().__init__()
|
65 |
+
inner_dim = dim_head * heads
|
66 |
+
context_dim = default(context_dim, query_dim)
|
67 |
+
|
68 |
+
self.scale = dim_head**-0.5
|
69 |
+
self.heads = heads
|
70 |
+
self.dim_head = dim_head
|
71 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
72 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
73 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
74 |
+
|
75 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
76 |
+
|
77 |
+
self.is_temporal_attention = temporal_length is not None
|
78 |
+
|
79 |
+
self.relative_position = relative_position
|
80 |
+
if self.relative_position:
|
81 |
+
assert self.is_temporal_attention
|
82 |
+
self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
|
83 |
+
self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
|
84 |
+
|
85 |
+
self.video_length = video_length
|
86 |
+
self.temporal_window_for_spatial_self_attention = temporal_window_for_spatial_self_attention
|
87 |
+
self.temporal_window_type = 'prv'
|
88 |
+
|
89 |
+
self.image_cross_attention = image_cross_attention
|
90 |
+
self.image_cross_attention_scale = image_cross_attention_scale
|
91 |
+
self.text_context_len = text_context_len
|
92 |
+
self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
|
93 |
+
if self.image_cross_attention:
|
94 |
+
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
95 |
+
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
96 |
+
if image_cross_attention_scale_learnable:
|
97 |
+
self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) )
|
98 |
+
|
99 |
+
def forward(self, x, context=None, mask=None):
|
100 |
+
if self.is_temporal_attention:
|
101 |
+
return self.temporal_forward(x, context=context, mask=mask)
|
102 |
+
else:
|
103 |
+
return self.spatial_forward(x, context=context, mask=mask)
|
104 |
+
|
105 |
+
def temporal_forward(self, x, context=None, mask=None):
|
106 |
+
assert mask is None, 'Attention mask not implemented!'
|
107 |
+
assert context is None, 'Temporal attention only supports self attention!'
|
108 |
+
|
109 |
+
q = self.to_q(x)
|
110 |
+
k = self.to_k(x)
|
111 |
+
v = self.to_v(x)
|
112 |
+
|
113 |
+
out = sdp(q, k, v, self.heads)
|
114 |
+
|
115 |
+
return self.to_out(out)
|
116 |
+
|
117 |
+
def spatial_forward(self, x, context=None, mask=None):
|
118 |
+
assert mask is None, 'Attention mask not implemented!'
|
119 |
+
|
120 |
+
spatial_self_attn = (context is None)
|
121 |
+
k_ip, v_ip, out_ip = None, None, None
|
122 |
+
|
123 |
+
q = self.to_q(x)
|
124 |
+
context = default(context, x)
|
125 |
+
|
126 |
+
if spatial_self_attn:
|
127 |
+
k = self.to_k(context)
|
128 |
+
v = self.to_v(context)
|
129 |
+
|
130 |
+
if self.temporal_window_for_spatial_self_attention:
|
131 |
+
k = make_temporal_window(k, t=self.video_length, method=self.temporal_window_type)
|
132 |
+
v = make_temporal_window(v, t=self.video_length, method=self.temporal_window_type)
|
133 |
+
elif self.image_cross_attention:
|
134 |
+
context, context_image = context
|
135 |
+
k = self.to_k(context)
|
136 |
+
v = self.to_v(context)
|
137 |
+
k_ip = self.to_k_ip(context_image)
|
138 |
+
v_ip = self.to_v_ip(context_image)
|
139 |
+
else:
|
140 |
+
raise NotImplementedError('Traditional prompt-only attention without IP-Adapter is illegal now.')
|
141 |
+
|
142 |
+
out = sdp(q, k, v, self.heads)
|
143 |
+
|
144 |
+
if k_ip is not None:
|
145 |
+
out_ip = sdp(q, k_ip, v_ip, self.heads)
|
146 |
+
|
147 |
+
if self.image_cross_attention_scale_learnable:
|
148 |
+
out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha) + 1)
|
149 |
+
else:
|
150 |
+
out = out + self.image_cross_attention_scale * out_ip
|
151 |
+
|
152 |
+
return self.to_out(out)
|
153 |
+
|
154 |
+
|
155 |
+
class BasicTransformerBlock(nn.Module):
|
156 |
+
|
157 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
158 |
+
disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77):
|
159 |
+
super().__init__()
|
160 |
+
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
161 |
+
self.disable_self_attn = disable_self_attn
|
162 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
163 |
+
context_dim=context_dim if self.disable_self_attn else None, video_length=video_length)
|
164 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
165 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len)
|
166 |
+
self.image_cross_attention = image_cross_attention
|
167 |
+
|
168 |
+
self.norm1 = nn.LayerNorm(dim)
|
169 |
+
self.norm2 = nn.LayerNorm(dim)
|
170 |
+
self.norm3 = nn.LayerNorm(dim)
|
171 |
+
self.checkpoint = checkpoint
|
172 |
+
|
173 |
+
|
174 |
+
def forward(self, x, context=None, mask=None, **kwargs):
|
175 |
+
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
176 |
+
input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
177 |
+
if context is not None:
|
178 |
+
input_tuple = (x, context)
|
179 |
+
if mask is not None:
|
180 |
+
forward_mask = partial(self._forward, mask=mask)
|
181 |
+
return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
|
182 |
+
return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint)
|
183 |
+
|
184 |
+
|
185 |
+
def _forward(self, x, context=None, mask=None):
|
186 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
|
187 |
+
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
188 |
+
x = self.ff(self.norm3(x)) + x
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class SpatialTransformer(nn.Module):
|
193 |
+
"""
|
194 |
+
Transformer block for image-like data in spatial axis.
|
195 |
+
First, project the input (aka embedding)
|
196 |
+
and reshape to b, t, d.
|
197 |
+
Then apply standard transformer action.
|
198 |
+
Finally, reshape to image
|
199 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
|
203 |
+
use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None,
|
204 |
+
image_cross_attention=False, image_cross_attention_scale_learnable=False):
|
205 |
+
super().__init__()
|
206 |
+
self.in_channels = in_channels
|
207 |
+
inner_dim = n_heads * d_head
|
208 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
209 |
+
if not use_linear:
|
210 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
211 |
+
else:
|
212 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
213 |
+
|
214 |
+
attention_cls = None
|
215 |
+
self.transformer_blocks = nn.ModuleList([
|
216 |
+
BasicTransformerBlock(
|
217 |
+
inner_dim,
|
218 |
+
n_heads,
|
219 |
+
d_head,
|
220 |
+
dropout=dropout,
|
221 |
+
context_dim=context_dim,
|
222 |
+
disable_self_attn=disable_self_attn,
|
223 |
+
checkpoint=use_checkpoint,
|
224 |
+
attention_cls=attention_cls,
|
225 |
+
video_length=video_length,
|
226 |
+
image_cross_attention=image_cross_attention,
|
227 |
+
image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
|
228 |
+
) for d in range(depth)
|
229 |
+
])
|
230 |
+
if not use_linear:
|
231 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
232 |
+
else:
|
233 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
234 |
+
self.use_linear = use_linear
|
235 |
+
|
236 |
+
|
237 |
+
def forward(self, x, context=None, **kwargs):
|
238 |
+
b, c, h, w = x.shape
|
239 |
+
x_in = x
|
240 |
+
x = self.norm(x)
|
241 |
+
if not self.use_linear:
|
242 |
+
x = self.proj_in(x)
|
243 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
244 |
+
if self.use_linear:
|
245 |
+
x = self.proj_in(x)
|
246 |
+
for i, block in enumerate(self.transformer_blocks):
|
247 |
+
x = block(x, context=context, **kwargs)
|
248 |
+
if self.use_linear:
|
249 |
+
x = self.proj_out(x)
|
250 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
251 |
+
if not self.use_linear:
|
252 |
+
x = self.proj_out(x)
|
253 |
+
return x + x_in
|
254 |
+
|
255 |
+
|
256 |
+
class TemporalTransformer(nn.Module):
|
257 |
+
"""
|
258 |
+
Transformer block for image-like data in temporal axis.
|
259 |
+
First, reshape to b, t, d.
|
260 |
+
Then apply standard transformer action.
|
261 |
+
Finally, reshape to image
|
262 |
+
"""
|
263 |
+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
|
264 |
+
use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1,
|
265 |
+
relative_position=False, temporal_length=None):
|
266 |
+
super().__init__()
|
267 |
+
self.only_self_att = only_self_att
|
268 |
+
self.relative_position = relative_position
|
269 |
+
self.causal_attention = causal_attention
|
270 |
+
self.causal_block_size = causal_block_size
|
271 |
+
|
272 |
+
self.in_channels = in_channels
|
273 |
+
inner_dim = n_heads * d_head
|
274 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
275 |
+
self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
276 |
+
if not use_linear:
|
277 |
+
self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
278 |
+
else:
|
279 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
280 |
+
|
281 |
+
if relative_position:
|
282 |
+
assert(temporal_length is not None)
|
283 |
+
attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
|
284 |
+
else:
|
285 |
+
attention_cls = partial(CrossAttention, temporal_length=temporal_length)
|
286 |
+
if self.causal_attention:
|
287 |
+
assert(temporal_length is not None)
|
288 |
+
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
|
289 |
+
|
290 |
+
if self.only_self_att:
|
291 |
+
context_dim = None
|
292 |
+
self.transformer_blocks = nn.ModuleList([
|
293 |
+
BasicTransformerBlock(
|
294 |
+
inner_dim,
|
295 |
+
n_heads,
|
296 |
+
d_head,
|
297 |
+
dropout=dropout,
|
298 |
+
context_dim=context_dim,
|
299 |
+
attention_cls=attention_cls,
|
300 |
+
checkpoint=use_checkpoint) for d in range(depth)
|
301 |
+
])
|
302 |
+
if not use_linear:
|
303 |
+
self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
304 |
+
else:
|
305 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
306 |
+
self.use_linear = use_linear
|
307 |
+
|
308 |
+
def forward(self, x, context=None):
|
309 |
+
b, c, t, h, w = x.shape
|
310 |
+
x_in = x
|
311 |
+
x = self.norm(x)
|
312 |
+
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
|
313 |
+
if not self.use_linear:
|
314 |
+
x = self.proj_in(x)
|
315 |
+
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
|
316 |
+
if self.use_linear:
|
317 |
+
x = self.proj_in(x)
|
318 |
+
|
319 |
+
temp_mask = None
|
320 |
+
if self.causal_attention:
|
321 |
+
# slice the from mask map
|
322 |
+
temp_mask = self.mask[:,:t,:t].to(x.device)
|
323 |
+
|
324 |
+
if temp_mask is not None:
|
325 |
+
mask = temp_mask.to(x.device)
|
326 |
+
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
|
327 |
+
else:
|
328 |
+
mask = None
|
329 |
+
|
330 |
+
if self.only_self_att:
|
331 |
+
## note: if no context is given, cross-attention defaults to self-attention
|
332 |
+
for i, block in enumerate(self.transformer_blocks):
|
333 |
+
x = block(x, mask=mask)
|
334 |
+
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
335 |
+
else:
|
336 |
+
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
337 |
+
context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
|
338 |
+
for i, block in enumerate(self.transformer_blocks):
|
339 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
340 |
+
for j in range(b):
|
341 |
+
context_j = repeat(
|
342 |
+
context[j],
|
343 |
+
't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
|
344 |
+
## note: causal mask will not applied in cross-attention case
|
345 |
+
x[j] = block(x[j], context=context_j)
|
346 |
+
|
347 |
+
if self.use_linear:
|
348 |
+
x = self.proj_out(x)
|
349 |
+
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
|
350 |
+
if not self.use_linear:
|
351 |
+
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
|
352 |
+
x = self.proj_out(x)
|
353 |
+
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
|
354 |
+
|
355 |
+
return x + x_in
|
356 |
+
|
357 |
+
|
358 |
+
class GEGLU(nn.Module):
|
359 |
+
def __init__(self, dim_in, dim_out):
|
360 |
+
super().__init__()
|
361 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
365 |
+
return x * F.gelu(gate)
|
366 |
+
|
367 |
+
|
368 |
+
class FeedForward(nn.Module):
|
369 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
370 |
+
super().__init__()
|
371 |
+
inner_dim = int(dim * mult)
|
372 |
+
dim_out = default(dim_out, dim)
|
373 |
+
project_in = nn.Sequential(
|
374 |
+
nn.Linear(dim, inner_dim),
|
375 |
+
nn.GELU()
|
376 |
+
) if not glu else GEGLU(dim, inner_dim)
|
377 |
+
|
378 |
+
self.net = nn.Sequential(
|
379 |
+
project_in,
|
380 |
+
nn.Dropout(dropout),
|
381 |
+
nn.Linear(inner_dim, dim_out)
|
382 |
+
)
|
383 |
+
|
384 |
+
def forward(self, x):
|
385 |
+
return self.net(x)
|
diffusers_vdm/basics.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import einops
|
14 |
+
|
15 |
+
from inspect import isfunction
|
16 |
+
|
17 |
+
|
18 |
+
def zero_module(module):
|
19 |
+
"""
|
20 |
+
Zero out the parameters of a module and return it.
|
21 |
+
"""
|
22 |
+
for p in module.parameters():
|
23 |
+
p.detach().zero_()
|
24 |
+
return module
|
25 |
+
|
26 |
+
def scale_module(module, scale):
|
27 |
+
"""
|
28 |
+
Scale the parameters of a module and return it.
|
29 |
+
"""
|
30 |
+
for p in module.parameters():
|
31 |
+
p.detach().mul_(scale)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
def conv_nd(dims, *args, **kwargs):
|
36 |
+
"""
|
37 |
+
Create a 1D, 2D, or 3D convolution module.
|
38 |
+
"""
|
39 |
+
if dims == 1:
|
40 |
+
return nn.Conv1d(*args, **kwargs)
|
41 |
+
elif dims == 2:
|
42 |
+
return nn.Conv2d(*args, **kwargs)
|
43 |
+
elif dims == 3:
|
44 |
+
return nn.Conv3d(*args, **kwargs)
|
45 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
46 |
+
|
47 |
+
|
48 |
+
def linear(*args, **kwargs):
|
49 |
+
"""
|
50 |
+
Create a linear module.
|
51 |
+
"""
|
52 |
+
return nn.Linear(*args, **kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
56 |
+
"""
|
57 |
+
Create a 1D, 2D, or 3D average pooling module.
|
58 |
+
"""
|
59 |
+
if dims == 1:
|
60 |
+
return nn.AvgPool1d(*args, **kwargs)
|
61 |
+
elif dims == 2:
|
62 |
+
return nn.AvgPool2d(*args, **kwargs)
|
63 |
+
elif dims == 3:
|
64 |
+
return nn.AvgPool3d(*args, **kwargs)
|
65 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
66 |
+
|
67 |
+
|
68 |
+
def nonlinearity(type='silu'):
|
69 |
+
if type == 'silu':
|
70 |
+
return nn.SiLU()
|
71 |
+
elif type == 'leaky_relu':
|
72 |
+
return nn.LeakyReLU()
|
73 |
+
|
74 |
+
|
75 |
+
def normalization(channels, num_groups=32):
|
76 |
+
"""
|
77 |
+
Make a standard normalization layer.
|
78 |
+
:param channels: number of input channels.
|
79 |
+
:return: an nn.Module for normalization.
|
80 |
+
"""
|
81 |
+
return nn.GroupNorm(num_groups, channels)
|
82 |
+
|
83 |
+
|
84 |
+
def default(val, d):
|
85 |
+
if exists(val):
|
86 |
+
return val
|
87 |
+
return d() if isfunction(d) else d
|
88 |
+
|
89 |
+
|
90 |
+
def exists(val):
|
91 |
+
return val is not None
|
92 |
+
|
93 |
+
|
94 |
+
def extract_into_tensor(a, t, x_shape):
|
95 |
+
b, *_ = t.shape
|
96 |
+
out = a.gather(-1, t)
|
97 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
98 |
+
|
99 |
+
|
100 |
+
def make_temporal_window(x, t, method):
|
101 |
+
assert method in ['roll', 'prv', 'first']
|
102 |
+
|
103 |
+
if method == 'roll':
|
104 |
+
m = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
105 |
+
l = torch.roll(m, shifts=1, dims=1)
|
106 |
+
r = torch.roll(m, shifts=-1, dims=1)
|
107 |
+
|
108 |
+
recon = torch.cat([l, m, r], dim=2)
|
109 |
+
del l, m, r
|
110 |
+
|
111 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
112 |
+
return recon
|
113 |
+
|
114 |
+
if method == 'prv':
|
115 |
+
x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
116 |
+
prv = torch.cat([x[:, :1], x[:, :-1]], dim=1)
|
117 |
+
|
118 |
+
recon = torch.cat([x, prv], dim=2)
|
119 |
+
del x, prv
|
120 |
+
|
121 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
122 |
+
return recon
|
123 |
+
|
124 |
+
if method == 'first':
|
125 |
+
x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
126 |
+
prv = x[:, [0], :, :].repeat(1, t, 1, 1)
|
127 |
+
|
128 |
+
recon = torch.cat([x, prv], dim=2)
|
129 |
+
del x, prv
|
130 |
+
|
131 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
132 |
+
return recon
|
133 |
+
|
134 |
+
|
135 |
+
def checkpoint(func, inputs, params, flag):
|
136 |
+
"""
|
137 |
+
Evaluate a function without caching intermediate activations, allowing for
|
138 |
+
reduced memory at the expense of extra compute in the backward pass.
|
139 |
+
:param func: the function to evaluate.
|
140 |
+
:param inputs: the argument sequence to pass to `func`.
|
141 |
+
:param params: a sequence of parameters `func` depends on but does not
|
142 |
+
explicitly take as arguments.
|
143 |
+
:param flag: if False, disable gradient checkpointing.
|
144 |
+
"""
|
145 |
+
if flag:
|
146 |
+
return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False)
|
147 |
+
else:
|
148 |
+
return func(*inputs)
|
diffusers_vdm/dynamic_tsnr_sampler.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# everything that can improve v-prediction model
|
2 |
+
# dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
|
3 |
+
# written by lvmin at stanford 2024
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from functools import partial
|
10 |
+
from diffusers_vdm.basics import extract_into_tensor
|
11 |
+
|
12 |
+
|
13 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
14 |
+
|
15 |
+
|
16 |
+
def rescale_zero_terminal_snr(betas):
|
17 |
+
# Convert betas to alphas_bar_sqrt
|
18 |
+
alphas = 1.0 - betas
|
19 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
20 |
+
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
|
21 |
+
|
22 |
+
# Store old values.
|
23 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
|
24 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
|
25 |
+
|
26 |
+
# Shift so the last timestep is zero.
|
27 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
28 |
+
|
29 |
+
# Scale so the first timestep is back to the old value.
|
30 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
31 |
+
|
32 |
+
# Convert alphas_bar_sqrt to betas
|
33 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
34 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
35 |
+
alphas = np.concatenate([alphas_bar[0:1], alphas])
|
36 |
+
betas = 1 - alphas
|
37 |
+
|
38 |
+
return betas
|
39 |
+
|
40 |
+
|
41 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
42 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
43 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
44 |
+
|
45 |
+
# rescale the results from guidance (fixes overexposure)
|
46 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
47 |
+
|
48 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
49 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
50 |
+
|
51 |
+
return noise_cfg
|
52 |
+
|
53 |
+
|
54 |
+
class SamplerDynamicTSNR(torch.nn.Module):
|
55 |
+
@torch.no_grad()
|
56 |
+
def __init__(self, unet, terminal_scale=0.7):
|
57 |
+
super().__init__()
|
58 |
+
self.unet = unet
|
59 |
+
|
60 |
+
self.is_v = True
|
61 |
+
self.n_timestep = 1000
|
62 |
+
self.guidance_rescale = 0.7
|
63 |
+
|
64 |
+
linear_start = 0.00085
|
65 |
+
linear_end = 0.012
|
66 |
+
|
67 |
+
betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
|
68 |
+
betas = rescale_zero_terminal_snr(betas)
|
69 |
+
alphas = 1. - betas
|
70 |
+
|
71 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
72 |
+
|
73 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
|
74 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
|
75 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))
|
76 |
+
|
77 |
+
# Dynamic TSNR
|
78 |
+
turning_step = 400
|
79 |
+
scale_arr = np.concatenate([
|
80 |
+
np.linspace(1.0, terminal_scale, turning_step),
|
81 |
+
np.full(self.n_timestep - turning_step, terminal_scale)
|
82 |
+
])
|
83 |
+
self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))
|
84 |
+
|
85 |
+
def predict_eps_from_z_and_v(self, x_t, t, v):
|
86 |
+
return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t
|
87 |
+
|
88 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
89 |
+
return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v
|
90 |
+
|
91 |
+
def q_sample(self, x0, t, noise):
|
92 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
93 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
94 |
+
|
95 |
+
def get_v(self, x0, t, noise):
|
96 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
|
97 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)
|
98 |
+
|
99 |
+
def dynamic_x0_rescale(self, x0, t):
|
100 |
+
return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def get_ground_truth(self, x0, noise, t):
|
104 |
+
x0 = self.dynamic_x0_rescale(x0, t)
|
105 |
+
xt = self.q_sample(x0, t, noise)
|
106 |
+
target = self.get_v(x0, t, noise) if self.is_v else noise
|
107 |
+
return xt, target
|
108 |
+
|
109 |
+
def get_uniform_trailing_steps(self, steps):
|
110 |
+
c = self.n_timestep / steps
|
111 |
+
ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
|
112 |
+
steps_out = ddim_timesteps - 1
|
113 |
+
return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
|
117 |
+
bar = tqdm if progress_tqdm is None else progress_tqdm
|
118 |
+
|
119 |
+
eta = 1.0
|
120 |
+
|
121 |
+
timesteps = self.get_uniform_trailing_steps(steps)
|
122 |
+
timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))
|
123 |
+
|
124 |
+
x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)
|
125 |
+
|
126 |
+
alphas = self.alphas_cumprod[timesteps]
|
127 |
+
alphas_prev = self.alphas_cumprod[timesteps_prev]
|
128 |
+
scale_arr = self.scale_arr[timesteps]
|
129 |
+
scale_arr_prev = self.scale_arr[timesteps_prev]
|
130 |
+
|
131 |
+
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
132 |
+
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
133 |
+
|
134 |
+
s_in = x.new_ones((x.shape[0]))
|
135 |
+
s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
|
136 |
+
for i in bar(range(len(timesteps))):
|
137 |
+
index = len(timesteps) - 1 - i
|
138 |
+
t = timesteps[index].item()
|
139 |
+
|
140 |
+
model_output = self.model_apply(x, t * s_in, **extra_args)
|
141 |
+
|
142 |
+
if self.is_v:
|
143 |
+
e_t = self.predict_eps_from_z_and_v(x, t, model_output)
|
144 |
+
else:
|
145 |
+
e_t = model_output
|
146 |
+
|
147 |
+
a_prev = alphas_prev[index].item() * s_x
|
148 |
+
sigma_t = sigmas[index].item() * s_x
|
149 |
+
|
150 |
+
if self.is_v:
|
151 |
+
pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
|
152 |
+
else:
|
153 |
+
a_t = alphas[index].item() * s_x
|
154 |
+
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
155 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
156 |
+
|
157 |
+
# dynamic rescale
|
158 |
+
scale_t = scale_arr[index].item() * s_x
|
159 |
+
prev_scale_t = scale_arr_prev[index].item() * s_x
|
160 |
+
rescale = (prev_scale_t / scale_t)
|
161 |
+
pred_x0 = pred_x0 * rescale
|
162 |
+
|
163 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
164 |
+
noise = sigma_t * torch.randn_like(x)
|
165 |
+
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def model_apply(self, x, t, **extra_args):
|
171 |
+
x = x.to(device=self.unet.device, dtype=self.unet.dtype)
|
172 |
+
cfg_scale = extra_args['cfg_scale']
|
173 |
+
p = self.unet(x, t, **extra_args['positive'])
|
174 |
+
n = self.unet(x, t, **extra_args['negative'])
|
175 |
+
o = n + cfg_scale * (p - n)
|
176 |
+
o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
|
177 |
+
return o_better
|
diffusers_vdm/improved_clip_vision.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
|
2 |
+
# The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import types
|
6 |
+
import einops
|
7 |
+
|
8 |
+
from abc import ABCMeta
|
9 |
+
from transformers import CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
|
12 |
+
def preprocess(image):
|
13 |
+
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
|
14 |
+
std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]
|
15 |
+
|
16 |
+
scale = 16 / min(image.shape[2], image.shape[3])
|
17 |
+
image = torch.nn.functional.interpolate(
|
18 |
+
image,
|
19 |
+
size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
|
20 |
+
mode="bicubic",
|
21 |
+
antialias=True
|
22 |
+
)
|
23 |
+
|
24 |
+
return (image - mean) / std
|
25 |
+
|
26 |
+
|
27 |
+
def arbitrary_positional_encoding(p, H, W):
|
28 |
+
weight = p.weight
|
29 |
+
cls = weight[:1]
|
30 |
+
pos = weight[1:]
|
31 |
+
pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
|
32 |
+
pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
|
33 |
+
pos = einops.rearrange(pos, '1 C H W -> (H W) C')
|
34 |
+
weight = torch.cat([cls, pos])[None]
|
35 |
+
return weight
|
36 |
+
|
37 |
+
|
38 |
+
def improved_clipvision_embedding_forward(self, pixel_values):
|
39 |
+
pixel_values = pixel_values * 0.5 + 0.5
|
40 |
+
pixel_values = preprocess(pixel_values)
|
41 |
+
batch_size = pixel_values.shape[0]
|
42 |
+
target_dtype = self.patch_embedding.weight.dtype
|
43 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
44 |
+
B, C, H, W = patch_embeds.shape
|
45 |
+
patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
|
46 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
47 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
48 |
+
embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
|
49 |
+
return embeddings
|
50 |
+
|
51 |
+
|
52 |
+
class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__(config)
|
55 |
+
self.vision_model.embeddings.forward = types.MethodType(
|
56 |
+
improved_clipvision_embedding_forward,
|
57 |
+
self.vision_model.embeddings
|
58 |
+
)
|
diffusers_vdm/pipeline.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import einops
|
4 |
+
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
+
from diffusers_vdm.vae import VideoAutoencoderKL
|
9 |
+
from diffusers_vdm.projection import Resampler
|
10 |
+
from diffusers_vdm.unet import UNet3DModel
|
11 |
+
from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection
|
12 |
+
from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR
|
13 |
+
|
14 |
+
|
15 |
+
class LatentVideoDiffusionPipeline(DiffusionPipeline):
|
16 |
+
def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.loading_components = dict(
|
20 |
+
vae=vae,
|
21 |
+
text_encoder=text_encoder,
|
22 |
+
tokenizer=tokenizer,
|
23 |
+
unet=unet,
|
24 |
+
image_encoder=image_encoder,
|
25 |
+
image_projection=image_projection
|
26 |
+
)
|
27 |
+
|
28 |
+
for k, v in self.loading_components.items():
|
29 |
+
setattr(self, k, v)
|
30 |
+
|
31 |
+
if fp16:
|
32 |
+
self.vae.half()
|
33 |
+
self.text_encoder.half()
|
34 |
+
self.unet.half()
|
35 |
+
self.image_encoder.half()
|
36 |
+
self.image_projection.half()
|
37 |
+
|
38 |
+
self.vae.requires_grad_(False)
|
39 |
+
self.text_encoder.requires_grad_(False)
|
40 |
+
self.image_encoder.requires_grad_(False)
|
41 |
+
|
42 |
+
self.vae.eval()
|
43 |
+
self.text_encoder.eval()
|
44 |
+
self.image_encoder.eval()
|
45 |
+
|
46 |
+
if eval:
|
47 |
+
self.unet.eval()
|
48 |
+
self.image_projection.eval()
|
49 |
+
else:
|
50 |
+
self.unet.train()
|
51 |
+
self.image_projection.train()
|
52 |
+
|
53 |
+
def to(self, *args, **kwargs):
|
54 |
+
for k, v in self.loading_components.items():
|
55 |
+
if hasattr(v, 'to'):
|
56 |
+
v.to(*args, **kwargs)
|
57 |
+
return self
|
58 |
+
|
59 |
+
def save_pretrained(self, save_directory, **kwargs):
|
60 |
+
for k, v in self.loading_components.items():
|
61 |
+
folder = os.path.join(save_directory, k)
|
62 |
+
os.makedirs(folder, exist_ok=True)
|
63 |
+
v.save_pretrained(folder)
|
64 |
+
return
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None):
|
68 |
+
local_folder = snapshot_download(repo_id=repo_id, token=token)
|
69 |
+
return cls(
|
70 |
+
tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")),
|
71 |
+
text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")),
|
72 |
+
image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")),
|
73 |
+
vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")),
|
74 |
+
image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")),
|
75 |
+
unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")),
|
76 |
+
fp16=fp16,
|
77 |
+
eval=eval
|
78 |
+
)
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
def encode_cropped_prompt_77tokens(self, prompt: str):
|
82 |
+
cond_ids = self.tokenizer(prompt,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.tokenizer.model_max_length,
|
85 |
+
truncation=True,
|
86 |
+
return_tensors="pt").input_ids.to(self.text_encoder.device)
|
87 |
+
cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
88 |
+
return cond
|
89 |
+
|
90 |
+
@torch.inference_mode()
|
91 |
+
def encode_clip_vision(self, frames):
|
92 |
+
b, c, t, h, w = frames.shape
|
93 |
+
frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w')
|
94 |
+
clipvision_embed = self.image_encoder(frames).last_hidden_state
|
95 |
+
clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t)
|
96 |
+
return clipvision_embed
|
97 |
+
|
98 |
+
@torch.inference_mode()
|
99 |
+
def encode_latents(self, videos, return_hidden_states=True):
|
100 |
+
b, c, t, h, w = videos.shape
|
101 |
+
x = einops.rearrange(videos, 'b c t h w -> (b t) c h w')
|
102 |
+
encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states)
|
103 |
+
z = encoder_posterior.mode() * self.vae.scale_factor
|
104 |
+
z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
105 |
+
|
106 |
+
if not return_hidden_states:
|
107 |
+
return z
|
108 |
+
|
109 |
+
hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states]
|
110 |
+
hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last
|
111 |
+
|
112 |
+
return z, hidden_states
|
113 |
+
|
114 |
+
@torch.inference_mode()
|
115 |
+
def decode_latents(self, latents, hidden_states):
|
116 |
+
B, C, T, H, W = latents.shape
|
117 |
+
latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w')
|
118 |
+
latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor
|
119 |
+
pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T)
|
120 |
+
pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T)
|
121 |
+
return pixels
|
122 |
+
|
123 |
+
@torch.inference_mode()
|
124 |
+
def __call__(
|
125 |
+
self,
|
126 |
+
batch_size: int = 1,
|
127 |
+
steps: int = 50,
|
128 |
+
guidance_scale: float = 5.0,
|
129 |
+
positive_text_cond = None,
|
130 |
+
negative_text_cond = None,
|
131 |
+
positive_image_cond = None,
|
132 |
+
negative_image_cond = None,
|
133 |
+
concat_cond = None,
|
134 |
+
fs = 3,
|
135 |
+
progress_tqdm = None,
|
136 |
+
):
|
137 |
+
unet_is_training = self.unet.training
|
138 |
+
|
139 |
+
if unet_is_training:
|
140 |
+
self.unet.eval()
|
141 |
+
|
142 |
+
device = self.unet.device
|
143 |
+
dtype = self.unet.dtype
|
144 |
+
dynamic_tsnr_model = SamplerDynamicTSNR(self.unet)
|
145 |
+
|
146 |
+
# Batch
|
147 |
+
|
148 |
+
concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w
|
149 |
+
positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
|
150 |
+
negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
|
151 |
+
positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c
|
152 |
+
negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond)
|
153 |
+
|
154 |
+
if isinstance(fs, torch.Tensor):
|
155 |
+
fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b
|
156 |
+
else:
|
157 |
+
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b
|
158 |
+
|
159 |
+
# Initial latents
|
160 |
+
|
161 |
+
latent_shape = concat_cond.shape
|
162 |
+
|
163 |
+
# Feeds
|
164 |
+
|
165 |
+
sampler_kwargs = dict(
|
166 |
+
cfg_scale=guidance_scale,
|
167 |
+
positive=dict(
|
168 |
+
context_text=positive_text_cond,
|
169 |
+
context_img=positive_image_cond,
|
170 |
+
fs=fs,
|
171 |
+
concat_cond=concat_cond
|
172 |
+
),
|
173 |
+
negative=dict(
|
174 |
+
context_text=negative_text_cond,
|
175 |
+
context_img=negative_image_cond,
|
176 |
+
fs=fs,
|
177 |
+
concat_cond=concat_cond
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
# Sample
|
182 |
+
|
183 |
+
results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
|
184 |
+
|
185 |
+
if unet_is_training:
|
186 |
+
self.unet.train()
|
187 |
+
|
188 |
+
return results
|
diffusers_vdm/projection.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
4 |
+
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import einops
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
|
13 |
+
|
14 |
+
class ImageProjModel(nn.Module):
|
15 |
+
"""Projection Model"""
|
16 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
17 |
+
super().__init__()
|
18 |
+
self.cross_attention_dim = cross_attention_dim
|
19 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
20 |
+
self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
21 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
22 |
+
|
23 |
+
def forward(self, image_embeds):
|
24 |
+
#embeds = image_embeds
|
25 |
+
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
|
26 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
27 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
28 |
+
return clip_extra_context_tokens
|
29 |
+
|
30 |
+
|
31 |
+
# FFN
|
32 |
+
def FeedForward(dim, mult=4):
|
33 |
+
inner_dim = int(dim * mult)
|
34 |
+
return nn.Sequential(
|
35 |
+
nn.LayerNorm(dim),
|
36 |
+
nn.Linear(dim, inner_dim, bias=False),
|
37 |
+
nn.GELU(),
|
38 |
+
nn.Linear(inner_dim, dim, bias=False),
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def reshape_tensor(x, heads):
|
43 |
+
bs, length, width = x.shape
|
44 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
45 |
+
x = x.view(bs, length, heads, -1)
|
46 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
47 |
+
x = x.transpose(1, 2)
|
48 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
49 |
+
x = x.reshape(bs, heads, length, -1)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class PerceiverAttention(nn.Module):
|
54 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
55 |
+
super().__init__()
|
56 |
+
self.scale = dim_head**-0.5
|
57 |
+
self.dim_head = dim_head
|
58 |
+
self.heads = heads
|
59 |
+
inner_dim = dim_head * heads
|
60 |
+
|
61 |
+
self.norm1 = nn.LayerNorm(dim)
|
62 |
+
self.norm2 = nn.LayerNorm(dim)
|
63 |
+
|
64 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
65 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
66 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
67 |
+
|
68 |
+
|
69 |
+
def forward(self, x, latents):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
x (torch.Tensor): image features
|
73 |
+
shape (b, n1, D)
|
74 |
+
latent (torch.Tensor): latent features
|
75 |
+
shape (b, n2, D)
|
76 |
+
"""
|
77 |
+
x = self.norm1(x)
|
78 |
+
latents = self.norm2(latents)
|
79 |
+
|
80 |
+
b, l, _ = latents.shape
|
81 |
+
|
82 |
+
q = self.to_q(latents)
|
83 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
84 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
85 |
+
|
86 |
+
q = reshape_tensor(q, self.heads)
|
87 |
+
k = reshape_tensor(k, self.heads)
|
88 |
+
v = reshape_tensor(v, self.heads)
|
89 |
+
|
90 |
+
# attention
|
91 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
92 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
93 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
94 |
+
out = weight @ v
|
95 |
+
|
96 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
97 |
+
|
98 |
+
return self.to_out(out)
|
99 |
+
|
100 |
+
|
101 |
+
class Resampler(nn.Module, PyTorchModelHubMixin):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
dim=1024,
|
105 |
+
depth=8,
|
106 |
+
dim_head=64,
|
107 |
+
heads=16,
|
108 |
+
num_queries=8,
|
109 |
+
embedding_dim=768,
|
110 |
+
output_dim=1024,
|
111 |
+
ff_mult=4,
|
112 |
+
video_length=16,
|
113 |
+
input_frames_length=2,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.num_queries = num_queries
|
117 |
+
self.video_length = video_length
|
118 |
+
|
119 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries * video_length, dim) / dim**0.5)
|
120 |
+
self.input_pos = nn.Parameter(torch.zeros(1, input_frames_length, 1, embedding_dim))
|
121 |
+
|
122 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
123 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
124 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
125 |
+
|
126 |
+
self.layers = nn.ModuleList([])
|
127 |
+
for _ in range(depth):
|
128 |
+
self.layers.append(
|
129 |
+
nn.ModuleList(
|
130 |
+
[
|
131 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
132 |
+
FeedForward(dim=dim, mult=ff_mult),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
139 |
+
|
140 |
+
x = x + self.input_pos
|
141 |
+
x = einops.rearrange(x, 'b ti d c -> b (ti d) c')
|
142 |
+
x = self.proj_in(x)
|
143 |
+
|
144 |
+
for attn, ff in self.layers:
|
145 |
+
latents = attn(x, latents) + latents
|
146 |
+
latents = ff(latents) + latents
|
147 |
+
|
148 |
+
latents = self.proj_out(latents)
|
149 |
+
latents = self.norm_out(latents)
|
150 |
+
|
151 |
+
latents = einops.rearrange(latents, 'b (to l) c -> b to l c', to=self.video_length)
|
152 |
+
return latents
|
153 |
+
|
154 |
+
@property
|
155 |
+
def device(self):
|
156 |
+
return next(self.parameters()).device
|
157 |
+
|
158 |
+
@property
|
159 |
+
def dtype(self):
|
160 |
+
return next(self.parameters()).dtype
|
diffusers_vdm/unet.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/AILab-CVC/VideoCrafter
|
2 |
+
# https://github.com/Doubiiu/DynamiCrafter
|
3 |
+
# https://github.com/ToonCrafter/ToonCrafter
|
4 |
+
# Then edited by lllyasviel
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from abc import abstractmethod
|
8 |
+
import torch
|
9 |
+
import math
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from diffusers_vdm.basics import checkpoint
|
14 |
+
from diffusers_vdm.basics import (
|
15 |
+
zero_module,
|
16 |
+
conv_nd,
|
17 |
+
linear,
|
18 |
+
avg_pool_nd,
|
19 |
+
normalization
|
20 |
+
)
|
21 |
+
from diffusers_vdm.attention import SpatialTransformer, TemporalTransformer
|
22 |
+
from huggingface_hub import PyTorchModelHubMixin
|
23 |
+
|
24 |
+
|
25 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
26 |
+
"""
|
27 |
+
Create sinusoidal timestep embeddings.
|
28 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
29 |
+
These may be fractional.
|
30 |
+
:param dim: the dimension of the output.
|
31 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
32 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
33 |
+
"""
|
34 |
+
if not repeat_only:
|
35 |
+
half = dim // 2
|
36 |
+
freqs = torch.exp(
|
37 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
38 |
+
).to(device=timesteps.device)
|
39 |
+
args = timesteps[:, None].float() * freqs[None]
|
40 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
41 |
+
if dim % 2:
|
42 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
43 |
+
else:
|
44 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
45 |
+
return embedding
|
46 |
+
|
47 |
+
|
48 |
+
class TimestepBlock(nn.Module):
|
49 |
+
"""
|
50 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
51 |
+
"""
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def forward(self, x, emb):
|
55 |
+
"""
|
56 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
57 |
+
"""
|
58 |
+
|
59 |
+
|
60 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
61 |
+
"""
|
62 |
+
A sequential module that passes timestep embeddings to the children that
|
63 |
+
support it as an extra input.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def forward(self, x, emb, context=None, batch_size=None):
|
67 |
+
for layer in self:
|
68 |
+
if isinstance(layer, TimestepBlock):
|
69 |
+
x = layer(x, emb, batch_size=batch_size)
|
70 |
+
elif isinstance(layer, SpatialTransformer):
|
71 |
+
x = layer(x, context)
|
72 |
+
elif isinstance(layer, TemporalTransformer):
|
73 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
|
74 |
+
x = layer(x, context)
|
75 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
76 |
+
else:
|
77 |
+
x = layer(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Downsample(nn.Module):
|
82 |
+
"""
|
83 |
+
A downsampling layer with an optional convolution.
|
84 |
+
:param channels: channels in the inputs and outputs.
|
85 |
+
:param use_conv: a bool determining if a convolution is applied.
|
86 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
87 |
+
downsampling occurs in the inner-two dimensions.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
91 |
+
super().__init__()
|
92 |
+
self.channels = channels
|
93 |
+
self.out_channels = out_channels or channels
|
94 |
+
self.use_conv = use_conv
|
95 |
+
self.dims = dims
|
96 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
97 |
+
if use_conv:
|
98 |
+
self.op = conv_nd(
|
99 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
assert self.channels == self.out_channels
|
103 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
assert x.shape[1] == self.channels
|
107 |
+
return self.op(x)
|
108 |
+
|
109 |
+
|
110 |
+
class Upsample(nn.Module):
|
111 |
+
"""
|
112 |
+
An upsampling layer with an optional convolution.
|
113 |
+
:param channels: channels in the inputs and outputs.
|
114 |
+
:param use_conv: a bool determining if a convolution is applied.
|
115 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
116 |
+
upsampling occurs in the inner-two dimensions.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
120 |
+
super().__init__()
|
121 |
+
self.channels = channels
|
122 |
+
self.out_channels = out_channels or channels
|
123 |
+
self.use_conv = use_conv
|
124 |
+
self.dims = dims
|
125 |
+
if use_conv:
|
126 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
assert x.shape[1] == self.channels
|
130 |
+
if self.dims == 3:
|
131 |
+
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
|
132 |
+
else:
|
133 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
134 |
+
if self.use_conv:
|
135 |
+
x = self.conv(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class ResBlock(TimestepBlock):
|
140 |
+
"""
|
141 |
+
A residual block that can optionally change the number of channels.
|
142 |
+
:param channels: the number of input channels.
|
143 |
+
:param emb_channels: the number of timestep embedding channels.
|
144 |
+
:param dropout: the rate of dropout.
|
145 |
+
:param out_channels: if specified, the number of out channels.
|
146 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
147 |
+
convolution instead of a smaller 1x1 convolution to change the
|
148 |
+
channels in the skip connection.
|
149 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
150 |
+
:param up: if True, use this block for upsampling.
|
151 |
+
:param down: if True, use this block for downsampling.
|
152 |
+
:param use_temporal_conv: if True, use the temporal convolution.
|
153 |
+
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
channels,
|
159 |
+
emb_channels,
|
160 |
+
dropout,
|
161 |
+
out_channels=None,
|
162 |
+
use_scale_shift_norm=False,
|
163 |
+
dims=2,
|
164 |
+
use_checkpoint=False,
|
165 |
+
use_conv=False,
|
166 |
+
up=False,
|
167 |
+
down=False,
|
168 |
+
use_temporal_conv=False,
|
169 |
+
tempspatial_aware=False
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.channels = channels
|
173 |
+
self.emb_channels = emb_channels
|
174 |
+
self.dropout = dropout
|
175 |
+
self.out_channels = out_channels or channels
|
176 |
+
self.use_conv = use_conv
|
177 |
+
self.use_checkpoint = use_checkpoint
|
178 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
179 |
+
self.use_temporal_conv = use_temporal_conv
|
180 |
+
|
181 |
+
self.in_layers = nn.Sequential(
|
182 |
+
normalization(channels),
|
183 |
+
nn.SiLU(),
|
184 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
185 |
+
)
|
186 |
+
|
187 |
+
self.updown = up or down
|
188 |
+
|
189 |
+
if up:
|
190 |
+
self.h_upd = Upsample(channels, False, dims)
|
191 |
+
self.x_upd = Upsample(channels, False, dims)
|
192 |
+
elif down:
|
193 |
+
self.h_upd = Downsample(channels, False, dims)
|
194 |
+
self.x_upd = Downsample(channels, False, dims)
|
195 |
+
else:
|
196 |
+
self.h_upd = self.x_upd = nn.Identity()
|
197 |
+
|
198 |
+
self.emb_layers = nn.Sequential(
|
199 |
+
nn.SiLU(),
|
200 |
+
nn.Linear(
|
201 |
+
emb_channels,
|
202 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
203 |
+
),
|
204 |
+
)
|
205 |
+
self.out_layers = nn.Sequential(
|
206 |
+
normalization(self.out_channels),
|
207 |
+
nn.SiLU(),
|
208 |
+
nn.Dropout(p=dropout),
|
209 |
+
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
210 |
+
)
|
211 |
+
|
212 |
+
if self.out_channels == channels:
|
213 |
+
self.skip_connection = nn.Identity()
|
214 |
+
elif use_conv:
|
215 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
216 |
+
else:
|
217 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
218 |
+
|
219 |
+
if self.use_temporal_conv:
|
220 |
+
self.temopral_conv = TemporalConvBlock(
|
221 |
+
self.out_channels,
|
222 |
+
self.out_channels,
|
223 |
+
dropout=0.1,
|
224 |
+
spatial_aware=tempspatial_aware
|
225 |
+
)
|
226 |
+
|
227 |
+
def forward(self, x, emb, batch_size=None):
|
228 |
+
"""
|
229 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
230 |
+
:param x: an [N x C x ...] Tensor of features.
|
231 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
232 |
+
:return: an [N x C x ...] Tensor of outputs.
|
233 |
+
"""
|
234 |
+
input_tuple = (x, emb)
|
235 |
+
if batch_size:
|
236 |
+
forward_batchsize = partial(self._forward, batch_size=batch_size)
|
237 |
+
return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
|
238 |
+
return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
|
239 |
+
|
240 |
+
def _forward(self, x, emb, batch_size=None):
|
241 |
+
if self.updown:
|
242 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
243 |
+
h = in_rest(x)
|
244 |
+
h = self.h_upd(h)
|
245 |
+
x = self.x_upd(x)
|
246 |
+
h = in_conv(h)
|
247 |
+
else:
|
248 |
+
h = self.in_layers(x)
|
249 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
250 |
+
while len(emb_out.shape) < len(h.shape):
|
251 |
+
emb_out = emb_out[..., None]
|
252 |
+
if self.use_scale_shift_norm:
|
253 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
254 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
255 |
+
h = out_norm(h) * (1 + scale) + shift
|
256 |
+
h = out_rest(h)
|
257 |
+
else:
|
258 |
+
h = h + emb_out
|
259 |
+
h = self.out_layers(h)
|
260 |
+
h = self.skip_connection(x) + h
|
261 |
+
|
262 |
+
if self.use_temporal_conv and batch_size:
|
263 |
+
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
|
264 |
+
h = self.temopral_conv(h)
|
265 |
+
h = rearrange(h, 'b c t h w -> (b t) c h w')
|
266 |
+
return h
|
267 |
+
|
268 |
+
|
269 |
+
class TemporalConvBlock(nn.Module):
|
270 |
+
"""
|
271 |
+
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
|
275 |
+
super(TemporalConvBlock, self).__init__()
|
276 |
+
if out_channels is None:
|
277 |
+
out_channels = in_channels
|
278 |
+
self.in_channels = in_channels
|
279 |
+
self.out_channels = out_channels
|
280 |
+
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
|
281 |
+
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
|
282 |
+
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
|
283 |
+
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
|
284 |
+
|
285 |
+
# conv layers
|
286 |
+
self.conv1 = nn.Sequential(
|
287 |
+
nn.GroupNorm(32, in_channels), nn.SiLU(),
|
288 |
+
nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
|
289 |
+
self.conv2 = nn.Sequential(
|
290 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
291 |
+
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
|
292 |
+
self.conv3 = nn.Sequential(
|
293 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
294 |
+
nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
|
295 |
+
self.conv4 = nn.Sequential(
|
296 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
297 |
+
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
|
298 |
+
|
299 |
+
# zero out the last layer params,so the conv block is identity
|
300 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
301 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
identity = x
|
305 |
+
x = self.conv1(x)
|
306 |
+
x = self.conv2(x)
|
307 |
+
x = self.conv3(x)
|
308 |
+
x = self.conv4(x)
|
309 |
+
|
310 |
+
return identity + x
|
311 |
+
|
312 |
+
|
313 |
+
class UNet3DModel(nn.Module, PyTorchModelHubMixin):
|
314 |
+
"""
|
315 |
+
The full UNet model with attention and timestep embedding.
|
316 |
+
:param in_channels: in_channels in the input Tensor.
|
317 |
+
:param model_channels: base channel count for the model.
|
318 |
+
:param out_channels: channels in the output Tensor.
|
319 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
320 |
+
:param attention_resolutions: a collection of downsample rates at which
|
321 |
+
attention will take place. May be a set, list, or tuple.
|
322 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
323 |
+
will be used.
|
324 |
+
:param dropout: the dropout probability.
|
325 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
326 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
327 |
+
downsampling.
|
328 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
329 |
+
:param num_classes: if specified (as an int), then this model will be
|
330 |
+
class-conditional with `num_classes` classes.
|
331 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
332 |
+
:param num_heads: the number of attention heads in each attention layer.
|
333 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
334 |
+
a fixed channel width per attention head.
|
335 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
336 |
+
of heads for upsampling. Deprecated.
|
337 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
338 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
339 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
340 |
+
increased efficiency.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(self,
|
344 |
+
in_channels,
|
345 |
+
model_channels,
|
346 |
+
out_channels,
|
347 |
+
num_res_blocks,
|
348 |
+
attention_resolutions,
|
349 |
+
dropout=0.0,
|
350 |
+
channel_mult=(1, 2, 4, 8),
|
351 |
+
conv_resample=True,
|
352 |
+
dims=2,
|
353 |
+
context_dim=None,
|
354 |
+
use_scale_shift_norm=False,
|
355 |
+
resblock_updown=False,
|
356 |
+
num_heads=-1,
|
357 |
+
num_head_channels=-1,
|
358 |
+
transformer_depth=1,
|
359 |
+
use_linear=False,
|
360 |
+
temporal_conv=False,
|
361 |
+
tempspatial_aware=False,
|
362 |
+
temporal_attention=True,
|
363 |
+
use_relative_position=True,
|
364 |
+
use_causal_attention=False,
|
365 |
+
temporal_length=None,
|
366 |
+
addition_attention=False,
|
367 |
+
temporal_selfatt_only=True,
|
368 |
+
image_cross_attention=False,
|
369 |
+
image_cross_attention_scale_learnable=False,
|
370 |
+
default_fs=4,
|
371 |
+
fs_condition=False,
|
372 |
+
):
|
373 |
+
super(UNet3DModel, self).__init__()
|
374 |
+
if num_heads == -1:
|
375 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
376 |
+
if num_head_channels == -1:
|
377 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
378 |
+
|
379 |
+
self.in_channels = in_channels
|
380 |
+
self.model_channels = model_channels
|
381 |
+
self.out_channels = out_channels
|
382 |
+
self.num_res_blocks = num_res_blocks
|
383 |
+
self.attention_resolutions = attention_resolutions
|
384 |
+
self.dropout = dropout
|
385 |
+
self.channel_mult = channel_mult
|
386 |
+
self.conv_resample = conv_resample
|
387 |
+
self.temporal_attention = temporal_attention
|
388 |
+
time_embed_dim = model_channels * 4
|
389 |
+
self.use_checkpoint = use_checkpoint = False # moved to self.enable_gradient_checkpointing()
|
390 |
+
temporal_self_att_only = True
|
391 |
+
self.addition_attention = addition_attention
|
392 |
+
self.temporal_length = temporal_length
|
393 |
+
self.image_cross_attention = image_cross_attention
|
394 |
+
self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
|
395 |
+
self.default_fs = default_fs
|
396 |
+
self.fs_condition = fs_condition
|
397 |
+
|
398 |
+
## Time embedding blocks
|
399 |
+
self.time_embed = nn.Sequential(
|
400 |
+
linear(model_channels, time_embed_dim),
|
401 |
+
nn.SiLU(),
|
402 |
+
linear(time_embed_dim, time_embed_dim),
|
403 |
+
)
|
404 |
+
if fs_condition:
|
405 |
+
self.fps_embedding = nn.Sequential(
|
406 |
+
linear(model_channels, time_embed_dim),
|
407 |
+
nn.SiLU(),
|
408 |
+
linear(time_embed_dim, time_embed_dim),
|
409 |
+
)
|
410 |
+
nn.init.zeros_(self.fps_embedding[-1].weight)
|
411 |
+
nn.init.zeros_(self.fps_embedding[-1].bias)
|
412 |
+
## Input Block
|
413 |
+
self.input_blocks = nn.ModuleList(
|
414 |
+
[
|
415 |
+
TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
416 |
+
]
|
417 |
+
)
|
418 |
+
if self.addition_attention:
|
419 |
+
self.init_attn = TimestepEmbedSequential(
|
420 |
+
TemporalTransformer(
|
421 |
+
model_channels,
|
422 |
+
n_heads=8,
|
423 |
+
d_head=num_head_channels,
|
424 |
+
depth=transformer_depth,
|
425 |
+
context_dim=context_dim,
|
426 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
|
427 |
+
causal_attention=False, relative_position=use_relative_position,
|
428 |
+
temporal_length=temporal_length))
|
429 |
+
|
430 |
+
input_block_chans = [model_channels]
|
431 |
+
ch = model_channels
|
432 |
+
ds = 1
|
433 |
+
for level, mult in enumerate(channel_mult):
|
434 |
+
for _ in range(num_res_blocks):
|
435 |
+
layers = [
|
436 |
+
ResBlock(ch, time_embed_dim, dropout,
|
437 |
+
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
|
438 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
439 |
+
use_temporal_conv=temporal_conv
|
440 |
+
)
|
441 |
+
]
|
442 |
+
ch = mult * model_channels
|
443 |
+
if ds in attention_resolutions:
|
444 |
+
if num_head_channels == -1:
|
445 |
+
dim_head = ch // num_heads
|
446 |
+
else:
|
447 |
+
num_heads = ch // num_head_channels
|
448 |
+
dim_head = num_head_channels
|
449 |
+
layers.append(
|
450 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
451 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
452 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False,
|
453 |
+
video_length=temporal_length,
|
454 |
+
image_cross_attention=self.image_cross_attention,
|
455 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
|
456 |
+
)
|
457 |
+
)
|
458 |
+
if self.temporal_attention:
|
459 |
+
layers.append(
|
460 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
461 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
462 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
463 |
+
causal_attention=use_causal_attention,
|
464 |
+
relative_position=use_relative_position,
|
465 |
+
temporal_length=temporal_length
|
466 |
+
)
|
467 |
+
)
|
468 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
469 |
+
input_block_chans.append(ch)
|
470 |
+
if level != len(channel_mult) - 1:
|
471 |
+
out_ch = ch
|
472 |
+
self.input_blocks.append(
|
473 |
+
TimestepEmbedSequential(
|
474 |
+
ResBlock(ch, time_embed_dim, dropout,
|
475 |
+
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
|
476 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
477 |
+
down=True
|
478 |
+
)
|
479 |
+
if resblock_updown
|
480 |
+
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
481 |
+
)
|
482 |
+
)
|
483 |
+
ch = out_ch
|
484 |
+
input_block_chans.append(ch)
|
485 |
+
ds *= 2
|
486 |
+
|
487 |
+
if num_head_channels == -1:
|
488 |
+
dim_head = ch // num_heads
|
489 |
+
else:
|
490 |
+
num_heads = ch // num_head_channels
|
491 |
+
dim_head = num_head_channels
|
492 |
+
layers = [
|
493 |
+
ResBlock(ch, time_embed_dim, dropout,
|
494 |
+
dims=dims, use_checkpoint=use_checkpoint,
|
495 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
496 |
+
use_temporal_conv=temporal_conv
|
497 |
+
),
|
498 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
499 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
500 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
|
501 |
+
image_cross_attention=self.image_cross_attention,
|
502 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
|
503 |
+
)
|
504 |
+
]
|
505 |
+
if self.temporal_attention:
|
506 |
+
layers.append(
|
507 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
508 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
509 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
510 |
+
causal_attention=use_causal_attention, relative_position=use_relative_position,
|
511 |
+
temporal_length=temporal_length
|
512 |
+
)
|
513 |
+
)
|
514 |
+
layers.append(
|
515 |
+
ResBlock(ch, time_embed_dim, dropout,
|
516 |
+
dims=dims, use_checkpoint=use_checkpoint,
|
517 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
518 |
+
use_temporal_conv=temporal_conv
|
519 |
+
)
|
520 |
+
)
|
521 |
+
|
522 |
+
## Middle Block
|
523 |
+
self.middle_block = TimestepEmbedSequential(*layers)
|
524 |
+
|
525 |
+
## Output Block
|
526 |
+
self.output_blocks = nn.ModuleList([])
|
527 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
528 |
+
for i in range(num_res_blocks + 1):
|
529 |
+
ich = input_block_chans.pop()
|
530 |
+
layers = [
|
531 |
+
ResBlock(ch + ich, time_embed_dim, dropout,
|
532 |
+
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
|
533 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
534 |
+
use_temporal_conv=temporal_conv
|
535 |
+
)
|
536 |
+
]
|
537 |
+
ch = model_channels * mult
|
538 |
+
if ds in attention_resolutions:
|
539 |
+
if num_head_channels == -1:
|
540 |
+
dim_head = ch // num_heads
|
541 |
+
else:
|
542 |
+
num_heads = ch // num_head_channels
|
543 |
+
dim_head = num_head_channels
|
544 |
+
layers.append(
|
545 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
546 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
547 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False,
|
548 |
+
video_length=temporal_length,
|
549 |
+
image_cross_attention=self.image_cross_attention,
|
550 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
|
551 |
+
)
|
552 |
+
)
|
553 |
+
if self.temporal_attention:
|
554 |
+
layers.append(
|
555 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
556 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
557 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
558 |
+
causal_attention=use_causal_attention,
|
559 |
+
relative_position=use_relative_position,
|
560 |
+
temporal_length=temporal_length
|
561 |
+
)
|
562 |
+
)
|
563 |
+
if level and i == num_res_blocks:
|
564 |
+
out_ch = ch
|
565 |
+
layers.append(
|
566 |
+
ResBlock(ch, time_embed_dim, dropout,
|
567 |
+
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
|
568 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
569 |
+
up=True
|
570 |
+
)
|
571 |
+
if resblock_updown
|
572 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
573 |
+
)
|
574 |
+
ds //= 2
|
575 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
576 |
+
|
577 |
+
self.out = nn.Sequential(
|
578 |
+
normalization(ch),
|
579 |
+
nn.SiLU(),
|
580 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
581 |
+
)
|
582 |
+
|
583 |
+
@property
|
584 |
+
def device(self):
|
585 |
+
return next(self.parameters()).device
|
586 |
+
|
587 |
+
@property
|
588 |
+
def dtype(self):
|
589 |
+
return next(self.parameters()).dtype
|
590 |
+
|
591 |
+
def forward(self, x, timesteps, context_text=None, context_img=None, concat_cond=None, fs=None, **kwargs):
|
592 |
+
b, _, t, _, _ = x.shape
|
593 |
+
|
594 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
|
595 |
+
emb = self.time_embed(t_emb)
|
596 |
+
|
597 |
+
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
598 |
+
context_img = rearrange(context_img, 'b t l c -> (b t) l c')
|
599 |
+
|
600 |
+
context = (context_text, context_img)
|
601 |
+
|
602 |
+
emb = emb.repeat_interleave(repeats=t, dim=0)
|
603 |
+
|
604 |
+
if concat_cond is not None:
|
605 |
+
x = torch.cat([x, concat_cond], dim=1)
|
606 |
+
|
607 |
+
## always in shape (b t) c h w, except for temporal layer
|
608 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
609 |
+
|
610 |
+
## combine emb
|
611 |
+
if self.fs_condition:
|
612 |
+
if fs is None:
|
613 |
+
fs = torch.tensor(
|
614 |
+
[self.default_fs] * b, dtype=torch.long, device=x.device)
|
615 |
+
fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
|
616 |
+
|
617 |
+
fs_embed = self.fps_embedding(fs_emb)
|
618 |
+
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
619 |
+
emb = emb + fs_embed
|
620 |
+
|
621 |
+
h = x
|
622 |
+
hs = []
|
623 |
+
for id, module in enumerate(self.input_blocks):
|
624 |
+
h = module(h, emb, context=context, batch_size=b)
|
625 |
+
if id == 0 and self.addition_attention:
|
626 |
+
h = self.init_attn(h, emb, context=context, batch_size=b)
|
627 |
+
hs.append(h)
|
628 |
+
|
629 |
+
h = self.middle_block(h, emb, context=context, batch_size=b)
|
630 |
+
|
631 |
+
for module in self.output_blocks:
|
632 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
633 |
+
h = module(h, emb, context=context, batch_size=b)
|
634 |
+
h = h.type(x.dtype)
|
635 |
+
y = self.out(h)
|
636 |
+
|
637 |
+
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
638 |
+
return y
|
639 |
+
|
640 |
+
def enable_gradient_checkpointing(self, enable=True, verbose=False):
|
641 |
+
for k, v in self.named_modules():
|
642 |
+
if hasattr(v, 'checkpoint'):
|
643 |
+
v.checkpoint = enable
|
644 |
+
if verbose:
|
645 |
+
print(f'{k}.checkpoint = {enable}')
|
646 |
+
if hasattr(v, 'use_checkpoint'):
|
647 |
+
v.use_checkpoint = enable
|
648 |
+
if verbose:
|
649 |
+
print(f'{k}.use_checkpoint = {enable}')
|
650 |
+
return
|
diffusers_vdm/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
|
8 |
+
def resize_and_center_crop(image, target_width, target_height, interpolation=cv2.INTER_AREA):
|
9 |
+
original_height, original_width = image.shape[:2]
|
10 |
+
k = max(target_height / original_height, target_width / original_width)
|
11 |
+
new_width = int(round(original_width * k))
|
12 |
+
new_height = int(round(original_height * k))
|
13 |
+
resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
|
14 |
+
x_start = (new_width - target_width) // 2
|
15 |
+
y_start = (new_height - target_height) // 2
|
16 |
+
cropped_image = resized_image[y_start:y_start + target_height, x_start:x_start + target_width]
|
17 |
+
return cropped_image
|
18 |
+
|
19 |
+
|
20 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10):
|
21 |
+
b, c, t, h, w = x.shape
|
22 |
+
|
23 |
+
per_row = b
|
24 |
+
for p in [6, 5, 4, 3, 2]:
|
25 |
+
if b % p == 0:
|
26 |
+
per_row = p
|
27 |
+
break
|
28 |
+
|
29 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
30 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
31 |
+
x = x.detach().cpu().to(torch.uint8)
|
32 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
33 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '1'})
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def save_bcthw_as_png(x, output_filename):
|
38 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
39 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
40 |
+
x = x.detach().cpu().to(torch.uint8)
|
41 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
42 |
+
torchvision.io.write_png(x, output_filename)
|
43 |
+
return output_filename
|
diffusers_vdm/vae.py
ADDED
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# video VAE with many components from lots of repos
|
2 |
+
# collected by lvmin
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import xformers.ops
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from diffusers_vdm.basics import default, exists, zero_module, conv_nd, linear, normalization
|
11 |
+
from diffusers_vdm.unet import Upsample, Downsample
|
12 |
+
from huggingface_hub import PyTorchModelHubMixin
|
13 |
+
|
14 |
+
|
15 |
+
def chunked_attention(q, k, v, batch_chunk=0):
|
16 |
+
# if batch_chunk > 0 and not torch.is_grad_enabled():
|
17 |
+
# batch_size = q.size(0)
|
18 |
+
# chunks = [slice(i, i + batch_chunk) for i in range(0, batch_size, batch_chunk)]
|
19 |
+
#
|
20 |
+
# out_chunks = []
|
21 |
+
# for chunk in chunks:
|
22 |
+
# q_chunk = q[chunk]
|
23 |
+
# k_chunk = k[chunk]
|
24 |
+
# v_chunk = v[chunk]
|
25 |
+
#
|
26 |
+
# out_chunk = torch.nn.functional.scaled_dot_product_attention(
|
27 |
+
# q_chunk, k_chunk, v_chunk, attn_mask=None
|
28 |
+
# )
|
29 |
+
# out_chunks.append(out_chunk)
|
30 |
+
#
|
31 |
+
# out = torch.cat(out_chunks, dim=0)
|
32 |
+
# else:
|
33 |
+
# out = torch.nn.functional.scaled_dot_product_attention(
|
34 |
+
# q, k, v, attn_mask=None
|
35 |
+
# )
|
36 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def nonlinearity(x):
|
41 |
+
return x * torch.sigmoid(x)
|
42 |
+
|
43 |
+
|
44 |
+
def GroupNorm(in_channels, num_groups=32):
|
45 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
46 |
+
|
47 |
+
|
48 |
+
class DiagonalGaussianDistribution:
|
49 |
+
def __init__(self, parameters, deterministic=False):
|
50 |
+
self.parameters = parameters
|
51 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
52 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
53 |
+
self.deterministic = deterministic
|
54 |
+
self.std = torch.exp(0.5 * self.logvar)
|
55 |
+
self.var = torch.exp(self.logvar)
|
56 |
+
if self.deterministic:
|
57 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
58 |
+
|
59 |
+
def sample(self, noise=None):
|
60 |
+
if noise is None:
|
61 |
+
noise = torch.randn(self.mean.shape)
|
62 |
+
|
63 |
+
x = self.mean + self.std * noise.to(device=self.parameters.device)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def mode(self):
|
67 |
+
return self.mean
|
68 |
+
|
69 |
+
|
70 |
+
class EncoderDownSampleBlock(nn.Module):
|
71 |
+
def __init__(self, in_channels, with_conv):
|
72 |
+
super().__init__()
|
73 |
+
self.with_conv = with_conv
|
74 |
+
self.in_channels = in_channels
|
75 |
+
if self.with_conv:
|
76 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
77 |
+
in_channels,
|
78 |
+
kernel_size=3,
|
79 |
+
stride=2,
|
80 |
+
padding=0)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
if self.with_conv:
|
84 |
+
pad = (0, 1, 0, 1)
|
85 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
86 |
+
x = self.conv(x)
|
87 |
+
else:
|
88 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class ResnetBlock(nn.Module):
|
93 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
94 |
+
dropout, temb_channels=512):
|
95 |
+
super().__init__()
|
96 |
+
self.in_channels = in_channels
|
97 |
+
out_channels = in_channels if out_channels is None else out_channels
|
98 |
+
self.out_channels = out_channels
|
99 |
+
self.use_conv_shortcut = conv_shortcut
|
100 |
+
|
101 |
+
self.norm1 = GroupNorm(in_channels)
|
102 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
if temb_channels > 0:
|
108 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
109 |
+
out_channels)
|
110 |
+
self.norm2 = GroupNorm(out_channels)
|
111 |
+
self.dropout = torch.nn.Dropout(dropout)
|
112 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
113 |
+
out_channels,
|
114 |
+
kernel_size=3,
|
115 |
+
stride=1,
|
116 |
+
padding=1)
|
117 |
+
if self.in_channels != self.out_channels:
|
118 |
+
if self.use_conv_shortcut:
|
119 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
120 |
+
out_channels,
|
121 |
+
kernel_size=3,
|
122 |
+
stride=1,
|
123 |
+
padding=1)
|
124 |
+
else:
|
125 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
126 |
+
out_channels,
|
127 |
+
kernel_size=1,
|
128 |
+
stride=1,
|
129 |
+
padding=0)
|
130 |
+
|
131 |
+
def forward(self, x, temb):
|
132 |
+
h = x
|
133 |
+
h = self.norm1(h)
|
134 |
+
h = nonlinearity(h)
|
135 |
+
h = self.conv1(h)
|
136 |
+
|
137 |
+
if temb is not None:
|
138 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
139 |
+
|
140 |
+
h = self.norm2(h)
|
141 |
+
h = nonlinearity(h)
|
142 |
+
h = self.dropout(h)
|
143 |
+
h = self.conv2(h)
|
144 |
+
|
145 |
+
if self.in_channels != self.out_channels:
|
146 |
+
if self.use_conv_shortcut:
|
147 |
+
x = self.conv_shortcut(x)
|
148 |
+
else:
|
149 |
+
x = self.nin_shortcut(x)
|
150 |
+
|
151 |
+
return x + h
|
152 |
+
|
153 |
+
|
154 |
+
class Encoder(nn.Module):
|
155 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
156 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
157 |
+
resolution, z_channels, double_z=True, **kwargs):
|
158 |
+
super().__init__()
|
159 |
+
self.ch = ch
|
160 |
+
self.temb_ch = 0
|
161 |
+
self.num_resolutions = len(ch_mult)
|
162 |
+
self.num_res_blocks = num_res_blocks
|
163 |
+
self.resolution = resolution
|
164 |
+
self.in_channels = in_channels
|
165 |
+
|
166 |
+
# downsampling
|
167 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
168 |
+
self.ch,
|
169 |
+
kernel_size=3,
|
170 |
+
stride=1,
|
171 |
+
padding=1)
|
172 |
+
|
173 |
+
curr_res = resolution
|
174 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
175 |
+
self.in_ch_mult = in_ch_mult
|
176 |
+
self.down = nn.ModuleList()
|
177 |
+
for i_level in range(self.num_resolutions):
|
178 |
+
block = nn.ModuleList()
|
179 |
+
attn = nn.ModuleList()
|
180 |
+
block_in = ch * in_ch_mult[i_level]
|
181 |
+
block_out = ch * ch_mult[i_level]
|
182 |
+
for i_block in range(self.num_res_blocks):
|
183 |
+
block.append(ResnetBlock(in_channels=block_in,
|
184 |
+
out_channels=block_out,
|
185 |
+
temb_channels=self.temb_ch,
|
186 |
+
dropout=dropout))
|
187 |
+
block_in = block_out
|
188 |
+
if curr_res in attn_resolutions:
|
189 |
+
attn.append(Attention(block_in))
|
190 |
+
down = nn.Module()
|
191 |
+
down.block = block
|
192 |
+
down.attn = attn
|
193 |
+
if i_level != self.num_resolutions - 1:
|
194 |
+
down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv)
|
195 |
+
curr_res = curr_res // 2
|
196 |
+
self.down.append(down)
|
197 |
+
|
198 |
+
# middle
|
199 |
+
self.mid = nn.Module()
|
200 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
201 |
+
out_channels=block_in,
|
202 |
+
temb_channels=self.temb_ch,
|
203 |
+
dropout=dropout)
|
204 |
+
self.mid.attn_1 = Attention(block_in)
|
205 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
206 |
+
out_channels=block_in,
|
207 |
+
temb_channels=self.temb_ch,
|
208 |
+
dropout=dropout)
|
209 |
+
|
210 |
+
# end
|
211 |
+
self.norm_out = GroupNorm(block_in)
|
212 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
213 |
+
2 * z_channels if double_z else z_channels,
|
214 |
+
kernel_size=3,
|
215 |
+
stride=1,
|
216 |
+
padding=1)
|
217 |
+
|
218 |
+
def forward(self, x, return_hidden_states=False):
|
219 |
+
# timestep embedding
|
220 |
+
temb = None
|
221 |
+
|
222 |
+
# print(f'encoder-input={x.shape}')
|
223 |
+
# downsampling
|
224 |
+
hs = [self.conv_in(x)]
|
225 |
+
|
226 |
+
## if we return hidden states for decoder usage, we will store them in a list
|
227 |
+
if return_hidden_states:
|
228 |
+
hidden_states = []
|
229 |
+
# print(f'encoder-conv in feat={hs[0].shape}')
|
230 |
+
for i_level in range(self.num_resolutions):
|
231 |
+
for i_block in range(self.num_res_blocks):
|
232 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
233 |
+
# print(f'encoder-down feat={h.shape}')
|
234 |
+
if len(self.down[i_level].attn) > 0:
|
235 |
+
h = self.down[i_level].attn[i_block](h)
|
236 |
+
hs.append(h)
|
237 |
+
if return_hidden_states:
|
238 |
+
hidden_states.append(h)
|
239 |
+
if i_level != self.num_resolutions - 1:
|
240 |
+
# print(f'encoder-downsample (input)={hs[-1].shape}')
|
241 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
242 |
+
# print(f'encoder-downsample (output)={hs[-1].shape}')
|
243 |
+
if return_hidden_states:
|
244 |
+
hidden_states.append(hs[0])
|
245 |
+
# middle
|
246 |
+
h = hs[-1]
|
247 |
+
h = self.mid.block_1(h, temb)
|
248 |
+
# print(f'encoder-mid1 feat={h.shape}')
|
249 |
+
h = self.mid.attn_1(h)
|
250 |
+
h = self.mid.block_2(h, temb)
|
251 |
+
# print(f'encoder-mid2 feat={h.shape}')
|
252 |
+
|
253 |
+
# end
|
254 |
+
h = self.norm_out(h)
|
255 |
+
h = nonlinearity(h)
|
256 |
+
h = self.conv_out(h)
|
257 |
+
# print(f'end feat={h.shape}')
|
258 |
+
if return_hidden_states:
|
259 |
+
return h, hidden_states
|
260 |
+
else:
|
261 |
+
return h
|
262 |
+
|
263 |
+
|
264 |
+
class ConvCombiner(nn.Module):
|
265 |
+
def __init__(self, ch):
|
266 |
+
super().__init__()
|
267 |
+
self.conv = nn.Conv2d(ch, ch, 1, padding=0)
|
268 |
+
|
269 |
+
nn.init.zeros_(self.conv.weight)
|
270 |
+
nn.init.zeros_(self.conv.bias)
|
271 |
+
|
272 |
+
def forward(self, x, context):
|
273 |
+
## x: b c h w, context: b c 2 h w
|
274 |
+
b, c, l, h, w = context.shape
|
275 |
+
bt, c, h, w = x.shape
|
276 |
+
context = rearrange(context, "b c l h w -> (b l) c h w")
|
277 |
+
context = self.conv(context)
|
278 |
+
context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
|
279 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=bt // b)
|
280 |
+
x[:, :, 0] = x[:, :, 0] + context[:, :, 0]
|
281 |
+
x[:, :, -1] = x[:, :, -1] + context[:, :, -1]
|
282 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
class AttentionCombiner(nn.Module):
|
287 |
+
def __init__(
|
288 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
inner_dim = dim_head * heads
|
293 |
+
context_dim = default(context_dim, query_dim)
|
294 |
+
|
295 |
+
self.heads = heads
|
296 |
+
self.dim_head = dim_head
|
297 |
+
|
298 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
299 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
300 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
301 |
+
|
302 |
+
self.to_out = nn.Sequential(
|
303 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
304 |
+
)
|
305 |
+
self.attention_op = None
|
306 |
+
|
307 |
+
self.norm = GroupNorm(query_dim)
|
308 |
+
nn.init.zeros_(self.to_out[0].weight)
|
309 |
+
nn.init.zeros_(self.to_out[0].bias)
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
x,
|
314 |
+
context=None,
|
315 |
+
mask=None,
|
316 |
+
):
|
317 |
+
bt, c, h, w = x.shape
|
318 |
+
h_ = self.norm(x)
|
319 |
+
h_ = rearrange(h_, "b c h w -> b (h w) c")
|
320 |
+
q = self.to_q(h_)
|
321 |
+
|
322 |
+
b, c, l, h, w = context.shape
|
323 |
+
context = rearrange(context, "b c l h w -> (b l) (h w) c")
|
324 |
+
k = self.to_k(context)
|
325 |
+
v = self.to_v(context)
|
326 |
+
|
327 |
+
t = bt // b
|
328 |
+
k = repeat(k, "(b l) d c -> (b t) (l d) c", l=l, t=t)
|
329 |
+
v = repeat(v, "(b l) d c -> (b t) (l d) c", l=l, t=t)
|
330 |
+
|
331 |
+
b, _, _ = q.shape
|
332 |
+
q, k, v = map(
|
333 |
+
lambda t: t.unsqueeze(3)
|
334 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
335 |
+
.permute(0, 2, 1, 3)
|
336 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
337 |
+
.contiguous(),
|
338 |
+
(q, k, v),
|
339 |
+
)
|
340 |
+
|
341 |
+
out = chunked_attention(
|
342 |
+
q, k, v, batch_chunk=1
|
343 |
+
)
|
344 |
+
|
345 |
+
if exists(mask):
|
346 |
+
raise NotImplementedError
|
347 |
+
|
348 |
+
out = (
|
349 |
+
out.unsqueeze(0)
|
350 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
351 |
+
.permute(0, 2, 1, 3)
|
352 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
353 |
+
)
|
354 |
+
out = self.to_out(out)
|
355 |
+
out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
|
356 |
+
return x + out
|
357 |
+
|
358 |
+
|
359 |
+
class Attention(nn.Module):
|
360 |
+
def __init__(self, in_channels):
|
361 |
+
super().__init__()
|
362 |
+
self.in_channels = in_channels
|
363 |
+
|
364 |
+
self.norm = GroupNorm(in_channels)
|
365 |
+
self.q = torch.nn.Conv2d(
|
366 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
367 |
+
)
|
368 |
+
self.k = torch.nn.Conv2d(
|
369 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
370 |
+
)
|
371 |
+
self.v = torch.nn.Conv2d(
|
372 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
373 |
+
)
|
374 |
+
self.proj_out = torch.nn.Conv2d(
|
375 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
376 |
+
)
|
377 |
+
|
378 |
+
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
379 |
+
h_ = self.norm(h_)
|
380 |
+
q = self.q(h_)
|
381 |
+
k = self.k(h_)
|
382 |
+
v = self.v(h_)
|
383 |
+
|
384 |
+
# compute attention
|
385 |
+
B, C, H, W = q.shape
|
386 |
+
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
387 |
+
|
388 |
+
q, k, v = map(
|
389 |
+
lambda t: t.unsqueeze(3)
|
390 |
+
.reshape(B, t.shape[1], 1, C)
|
391 |
+
.permute(0, 2, 1, 3)
|
392 |
+
.reshape(B * 1, t.shape[1], C)
|
393 |
+
.contiguous(),
|
394 |
+
(q, k, v),
|
395 |
+
)
|
396 |
+
|
397 |
+
out = chunked_attention(
|
398 |
+
q, k, v, batch_chunk=1
|
399 |
+
)
|
400 |
+
|
401 |
+
out = (
|
402 |
+
out.unsqueeze(0)
|
403 |
+
.reshape(B, 1, out.shape[1], C)
|
404 |
+
.permute(0, 2, 1, 3)
|
405 |
+
.reshape(B, out.shape[1], C)
|
406 |
+
)
|
407 |
+
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
408 |
+
|
409 |
+
def forward(self, x, **kwargs):
|
410 |
+
h_ = x
|
411 |
+
h_ = self.attention(h_)
|
412 |
+
h_ = self.proj_out(h_)
|
413 |
+
return x + h_
|
414 |
+
|
415 |
+
|
416 |
+
class VideoDecoder(nn.Module):
|
417 |
+
def __init__(
|
418 |
+
self,
|
419 |
+
*,
|
420 |
+
ch,
|
421 |
+
out_ch,
|
422 |
+
ch_mult=(1, 2, 4, 8),
|
423 |
+
num_res_blocks,
|
424 |
+
attn_resolutions,
|
425 |
+
dropout=0.0,
|
426 |
+
resamp_with_conv=True,
|
427 |
+
in_channels,
|
428 |
+
resolution,
|
429 |
+
z_channels,
|
430 |
+
give_pre_end=False,
|
431 |
+
tanh_out=False,
|
432 |
+
use_linear_attn=False,
|
433 |
+
attn_level=[2, 3],
|
434 |
+
video_kernel_size=[3, 1, 1],
|
435 |
+
alpha: float = 0.0,
|
436 |
+
merge_strategy: str = "learned",
|
437 |
+
**kwargs,
|
438 |
+
):
|
439 |
+
super().__init__()
|
440 |
+
self.video_kernel_size = video_kernel_size
|
441 |
+
self.alpha = alpha
|
442 |
+
self.merge_strategy = merge_strategy
|
443 |
+
self.ch = ch
|
444 |
+
self.temb_ch = 0
|
445 |
+
self.num_resolutions = len(ch_mult)
|
446 |
+
self.num_res_blocks = num_res_blocks
|
447 |
+
self.resolution = resolution
|
448 |
+
self.in_channels = in_channels
|
449 |
+
self.give_pre_end = give_pre_end
|
450 |
+
self.tanh_out = tanh_out
|
451 |
+
self.attn_level = attn_level
|
452 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
453 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
454 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
455 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
456 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
457 |
+
|
458 |
+
# z to block_in
|
459 |
+
self.conv_in = torch.nn.Conv2d(
|
460 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
461 |
+
)
|
462 |
+
|
463 |
+
# middle
|
464 |
+
self.mid = nn.Module()
|
465 |
+
self.mid.block_1 = VideoResBlock(
|
466 |
+
in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout,
|
470 |
+
video_kernel_size=self.video_kernel_size,
|
471 |
+
alpha=self.alpha,
|
472 |
+
merge_strategy=self.merge_strategy,
|
473 |
+
)
|
474 |
+
self.mid.attn_1 = Attention(block_in)
|
475 |
+
self.mid.block_2 = VideoResBlock(
|
476 |
+
in_channels=block_in,
|
477 |
+
out_channels=block_in,
|
478 |
+
temb_channels=self.temb_ch,
|
479 |
+
dropout=dropout,
|
480 |
+
video_kernel_size=self.video_kernel_size,
|
481 |
+
alpha=self.alpha,
|
482 |
+
merge_strategy=self.merge_strategy,
|
483 |
+
)
|
484 |
+
|
485 |
+
# upsampling
|
486 |
+
self.up = nn.ModuleList()
|
487 |
+
self.attn_refinement = nn.ModuleList()
|
488 |
+
for i_level in reversed(range(self.num_resolutions)):
|
489 |
+
block = nn.ModuleList()
|
490 |
+
attn = nn.ModuleList()
|
491 |
+
block_out = ch * ch_mult[i_level]
|
492 |
+
for i_block in range(self.num_res_blocks + 1):
|
493 |
+
block.append(
|
494 |
+
VideoResBlock(
|
495 |
+
in_channels=block_in,
|
496 |
+
out_channels=block_out,
|
497 |
+
temb_channels=self.temb_ch,
|
498 |
+
dropout=dropout,
|
499 |
+
video_kernel_size=self.video_kernel_size,
|
500 |
+
alpha=self.alpha,
|
501 |
+
merge_strategy=self.merge_strategy,
|
502 |
+
)
|
503 |
+
)
|
504 |
+
block_in = block_out
|
505 |
+
if curr_res in attn_resolutions:
|
506 |
+
attn.append(Attention(block_in))
|
507 |
+
up = nn.Module()
|
508 |
+
up.block = block
|
509 |
+
up.attn = attn
|
510 |
+
if i_level != 0:
|
511 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
512 |
+
curr_res = curr_res * 2
|
513 |
+
self.up.insert(0, up) # prepend to get consistent order
|
514 |
+
|
515 |
+
if i_level in self.attn_level:
|
516 |
+
self.attn_refinement.insert(0, AttentionCombiner(block_in))
|
517 |
+
else:
|
518 |
+
self.attn_refinement.insert(0, ConvCombiner(block_in))
|
519 |
+
# end
|
520 |
+
self.norm_out = GroupNorm(block_in)
|
521 |
+
self.attn_refinement.append(ConvCombiner(block_in))
|
522 |
+
self.conv_out = DecoderConv3D(
|
523 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1, video_kernel_size=self.video_kernel_size
|
524 |
+
)
|
525 |
+
|
526 |
+
def forward(self, z, ref_context=None, **kwargs):
|
527 |
+
## ref_context: b c 2 h w, 2 means starting and ending frame
|
528 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
529 |
+
self.last_z_shape = z.shape
|
530 |
+
# timestep embedding
|
531 |
+
temb = None
|
532 |
+
|
533 |
+
# z to block_in
|
534 |
+
h = self.conv_in(z)
|
535 |
+
|
536 |
+
# middle
|
537 |
+
h = self.mid.block_1(h, temb, **kwargs)
|
538 |
+
h = self.mid.attn_1(h, **kwargs)
|
539 |
+
h = self.mid.block_2(h, temb, **kwargs)
|
540 |
+
|
541 |
+
# upsampling
|
542 |
+
for i_level in reversed(range(self.num_resolutions)):
|
543 |
+
for i_block in range(self.num_res_blocks + 1):
|
544 |
+
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
545 |
+
if len(self.up[i_level].attn) > 0:
|
546 |
+
h = self.up[i_level].attn[i_block](h, **kwargs)
|
547 |
+
if ref_context:
|
548 |
+
h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
|
549 |
+
if i_level != 0:
|
550 |
+
h = self.up[i_level].upsample(h)
|
551 |
+
|
552 |
+
# end
|
553 |
+
if self.give_pre_end:
|
554 |
+
return h
|
555 |
+
|
556 |
+
h = self.norm_out(h)
|
557 |
+
h = nonlinearity(h)
|
558 |
+
if ref_context:
|
559 |
+
# print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
|
560 |
+
h = self.attn_refinement[-1](x=h, context=ref_context[-1])
|
561 |
+
h = self.conv_out(h, **kwargs)
|
562 |
+
if self.tanh_out:
|
563 |
+
h = torch.tanh(h)
|
564 |
+
return h
|
565 |
+
|
566 |
+
|
567 |
+
class TimeStackBlock(torch.nn.Module):
|
568 |
+
def __init__(
|
569 |
+
self,
|
570 |
+
channels: int,
|
571 |
+
emb_channels: int,
|
572 |
+
dropout: float,
|
573 |
+
out_channels: int = None,
|
574 |
+
use_conv: bool = False,
|
575 |
+
use_scale_shift_norm: bool = False,
|
576 |
+
dims: int = 2,
|
577 |
+
use_checkpoint: bool = False,
|
578 |
+
up: bool = False,
|
579 |
+
down: bool = False,
|
580 |
+
kernel_size: int = 3,
|
581 |
+
exchange_temb_dims: bool = False,
|
582 |
+
skip_t_emb: bool = False,
|
583 |
+
):
|
584 |
+
super().__init__()
|
585 |
+
self.channels = channels
|
586 |
+
self.emb_channels = emb_channels
|
587 |
+
self.dropout = dropout
|
588 |
+
self.out_channels = out_channels or channels
|
589 |
+
self.use_conv = use_conv
|
590 |
+
self.use_checkpoint = use_checkpoint
|
591 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
592 |
+
self.exchange_temb_dims = exchange_temb_dims
|
593 |
+
|
594 |
+
if isinstance(kernel_size, list):
|
595 |
+
padding = [k // 2 for k in kernel_size]
|
596 |
+
else:
|
597 |
+
padding = kernel_size // 2
|
598 |
+
|
599 |
+
self.in_layers = nn.Sequential(
|
600 |
+
normalization(channels),
|
601 |
+
nn.SiLU(),
|
602 |
+
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
603 |
+
)
|
604 |
+
|
605 |
+
self.updown = up or down
|
606 |
+
|
607 |
+
if up:
|
608 |
+
self.h_upd = Upsample(channels, False, dims)
|
609 |
+
self.x_upd = Upsample(channels, False, dims)
|
610 |
+
elif down:
|
611 |
+
self.h_upd = Downsample(channels, False, dims)
|
612 |
+
self.x_upd = Downsample(channels, False, dims)
|
613 |
+
else:
|
614 |
+
self.h_upd = self.x_upd = nn.Identity()
|
615 |
+
|
616 |
+
self.skip_t_emb = skip_t_emb
|
617 |
+
self.emb_out_channels = (
|
618 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
619 |
+
)
|
620 |
+
if self.skip_t_emb:
|
621 |
+
# print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
622 |
+
assert not self.use_scale_shift_norm
|
623 |
+
self.emb_layers = None
|
624 |
+
self.exchange_temb_dims = False
|
625 |
+
else:
|
626 |
+
self.emb_layers = nn.Sequential(
|
627 |
+
nn.SiLU(),
|
628 |
+
linear(
|
629 |
+
emb_channels,
|
630 |
+
self.emb_out_channels,
|
631 |
+
),
|
632 |
+
)
|
633 |
+
|
634 |
+
self.out_layers = nn.Sequential(
|
635 |
+
normalization(self.out_channels),
|
636 |
+
nn.SiLU(),
|
637 |
+
nn.Dropout(p=dropout),
|
638 |
+
zero_module(
|
639 |
+
conv_nd(
|
640 |
+
dims,
|
641 |
+
self.out_channels,
|
642 |
+
self.out_channels,
|
643 |
+
kernel_size,
|
644 |
+
padding=padding,
|
645 |
+
)
|
646 |
+
),
|
647 |
+
)
|
648 |
+
|
649 |
+
if self.out_channels == channels:
|
650 |
+
self.skip_connection = nn.Identity()
|
651 |
+
elif use_conv:
|
652 |
+
self.skip_connection = conv_nd(
|
653 |
+
dims, channels, self.out_channels, kernel_size, padding=padding
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
657 |
+
|
658 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
659 |
+
if self.updown:
|
660 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
661 |
+
h = in_rest(x)
|
662 |
+
h = self.h_upd(h)
|
663 |
+
x = self.x_upd(x)
|
664 |
+
h = in_conv(h)
|
665 |
+
else:
|
666 |
+
h = self.in_layers(x)
|
667 |
+
|
668 |
+
if self.skip_t_emb:
|
669 |
+
emb_out = torch.zeros_like(h)
|
670 |
+
else:
|
671 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
672 |
+
while len(emb_out.shape) < len(h.shape):
|
673 |
+
emb_out = emb_out[..., None]
|
674 |
+
if self.use_scale_shift_norm:
|
675 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
676 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
677 |
+
h = out_norm(h) * (1 + scale) + shift
|
678 |
+
h = out_rest(h)
|
679 |
+
else:
|
680 |
+
if self.exchange_temb_dims:
|
681 |
+
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
682 |
+
h = h + emb_out
|
683 |
+
h = self.out_layers(h)
|
684 |
+
return self.skip_connection(x) + h
|
685 |
+
|
686 |
+
|
687 |
+
class VideoResBlock(ResnetBlock):
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
out_channels,
|
691 |
+
*args,
|
692 |
+
dropout=0.0,
|
693 |
+
video_kernel_size=3,
|
694 |
+
alpha=0.0,
|
695 |
+
merge_strategy="learned",
|
696 |
+
**kwargs,
|
697 |
+
):
|
698 |
+
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
699 |
+
if video_kernel_size is None:
|
700 |
+
video_kernel_size = [3, 1, 1]
|
701 |
+
self.time_stack = TimeStackBlock(
|
702 |
+
channels=out_channels,
|
703 |
+
emb_channels=0,
|
704 |
+
dropout=dropout,
|
705 |
+
dims=3,
|
706 |
+
use_scale_shift_norm=False,
|
707 |
+
use_conv=False,
|
708 |
+
up=False,
|
709 |
+
down=False,
|
710 |
+
kernel_size=video_kernel_size,
|
711 |
+
use_checkpoint=True,
|
712 |
+
skip_t_emb=True,
|
713 |
+
)
|
714 |
+
|
715 |
+
self.merge_strategy = merge_strategy
|
716 |
+
if self.merge_strategy == "fixed":
|
717 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
718 |
+
elif self.merge_strategy == "learned":
|
719 |
+
self.register_parameter(
|
720 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
724 |
+
|
725 |
+
def get_alpha(self, bs):
|
726 |
+
if self.merge_strategy == "fixed":
|
727 |
+
return self.mix_factor
|
728 |
+
elif self.merge_strategy == "learned":
|
729 |
+
return torch.sigmoid(self.mix_factor)
|
730 |
+
else:
|
731 |
+
raise NotImplementedError()
|
732 |
+
|
733 |
+
def forward(self, x, temb, skip_video=False, timesteps=None):
|
734 |
+
assert isinstance(timesteps, int)
|
735 |
+
|
736 |
+
b, c, h, w = x.shape
|
737 |
+
|
738 |
+
x = super().forward(x, temb)
|
739 |
+
|
740 |
+
if not skip_video:
|
741 |
+
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
742 |
+
|
743 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
744 |
+
|
745 |
+
x = self.time_stack(x, temb)
|
746 |
+
|
747 |
+
alpha = self.get_alpha(bs=b // timesteps)
|
748 |
+
x = alpha * x + (1.0 - alpha) * x_mix
|
749 |
+
|
750 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
751 |
+
return x
|
752 |
+
|
753 |
+
|
754 |
+
class DecoderConv3D(torch.nn.Conv2d):
|
755 |
+
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
756 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
757 |
+
if isinstance(video_kernel_size, list):
|
758 |
+
padding = [int(k // 2) for k in video_kernel_size]
|
759 |
+
else:
|
760 |
+
padding = int(video_kernel_size // 2)
|
761 |
+
|
762 |
+
self.time_mix_conv = torch.nn.Conv3d(
|
763 |
+
in_channels=out_channels,
|
764 |
+
out_channels=out_channels,
|
765 |
+
kernel_size=video_kernel_size,
|
766 |
+
padding=padding,
|
767 |
+
)
|
768 |
+
|
769 |
+
def forward(self, input, timesteps, skip_video=False):
|
770 |
+
x = super().forward(input)
|
771 |
+
if skip_video:
|
772 |
+
return x
|
773 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
774 |
+
x = self.time_mix_conv(x)
|
775 |
+
return rearrange(x, "b c t h w -> (b t) c h w")
|
776 |
+
|
777 |
+
|
778 |
+
class VideoAutoencoderKL(torch.nn.Module, PyTorchModelHubMixin):
|
779 |
+
def __init__(self,
|
780 |
+
double_z=True,
|
781 |
+
z_channels=4,
|
782 |
+
resolution=256,
|
783 |
+
in_channels=3,
|
784 |
+
out_ch=3,
|
785 |
+
ch=128,
|
786 |
+
ch_mult=[],
|
787 |
+
num_res_blocks=2,
|
788 |
+
attn_resolutions=[],
|
789 |
+
dropout=0.0,
|
790 |
+
):
|
791 |
+
super().__init__()
|
792 |
+
self.encoder = Encoder(double_z=double_z, z_channels=z_channels, resolution=resolution, in_channels=in_channels,
|
793 |
+
out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
|
794 |
+
attn_resolutions=attn_resolutions, dropout=dropout)
|
795 |
+
self.decoder = VideoDecoder(double_z=double_z, z_channels=z_channels, resolution=resolution,
|
796 |
+
in_channels=in_channels, out_ch=out_ch, ch=ch, ch_mult=ch_mult,
|
797 |
+
num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout)
|
798 |
+
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
|
799 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
800 |
+
self.scale_factor = 0.18215
|
801 |
+
|
802 |
+
def encode(self, x, return_hidden_states=False, **kwargs):
|
803 |
+
if return_hidden_states:
|
804 |
+
h, hidden = self.encoder(x, return_hidden_states)
|
805 |
+
moments = self.quant_conv(h)
|
806 |
+
posterior = DiagonalGaussianDistribution(moments)
|
807 |
+
return posterior, hidden
|
808 |
+
else:
|
809 |
+
h = self.encoder(x)
|
810 |
+
moments = self.quant_conv(h)
|
811 |
+
posterior = DiagonalGaussianDistribution(moments)
|
812 |
+
return posterior, None
|
813 |
+
|
814 |
+
def decode(self, z, **kwargs):
|
815 |
+
if len(kwargs) == 0:
|
816 |
+
z = self.post_quant_conv(z)
|
817 |
+
dec = self.decoder(z, **kwargs)
|
818 |
+
return dec
|
819 |
+
|
820 |
+
@property
|
821 |
+
def device(self):
|
822 |
+
return next(self.parameters()).device
|
823 |
+
|
824 |
+
@property
|
825 |
+
def dtype(self):
|
826 |
+
return next(self.parameters()).dtype
|
gradio_app.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
|
4 |
+
result_dir = os.path.join('./', 'results')
|
5 |
+
os.makedirs(result_dir, exist_ok=True)
|
6 |
+
|
7 |
+
|
8 |
+
import functools
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import wd14tagger
|
15 |
+
import memory_management
|
16 |
+
import uuid
|
17 |
+
|
18 |
+
from PIL import Image
|
19 |
+
from diffusers_helper.code_cond import unet_add_coded_conds
|
20 |
+
from diffusers_helper.cat_cond import unet_add_concat_conds
|
21 |
+
from diffusers_helper.k_diffusion import KDiffusionSampler
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
23 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
24 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
25 |
+
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
26 |
+
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
27 |
+
|
28 |
+
|
29 |
+
class ModifiedUNet(UNet2DConditionModel):
|
30 |
+
@classmethod
|
31 |
+
def from_config(cls, *args, **kwargs):
|
32 |
+
m = super().from_config(*args, **kwargs)
|
33 |
+
unet_add_concat_conds(unet=m, new_channels=4)
|
34 |
+
unet_add_coded_conds(unet=m, added_number_count=1)
|
35 |
+
return m
|
36 |
+
|
37 |
+
|
38 |
+
model_name = 'lllyasviel/paints_undo_single_frame'
|
39 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
40 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
|
41 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
|
42 |
+
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
|
43 |
+
|
44 |
+
unet.set_attn_processor(AttnProcessor2_0())
|
45 |
+
vae.set_attn_processor(AttnProcessor2_0())
|
46 |
+
|
47 |
+
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
48 |
+
'lllyasviel/paints_undo_multi_frame',
|
49 |
+
fp16=True
|
50 |
+
)
|
51 |
+
|
52 |
+
memory_management.unload_all_models([
|
53 |
+
video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
|
54 |
+
unet, vae, text_encoder
|
55 |
+
])
|
56 |
+
|
57 |
+
k_sampler = KDiffusionSampler(
|
58 |
+
unet=unet,
|
59 |
+
timesteps=1000,
|
60 |
+
linear_start=0.00085,
|
61 |
+
linear_end=0.020,
|
62 |
+
linear=True
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def find_best_bucket(h, w, options):
|
67 |
+
min_metric = float('inf')
|
68 |
+
best_bucket = None
|
69 |
+
for (bucket_h, bucket_w) in options:
|
70 |
+
metric = abs(h * bucket_w - w * bucket_h)
|
71 |
+
if metric <= min_metric:
|
72 |
+
min_metric = metric
|
73 |
+
best_bucket = (bucket_h, bucket_w)
|
74 |
+
return best_bucket
|
75 |
+
|
76 |
+
|
77 |
+
@torch.inference_mode()
|
78 |
+
def encode_cropped_prompt_77tokens(txt: str):
|
79 |
+
memory_management.load_models_to_gpu(text_encoder)
|
80 |
+
cond_ids = tokenizer(txt,
|
81 |
+
padding="max_length",
|
82 |
+
max_length=tokenizer.model_max_length,
|
83 |
+
truncation=True,
|
84 |
+
return_tensors="pt").input_ids.to(device=text_encoder.device)
|
85 |
+
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
86 |
+
return text_cond
|
87 |
+
|
88 |
+
|
89 |
+
@torch.inference_mode()
|
90 |
+
def pytorch2numpy(imgs):
|
91 |
+
results = []
|
92 |
+
for x in imgs:
|
93 |
+
y = x.movedim(0, -1)
|
94 |
+
y = y * 127.5 + 127.5
|
95 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
96 |
+
results.append(y)
|
97 |
+
return results
|
98 |
+
|
99 |
+
|
100 |
+
@torch.inference_mode()
|
101 |
+
def numpy2pytorch(imgs):
|
102 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
103 |
+
h = h.movedim(-1, 1)
|
104 |
+
return h
|
105 |
+
|
106 |
+
|
107 |
+
def resize_without_crop(image, target_width, target_height):
|
108 |
+
pil_image = Image.fromarray(image)
|
109 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
110 |
+
return np.array(resized_image)
|
111 |
+
|
112 |
+
|
113 |
+
@torch.inference_mode()
|
114 |
+
def interrogator_process(x):
|
115 |
+
return wd14tagger.default_interrogator(x)
|
116 |
+
|
117 |
+
|
118 |
+
@torch.inference_mode()
|
119 |
+
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
120 |
+
progress=gr.Progress()):
|
121 |
+
rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
|
122 |
+
|
123 |
+
memory_management.load_models_to_gpu(vae)
|
124 |
+
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
125 |
+
concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
|
126 |
+
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
127 |
+
|
128 |
+
memory_management.load_models_to_gpu(text_encoder)
|
129 |
+
conds = encode_cropped_prompt_77tokens(prompt)
|
130 |
+
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
131 |
+
|
132 |
+
memory_management.load_models_to_gpu(unet)
|
133 |
+
fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
|
134 |
+
initial_latents = torch.zeros_like(concat_conds)
|
135 |
+
concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
|
136 |
+
latents = k_sampler(
|
137 |
+
initial_latent=initial_latents,
|
138 |
+
strength=1.0,
|
139 |
+
num_inference_steps=steps,
|
140 |
+
guidance_scale=cfg,
|
141 |
+
batch_size=len(input_undo_steps),
|
142 |
+
generator=rng,
|
143 |
+
prompt_embeds=conds,
|
144 |
+
negative_prompt_embeds=unconds,
|
145 |
+
cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs},
|
146 |
+
same_noise_in_batch=True,
|
147 |
+
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
148 |
+
).to(vae.dtype) / vae.config.scaling_factor
|
149 |
+
|
150 |
+
memory_management.load_models_to_gpu(vae)
|
151 |
+
pixels = vae.decode(latents).sample
|
152 |
+
pixels = pytorch2numpy(pixels)
|
153 |
+
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
154 |
+
|
155 |
+
return pixels
|
156 |
+
|
157 |
+
|
158 |
+
@torch.inference_mode()
|
159 |
+
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
160 |
+
random.seed(seed)
|
161 |
+
np.random.seed(seed)
|
162 |
+
torch.manual_seed(seed)
|
163 |
+
torch.cuda.manual_seed_all(seed)
|
164 |
+
|
165 |
+
frames = 16
|
166 |
+
|
167 |
+
target_height, target_width = find_best_bucket(
|
168 |
+
image_1.shape[0], image_1.shape[1],
|
169 |
+
options=[(320, 512), (384, 448), (448, 384), (512, 320)]
|
170 |
+
)
|
171 |
+
|
172 |
+
image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height)
|
173 |
+
image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height)
|
174 |
+
input_frames = numpy2pytorch([image_1, image_2])
|
175 |
+
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
176 |
+
|
177 |
+
memory_management.load_models_to_gpu(video_pipe.text_encoder)
|
178 |
+
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
179 |
+
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
180 |
+
|
181 |
+
memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
|
182 |
+
input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
|
183 |
+
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
184 |
+
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
185 |
+
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
186 |
+
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
187 |
+
|
188 |
+
memory_management.load_models_to_gpu([video_pipe.vae])
|
189 |
+
input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
|
190 |
+
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
191 |
+
first_frame = input_frame_latents[:, :, 0]
|
192 |
+
last_frame = input_frame_latents[:, :, 1]
|
193 |
+
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
194 |
+
|
195 |
+
memory_management.load_models_to_gpu([video_pipe.unet])
|
196 |
+
latents = video_pipe(
|
197 |
+
batch_size=1,
|
198 |
+
steps=int(steps),
|
199 |
+
guidance_scale=cfg_scale,
|
200 |
+
positive_text_cond=positive_text_cond,
|
201 |
+
negative_text_cond=negative_text_cond,
|
202 |
+
positive_image_cond=positive_image_cond,
|
203 |
+
negative_image_cond=negative_image_cond,
|
204 |
+
concat_cond=concat_cond,
|
205 |
+
fs=fs,
|
206 |
+
progress_tqdm=progress_tqdm
|
207 |
+
)
|
208 |
+
|
209 |
+
memory_management.load_models_to_gpu([video_pipe.vae])
|
210 |
+
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
211 |
+
return video, image_1, image_2
|
212 |
+
|
213 |
+
|
214 |
+
@torch.inference_mode()
|
215 |
+
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
216 |
+
result_frames = []
|
217 |
+
cropped_images = []
|
218 |
+
|
219 |
+
for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])):
|
220 |
+
im1 = np.array(Image.open(im1[0]))
|
221 |
+
im2 = np.array(Image.open(im2[0]))
|
222 |
+
frames, im1, im2 = process_video_inner(
|
223 |
+
im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3,
|
224 |
+
progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})')
|
225 |
+
)
|
226 |
+
result_frames.append(frames[:, :, :-1, :, :])
|
227 |
+
cropped_images.append([im1, im2])
|
228 |
+
|
229 |
+
video = torch.cat(result_frames, dim=2)
|
230 |
+
video = torch.flip(video, dims=[2])
|
231 |
+
|
232 |
+
uuid_name = str(uuid.uuid4())
|
233 |
+
output_filename = os.path.join(result_dir, uuid_name + '.mp4')
|
234 |
+
Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png'))
|
235 |
+
video = save_bcthw_as_mp4(video, output_filename, fps=fps)
|
236 |
+
video = [x.cpu().numpy() for x in video]
|
237 |
+
return output_filename, video
|
238 |
+
|
239 |
+
|
240 |
+
block = gr.Blocks().queue()
|
241 |
+
with block:
|
242 |
+
gr.Markdown('# Paints-Undo')
|
243 |
+
|
244 |
+
with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True):
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column():
|
247 |
+
input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512)
|
248 |
+
with gr.Column():
|
249 |
+
prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False)
|
250 |
+
prompt = gr.Textbox(label="Output Prompt", interactive=True)
|
251 |
+
|
252 |
+
with gr.Accordion(label='Step 2: Generate Key Frames', open=True):
|
253 |
+
with gr.Row():
|
254 |
+
with gr.Column():
|
255 |
+
input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999],
|
256 |
+
choices=list(range(1000)), multiselect=True)
|
257 |
+
seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345)
|
258 |
+
image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
|
259 |
+
image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
|
260 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
|
261 |
+
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01)
|
262 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
263 |
+
value='lowres, bad anatomy, bad hands, cropped, worst quality')
|
264 |
+
|
265 |
+
with gr.Column():
|
266 |
+
key_gen_button = gr.Button(value="Generate Key Frames", interactive=False)
|
267 |
+
result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4)
|
268 |
+
|
269 |
+
with gr.Accordion(label='Step 3: Generate All Videos', open=True):
|
270 |
+
with gr.Row():
|
271 |
+
with gr.Column():
|
272 |
+
i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality')
|
273 |
+
i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123)
|
274 |
+
i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5,
|
275 |
+
elem_id="i2v_cfg_scale")
|
276 |
+
i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps",
|
277 |
+
label="Sampling steps", value=50)
|
278 |
+
i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4)
|
279 |
+
with gr.Column():
|
280 |
+
i2v_end_btn = gr.Button("Generate Video", interactive=False)
|
281 |
+
i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
|
282 |
+
show_share_button=True, height=512)
|
283 |
+
with gr.Row():
|
284 |
+
i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8)
|
285 |
+
|
286 |
+
input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)],
|
287 |
+
outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn])
|
288 |
+
|
289 |
+
prompt_gen_button.click(
|
290 |
+
fn=interrogator_process,
|
291 |
+
inputs=[input_fg],
|
292 |
+
outputs=[prompt]
|
293 |
+
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
|
294 |
+
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
295 |
+
|
296 |
+
key_gen_button.click(
|
297 |
+
fn=process,
|
298 |
+
inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg],
|
299 |
+
outputs=[result_gallery]
|
300 |
+
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
|
301 |
+
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
302 |
+
|
303 |
+
i2v_end_btn.click(
|
304 |
+
inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed],
|
305 |
+
outputs=[i2v_output_video, i2v_output_images],
|
306 |
+
fn=process_video
|
307 |
+
)
|
308 |
+
|
309 |
+
dbs = [
|
310 |
+
['./imgs/1.jpg', 12345, 123],
|
311 |
+
['./imgs/2.jpg', 37000, 12345],
|
312 |
+
['./imgs/3.jpg', 3000, 3000],
|
313 |
+
]
|
314 |
+
|
315 |
+
gr.Examples(
|
316 |
+
examples=dbs,
|
317 |
+
inputs=[input_fg, seed, i2v_seed],
|
318 |
+
examples_per_page=1024
|
319 |
+
)
|
320 |
+
|
321 |
+
block.queue().launch(server_name='0.0.0.0')
|
imgs/1.jpg
ADDED
![]() |
imgs/2.jpg
ADDED
![]() |
imgs/3.jpg
ADDED
![]() |
memory_management.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from contextlib import contextmanager
|
3 |
+
|
4 |
+
|
5 |
+
high_vram = False
|
6 |
+
gpu = torch.device('cuda')
|
7 |
+
cpu = torch.device('cpu')
|
8 |
+
|
9 |
+
torch.zeros((1, 1)).to(gpu, torch.float32)
|
10 |
+
torch.cuda.empty_cache()
|
11 |
+
|
12 |
+
models_in_gpu = []
|
13 |
+
|
14 |
+
|
15 |
+
@contextmanager
|
16 |
+
def movable_bnb_model(m):
|
17 |
+
if hasattr(m, 'quantization_method'):
|
18 |
+
m.quantization_method_backup = m.quantization_method
|
19 |
+
del m.quantization_method
|
20 |
+
try:
|
21 |
+
yield None
|
22 |
+
finally:
|
23 |
+
if hasattr(m, 'quantization_method_backup'):
|
24 |
+
m.quantization_method = m.quantization_method_backup
|
25 |
+
del m.quantization_method_backup
|
26 |
+
return
|
27 |
+
|
28 |
+
|
29 |
+
def load_models_to_gpu(models):
|
30 |
+
global models_in_gpu
|
31 |
+
|
32 |
+
if not isinstance(models, (tuple, list)):
|
33 |
+
models = [models]
|
34 |
+
|
35 |
+
models_to_remain = [m for m in set(models) if m in models_in_gpu]
|
36 |
+
models_to_load = [m for m in set(models) if m not in models_in_gpu]
|
37 |
+
models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
|
38 |
+
|
39 |
+
if not high_vram:
|
40 |
+
for m in models_to_unload:
|
41 |
+
with movable_bnb_model(m):
|
42 |
+
m.to(cpu)
|
43 |
+
print('Unload to CPU:', m.__class__.__name__)
|
44 |
+
models_in_gpu = models_to_remain
|
45 |
+
|
46 |
+
for m in models_to_load:
|
47 |
+
with movable_bnb_model(m):
|
48 |
+
m.to(gpu)
|
49 |
+
print('Load to GPU:', m.__class__.__name__)
|
50 |
+
|
51 |
+
models_in_gpu = list(set(models_in_gpu + models))
|
52 |
+
torch.cuda.empty_cache()
|
53 |
+
return
|
54 |
+
|
55 |
+
|
56 |
+
def unload_all_models(extra_models=None):
|
57 |
+
global models_in_gpu
|
58 |
+
|
59 |
+
if extra_models is None:
|
60 |
+
extra_models = []
|
61 |
+
|
62 |
+
if not isinstance(extra_models, (tuple, list)):
|
63 |
+
extra_models = [extra_models]
|
64 |
+
|
65 |
+
models_in_gpu = list(set(models_in_gpu + extra_models))
|
66 |
+
|
67 |
+
return load_models_to_gpu([])
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.28.0
|
2 |
+
transformers==4.41.1
|
3 |
+
gradio==4.31.5
|
4 |
+
bitsandbytes==0.43.1
|
5 |
+
accelerate==0.30.1
|
6 |
+
protobuf==3.20
|
7 |
+
opencv-python
|
8 |
+
tensorboardX
|
9 |
+
safetensors
|
10 |
+
pillow
|
11 |
+
einops
|
12 |
+
torch
|
13 |
+
peft
|
14 |
+
xformers
|
15 |
+
onnxruntime
|
16 |
+
av
|
wd14tagger.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
import csv
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime as ort
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from onnxruntime import InferenceSession
|
11 |
+
from torch.hub import download_url_to_file
|
12 |
+
|
13 |
+
|
14 |
+
global_model = None
|
15 |
+
global_csv = None
|
16 |
+
|
17 |
+
|
18 |
+
def download_model(url, local_path):
|
19 |
+
if os.path.exists(local_path):
|
20 |
+
return local_path
|
21 |
+
|
22 |
+
temp_path = local_path + '.tmp'
|
23 |
+
download_url_to_file(url=url, dst=temp_path)
|
24 |
+
os.rename(temp_path, local_path)
|
25 |
+
return local_path
|
26 |
+
|
27 |
+
|
28 |
+
def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""):
|
29 |
+
global global_model, global_csv
|
30 |
+
|
31 |
+
model_name = "wd-v1-4-moat-tagger-v2"
|
32 |
+
|
33 |
+
model_onnx_filename = download_model(
|
34 |
+
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
|
35 |
+
local_path=f'./{model_name}.onnx',
|
36 |
+
)
|
37 |
+
|
38 |
+
model_csv_filename = download_model(
|
39 |
+
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
|
40 |
+
local_path=f'./{model_name}.csv',
|
41 |
+
)
|
42 |
+
|
43 |
+
if global_model is not None:
|
44 |
+
model = global_model
|
45 |
+
else:
|
46 |
+
# assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!'
|
47 |
+
# model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider'])
|
48 |
+
model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider'])
|
49 |
+
global_model = model
|
50 |
+
|
51 |
+
input = model.get_inputs()[0]
|
52 |
+
height = input.shape[1]
|
53 |
+
|
54 |
+
if isinstance(image, str):
|
55 |
+
image = Image.open(image) # RGB
|
56 |
+
elif isinstance(image, np.ndarray):
|
57 |
+
image = Image.fromarray(image)
|
58 |
+
else:
|
59 |
+
image = image
|
60 |
+
|
61 |
+
ratio = float(height) / max(image.size)
|
62 |
+
new_size = tuple([int(x*ratio) for x in image.size])
|
63 |
+
image = image.resize(new_size, Image.LANCZOS)
|
64 |
+
square = Image.new("RGB", (height, height), (255, 255, 255))
|
65 |
+
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
|
66 |
+
|
67 |
+
image = np.array(square).astype(np.float32)
|
68 |
+
image = image[:, :, ::-1] # RGB -> BGR
|
69 |
+
image = np.expand_dims(image, 0)
|
70 |
+
|
71 |
+
if global_csv is not None:
|
72 |
+
csv_lines = global_csv
|
73 |
+
else:
|
74 |
+
csv_lines = []
|
75 |
+
with open(model_csv_filename) as f:
|
76 |
+
reader = csv.reader(f)
|
77 |
+
next(reader)
|
78 |
+
for row in reader:
|
79 |
+
csv_lines.append(row)
|
80 |
+
global_csv = csv_lines
|
81 |
+
|
82 |
+
tags = []
|
83 |
+
general_index = None
|
84 |
+
character_index = None
|
85 |
+
for line_num, row in enumerate(csv_lines):
|
86 |
+
if general_index is None and row[2] == "0":
|
87 |
+
general_index = line_num
|
88 |
+
elif character_index is None and row[2] == "4":
|
89 |
+
character_index = line_num
|
90 |
+
tags.append(row[1])
|
91 |
+
|
92 |
+
label_name = model.get_outputs()[0].name
|
93 |
+
probs = model.run([label_name], {input.name: image})[0]
|
94 |
+
|
95 |
+
result = list(zip(tags, probs[0]))
|
96 |
+
|
97 |
+
general = [item for item in result[general_index:character_index] if item[1] > threshold]
|
98 |
+
character = [item for item in result[character_index:] if item[1] > character_threshold]
|
99 |
+
|
100 |
+
all = character + general
|
101 |
+
remove = [s.strip() for s in exclude_tags.lower().split(",")]
|
102 |
+
all = [tag for tag in all if tag[0] not in remove]
|
103 |
+
|
104 |
+
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
|
105 |
+
return res
|