Spaces:
Running
on
Zero
Running
on
Zero
update requirements.
Browse files- .gitattributes +15 -0
- LICENSE +233 -0
- README.md +4 -4
- app.py +1066 -0
- model/__init__.py +0 -0
- model/dera.py +195 -0
- model/dit.py +1090 -0
- model/image_encoder.py +903 -0
- model/prompter.py +107 -0
- model/text_encoder.py +269 -0
- model/vae.py +809 -0
- pipeline/__init__.py +0 -0
- pipeline/i2v_pipeline.py +511 -0
- requirements.txt +148 -0
- samples/1_image1.png +3 -0
- samples/1_out.mp4 +3 -0
- samples/1_prompt.txt +1 -0
- samples/1_sketch1.jpg +3 -0
- samples/1_sketch2.jpg +3 -0
- samples/1_sketch3.jpg +3 -0
- samples/2_image1.jpg +3 -0
- samples/2_out.mp4 +3 -0
- samples/2_prompt.txt +1 -0
- samples/2_sketch1.jpg +3 -0
- samples/2_sketch2.jpg +3 -0
- samples/3_image1.png +3 -0
- samples/3_out.mp4 +3 -0
- samples/3_prompt.txt +1 -0
- samples/3_sketch1.jpg +3 -0
- samples/ToonComposer-Icon.png +3 -0
- samples/ToonComposer-Method.jpg +3 -0
- samples/ToonComposer-TLDR.jpg +3 -0
- scheduler/__init__.py +0 -0
- scheduler/flow_match.py +78 -0
- tooncomposer.py +234 -0
- util/model_util.py +241 -0
- util/optical_flow.py +140 -0
- util/stylesheets.py +0 -0
- util/training_util.py +317 -0
.gitattributes
CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
samples/1_out.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
samples/2_out.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
samples/3_out.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
samples/1_image1.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
samples/3_image1.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
samples/ToonComposer-Icon.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
samples/1_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
samples/1_sketch3.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
samples/2_image1.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
samples/1_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
samples/2_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
samples/2_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
+
samples/3_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
samples/ToonComposer-Method.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
samples/ToonComposer-TLDR.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making ToonComposer available.
|
2 |
+
|
3 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
4 |
+
|
5 |
+
ToonComposer is licensed under the MIT License except for the third-party components listed below, which is licensed under different terms. ToonComposer does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
6 |
+
|
7 |
+
For avoidance of doubts, ToonComposer refers to the inference code, parameters and weights made publicly available by Tencent in accordance with the MIT License in this repository.
|
8 |
+
|
9 |
+
Terms of the MIT License:
|
10 |
+
--------------------------------------------------------------------
|
11 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
12 |
+
|
13 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the " Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice (including the next paragraph) shall be included in all copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
18 |
+
|
19 |
+
|
20 |
+
The ToonComposer model was developed by Tencent based on the following Open Models.
|
21 |
+
The ToonComposer inference code was developed by Tencent based on the code of the following Open Models.The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
22 |
+
|
23 |
+
Open Models Licensed under the Apache-2.0 License:
|
24 |
+
|
25 |
+
--------------------------------------------------------------------
|
26 |
+
1.Wan2.1
|
27 |
+
Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
28 |
+
The code of this model was modified by Tencent.
|
29 |
+
|
30 |
+
--------------------------------------------------------------------
|
31 |
+
Terms of the Apache-2.0 License:
|
32 |
+
--------------------------------------------------------------------
|
33 |
+
Apache License
|
34 |
+
Version 2.0, January 2004
|
35 |
+
http://www.apache.org/licenses/
|
36 |
+
|
37 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
38 |
+
|
39 |
+
1. Definitions.
|
40 |
+
|
41 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
42 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
43 |
+
|
44 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
45 |
+
the copyright owner that is granting the License.
|
46 |
+
|
47 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
48 |
+
other entities that control, are controlled by, or are under common
|
49 |
+
control with that entity. For the purposes of this definition,
|
50 |
+
"control" means (i) the power, direct or indirect, to cause the
|
51 |
+
direction or management of such entity, whether by contract or
|
52 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
53 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
54 |
+
|
55 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
56 |
+
exercising permissions granted by this License.
|
57 |
+
|
58 |
+
"Source" form shall mean the preferred form for making modifications,
|
59 |
+
including but not limited to software source code, documentation
|
60 |
+
source, and configuration files.
|
61 |
+
|
62 |
+
"Object" form shall mean any form resulting from mechanical
|
63 |
+
transformation or translation of a Source form, including but
|
64 |
+
not limited to compiled object code, generated documentation,
|
65 |
+
and conversions to other media types.
|
66 |
+
|
67 |
+
"Work" shall mean the work of authorship, whether in Source or
|
68 |
+
Object form, made available under the License, as indicated by a
|
69 |
+
copyright notice that is included in or attached to the work
|
70 |
+
(an example is provided in the Appendix below).
|
71 |
+
|
72 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
73 |
+
form, that is based on (or derived from) the Work and for which the
|
74 |
+
editorial revisions, annotations, elaborations, or other modifications
|
75 |
+
represent, as a whole, an original work of authorship. For the purposes
|
76 |
+
of this License, Derivative Works shall not include works that remain
|
77 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
78 |
+
the Work and Derivative Works thereof.
|
79 |
+
|
80 |
+
"Contribution" shall mean any work of authorship, including
|
81 |
+
the original version of the Work and any modifications or additions
|
82 |
+
to that Work or Derivative Works thereof, that is intentionally
|
83 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
84 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
85 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
86 |
+
means any form of electronic, verbal, or written communication sent
|
87 |
+
to the Licensor or its representatives, including but not limited to
|
88 |
+
communication on electronic mailing lists, source code control systems,
|
89 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
90 |
+
Licensor for the purpose of discussing and improving the Work, but
|
91 |
+
excluding communication that is conspicuously marked or otherwise
|
92 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
93 |
+
|
94 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
95 |
+
on behalf of whom a Contribution has been received by Licensor and
|
96 |
+
subsequently incorporated within the Work.
|
97 |
+
|
98 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
99 |
+
this License, each Contributor hereby grants to You a perpetual,
|
100 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
101 |
+
copyright license to reproduce, prepare Derivative Works of,
|
102 |
+
publicly display, publicly perform, sublicense, and distribute the
|
103 |
+
Work and such Derivative Works in Source or Object form.
|
104 |
+
|
105 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
106 |
+
this License, each Contributor hereby grants to You a perpetual,
|
107 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
108 |
+
(except as stated in this section) patent license to make, have made,
|
109 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
110 |
+
where such license applies only to those patent claims licensable
|
111 |
+
by such Contributor that are necessarily infringed by their
|
112 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
113 |
+
with the Work to which such Contribution(s) was submitted. If You
|
114 |
+
institute patent litigation against any entity (including a
|
115 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
116 |
+
or a Contribution incorporated within the Work constitutes direct
|
117 |
+
or contributory patent infringement, then any patent licenses
|
118 |
+
granted to You under this License for that Work shall terminate
|
119 |
+
as of the date such litigation is filed.
|
120 |
+
|
121 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
122 |
+
Work or Derivative Works thereof in any medium, with or without
|
123 |
+
modifications, and in Source or Object form, provided that You
|
124 |
+
meet the following conditions:
|
125 |
+
|
126 |
+
(a) You must give any other recipients of the Work or
|
127 |
+
Derivative Works a copy of this License; and
|
128 |
+
|
129 |
+
(b) You must cause any modified files to carry prominent notices
|
130 |
+
stating that You changed the files; and
|
131 |
+
|
132 |
+
(c) You must retain, in the Source form of any Derivative Works
|
133 |
+
that You distribute, all copyright, patent, trademark, and
|
134 |
+
attribution notices from the Source form of the Work,
|
135 |
+
excluding those notices that do not pertain to any part of
|
136 |
+
the Derivative Works; and
|
137 |
+
|
138 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
139 |
+
distribution, then any Derivative Works that You distribute must
|
140 |
+
include a readable copy of the attribution notices contained
|
141 |
+
within such NOTICE file, excluding those notices that do not
|
142 |
+
pertain to any part of the Derivative Works, in at least one
|
143 |
+
of the following places: within a NOTICE text file distributed
|
144 |
+
as part of the Derivative Works; within the Source form or
|
145 |
+
documentation, if provided along with the Derivative Works; or,
|
146 |
+
within a display generated by the Derivative Works, if and
|
147 |
+
wherever such third-party notices normally appear. The contents
|
148 |
+
of the NOTICE file are for informational purposes only and
|
149 |
+
do not modify the License. You may add Your own attribution
|
150 |
+
notices within Derivative Works that You distribute, alongside
|
151 |
+
or as an addendum to the NOTICE text from the Work, provided
|
152 |
+
that such additional attribution notices cannot be construed
|
153 |
+
as modifying the License.
|
154 |
+
|
155 |
+
You may add Your own copyright statement to Your modifications and
|
156 |
+
may provide additional or different license terms and conditions
|
157 |
+
for use, reproduction, or distribution of Your modifications, or
|
158 |
+
for any such Derivative Works as a whole, provided Your use,
|
159 |
+
reproduction, and distribution of the Work otherwise complies with
|
160 |
+
the conditions stated in this License.
|
161 |
+
|
162 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
163 |
+
any Contribution intentionally submitted for inclusion in the Work
|
164 |
+
by You to the Licensor shall be under the terms and conditions of
|
165 |
+
this License, without any additional terms or conditions.
|
166 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
167 |
+
the terms of any separate license agreement you may have executed
|
168 |
+
with Licensor regarding such Contributions.
|
169 |
+
|
170 |
+
6. Trademarks. This License does not grant permission to use the trade
|
171 |
+
names, trademarks, service marks, or product names of the Licensor,
|
172 |
+
except as required for reasonable and customary use in describing the
|
173 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
174 |
+
|
175 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
176 |
+
agreed to in writing, Licensor provides the Work (and each
|
177 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
178 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
179 |
+
implied, including, without limitation, any warranties or conditions
|
180 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
181 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
182 |
+
appropriateness of using or redistributing the Work and assume any
|
183 |
+
risks associated with Your exercise of permissions under this License.
|
184 |
+
|
185 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
186 |
+
whether in tort (including negligence), contract, or otherwise,
|
187 |
+
unless required by applicable law (such as deliberate and grossly
|
188 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
189 |
+
liable to You for damages, including any direct, indirect, special,
|
190 |
+
incidental, or consequential damages of any character arising as a
|
191 |
+
result of this License or out of the use or inability to use the
|
192 |
+
Work (including but not limited to damages for loss of goodwill,
|
193 |
+
work stoppage, computer failure or malfunction, or any and all
|
194 |
+
other commercial damages or losses), even if such Contributor
|
195 |
+
has been advised of the possibility of such damages.
|
196 |
+
|
197 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
198 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
199 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
200 |
+
or other liability obligations and/or rights consistent with this
|
201 |
+
License. However, in accepting such obligations, You may act only
|
202 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
203 |
+
of any other Contributor, and only if You agree to indemnify,
|
204 |
+
defend, and hold each Contributor harmless for any liability
|
205 |
+
incurred by, or claims asserted against, such Contributor by reason
|
206 |
+
of your accepting any such warranty or additional liability.
|
207 |
+
|
208 |
+
END OF TERMS AND CONDITIONS
|
209 |
+
|
210 |
+
APPENDIX: How to apply the Apache License to your work.
|
211 |
+
|
212 |
+
To apply the Apache License to your work, attach the following
|
213 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
214 |
+
replaced with your own identifying information. (Don't include
|
215 |
+
the brackets!) The text should be enclosed in the appropriate
|
216 |
+
comment syntax for the file format. We also recommend that a
|
217 |
+
file or class name and description of purpose be included on the
|
218 |
+
same "printed page" as the copyright notice for easier
|
219 |
+
identification within third-party archives.
|
220 |
+
|
221 |
+
Copyright [yyyy] [name of copyright owner]
|
222 |
+
|
223 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
224 |
+
you may not use this file except in compliance with the License.
|
225 |
+
You may obtain a copy of the License at
|
226 |
+
|
227 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
228 |
+
|
229 |
+
Unless required by applicable law or agreed to in writing, software
|
230 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
231 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
232 |
+
See the License for the specific language governing permissions and
|
233 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
title: ToonComposer
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: ToonComposer
|
3 |
+
emoji: 🎨
|
4 |
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,1066 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tooncomposer import ToonComposer, get_base_model_paths
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
from util.training_util import extract_img_to_sketch
|
8 |
+
import os
|
9 |
+
import tempfile
|
10 |
+
import cv2
|
11 |
+
import gradio as gr
|
12 |
+
from einops import rearrange
|
13 |
+
from datetime import datetime
|
14 |
+
from typing import Optional, List, Dict
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
|
17 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
18 |
+
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
+
# Weights resolution and download helpers
|
21 |
+
# -----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
WAN_REPO_ID = "Wan-AI/Wan2.1-I2V-14B-480P"
|
24 |
+
TOONCOMPOSER_REPO_ID = "TencentARC/ToonComposer"
|
25 |
+
|
26 |
+
def _path_is_dir_with_files(dir_path: str, required_files: List[str]) -> bool:
|
27 |
+
if not dir_path or not os.path.isdir(dir_path):
|
28 |
+
return False
|
29 |
+
for f in required_files:
|
30 |
+
if not os.path.exists(os.path.join(dir_path, f)):
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
def resolve_wan_model_root(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
|
35 |
+
"""Return a directory containing Wan2.1-I2V-14B-480P weights.
|
36 |
+
|
37 |
+
Resolution order:
|
38 |
+
1) preferred_dir arg (if valid)
|
39 |
+
2) WAN21_I2V_DIR env var (if valid)
|
40 |
+
3) HF local cache (no download) via snapshot_download(local_files_only=True)
|
41 |
+
4) HF download to cache via snapshot_download()
|
42 |
+
"""
|
43 |
+
# Required filenames relative to the model root
|
44 |
+
expected = get_base_model_paths("Wan2.1-I2V-14B-480P", format='dict', model_root=".")
|
45 |
+
required_files = []
|
46 |
+
required_files.extend([os.path.basename(p) for p in expected["dit"]])
|
47 |
+
required_files.append(os.path.basename(expected["image_encoder"]))
|
48 |
+
required_files.append(os.path.basename(expected["text_encoder"]))
|
49 |
+
required_files.append(os.path.basename(expected["vae"]))
|
50 |
+
|
51 |
+
# 1) preferred_dir arg
|
52 |
+
if _path_is_dir_with_files(preferred_dir or "", required_files):
|
53 |
+
return os.path.abspath(preferred_dir)
|
54 |
+
|
55 |
+
# 2) environment variable
|
56 |
+
env_dir = os.environ.get("WAN21_I2V_DIR")
|
57 |
+
if _path_is_dir_with_files(env_dir or "", required_files):
|
58 |
+
return os.path.abspath(env_dir)
|
59 |
+
|
60 |
+
# 3) try local cache without network
|
61 |
+
try:
|
62 |
+
cached_dir = snapshot_download(repo_id=WAN_REPO_ID, local_files_only=True)
|
63 |
+
return cached_dir
|
64 |
+
except Exception:
|
65 |
+
pass
|
66 |
+
|
67 |
+
# 4) download (may be large)
|
68 |
+
cached_dir = snapshot_download(repo_id=WAN_REPO_ID, token=hf_token)
|
69 |
+
return cached_dir
|
70 |
+
|
71 |
+
def resolve_tooncomposer_repo_dir(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
|
72 |
+
"""Return a directory containing ToonComposer repo with 480p/608p subdirs."""
|
73 |
+
# Quick validity check: ensure either a subdir 480p or 608p exists with required files
|
74 |
+
def has_resolution_dirs(base_dir: str) -> bool:
|
75 |
+
if not base_dir or not os.path.isdir(base_dir):
|
76 |
+
return False
|
77 |
+
ok = False
|
78 |
+
for res in ["480p", "608p"]:
|
79 |
+
d = os.path.join(base_dir, res)
|
80 |
+
if os.path.isdir(d):
|
81 |
+
ckpt = os.path.join(d, "tooncomposer.ckpt")
|
82 |
+
cfg = os.path.join(d, "config.json")
|
83 |
+
if os.path.exists(ckpt) and os.path.exists(cfg):
|
84 |
+
ok = True
|
85 |
+
return ok
|
86 |
+
|
87 |
+
# 1) preferred_dir arg
|
88 |
+
if has_resolution_dirs(preferred_dir or ""):
|
89 |
+
return os.path.abspath(preferred_dir)
|
90 |
+
|
91 |
+
# 2) environment variable
|
92 |
+
env_dir = os.environ.get("TOONCOMPOSER_DIR")
|
93 |
+
if has_resolution_dirs(env_dir or ""):
|
94 |
+
return os.path.abspath(env_dir)
|
95 |
+
|
96 |
+
# 3) try local cache first
|
97 |
+
try:
|
98 |
+
cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, local_files_only=True)
|
99 |
+
return cached_dir
|
100 |
+
except Exception:
|
101 |
+
pass
|
102 |
+
|
103 |
+
# 4) download repo to cache
|
104 |
+
cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, token=hf_token)
|
105 |
+
return cached_dir
|
106 |
+
|
107 |
+
def build_checkpoints_by_resolution(tooncomposer_base_dir: str) -> Dict[str, Dict[str, object]]:
|
108 |
+
"""Construct resolution mapping from a base repo dir that contains 480p/608p.
|
109 |
+
|
110 |
+
The ToonComposer HF repo stores, inside each resolution dir:
|
111 |
+
- tooncomposer.ckpt
|
112 |
+
- config.json (model configuration)
|
113 |
+
"""
|
114 |
+
mapping = {}
|
115 |
+
# Known target sizes
|
116 |
+
res_to_hw = {
|
117 |
+
"480p": (480, 832),
|
118 |
+
"608p": (608, 1088),
|
119 |
+
}
|
120 |
+
for res, (h, w) in res_to_hw.items():
|
121 |
+
res_dir = os.path.join(tooncomposer_base_dir, res)
|
122 |
+
mapping[res] = {
|
123 |
+
"target_height": h,
|
124 |
+
"target_width": w,
|
125 |
+
"snapshot_args_path": os.path.join(res_dir, "config.json"),
|
126 |
+
"checkpoint_path": os.path.join(res_dir, "tooncomposer.ckpt"),
|
127 |
+
}
|
128 |
+
return mapping
|
129 |
+
|
130 |
+
# Will be populated in main() after resolving ToonComposer repo directory
|
131 |
+
checkpoints_by_resolution = {}
|
132 |
+
|
133 |
+
def tensor2video(frames):
|
134 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
135 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
136 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
137 |
+
return frames
|
138 |
+
|
139 |
+
def _load_model_config(config_path: str) -> Dict[str, object]:
|
140 |
+
with open(config_path, "r") as f:
|
141 |
+
data = json.load(f)
|
142 |
+
return data
|
143 |
+
|
144 |
+
def _merge_with_defaults(cfg: Dict[str, object]) -> Dict[str, object]:
|
145 |
+
# Provide safe defaults for optional fields used at inference-time
|
146 |
+
defaults = {
|
147 |
+
"base_model_name": "Wan2.1-I2V-14B-480P",
|
148 |
+
"learning_rate": 1e-5,
|
149 |
+
"train_architecture": "lora",
|
150 |
+
"lora_rank": 4,
|
151 |
+
"lora_alpha": 4,
|
152 |
+
"lora_target_modules": "q,k,v,o,ffn.0,ffn.2",
|
153 |
+
"init_lora_weights": "kaiming",
|
154 |
+
"use_gradient_checkpointing": True,
|
155 |
+
"tiled": False,
|
156 |
+
"tile_size_height": 34,
|
157 |
+
"tile_size_width": 34,
|
158 |
+
"tile_stride_height": 18,
|
159 |
+
"tile_stride_width": 16,
|
160 |
+
"output_path": "./",
|
161 |
+
"use_local_lora": False,
|
162 |
+
"use_dera": False,
|
163 |
+
"dera_rank": None,
|
164 |
+
"use_dera_spatial": True,
|
165 |
+
"use_dera_temporal": True,
|
166 |
+
"use_sequence_cond": True,
|
167 |
+
"sequence_cond_mode": "sparse",
|
168 |
+
"use_channel_cond": False,
|
169 |
+
"use_sequence_cond_position_aware_residual": True,
|
170 |
+
"use_sequence_cond_loss": False,
|
171 |
+
"fast_dev": False,
|
172 |
+
"max_num_cond_images": 1,
|
173 |
+
"max_num_cond_sketches": 2,
|
174 |
+
"visualize_attention": False,
|
175 |
+
"random_spaced_cond_frames": False,
|
176 |
+
"use_sketch_mask": True,
|
177 |
+
"sketch_mask_ratio": 0.2,
|
178 |
+
"no_first_sketch": False,
|
179 |
+
}
|
180 |
+
merged = defaults.copy()
|
181 |
+
merged.update(cfg)
|
182 |
+
return merged
|
183 |
+
|
184 |
+
def initialize_model(resolution="480p", fast_dev=False, device="cuda:0", dtype=torch.bfloat16,
|
185 |
+
wan_model_dir: Optional[str] = None, tooncomposer_dir: Optional[str] = None,
|
186 |
+
hf_token: Optional[str] = None):
|
187 |
+
# Initialize model components
|
188 |
+
if resolution not in checkpoints_by_resolution:
|
189 |
+
raise ValueError(f"Resolution '{resolution}' is not available. Found: {list(checkpoints_by_resolution.keys())}")
|
190 |
+
|
191 |
+
# 1) resolve config and checkpoint from ToonComposer repo (local or HF)
|
192 |
+
snapshot_args_path = checkpoints_by_resolution[resolution]["snapshot_args_path"]
|
193 |
+
checkpoint_path = checkpoints_by_resolution[resolution]["checkpoint_path"]
|
194 |
+
|
195 |
+
# 2) load model config
|
196 |
+
snapshot_args_raw = _load_model_config(snapshot_args_path)
|
197 |
+
snapshot_args = _merge_with_defaults(snapshot_args_raw)
|
198 |
+
snapshot_args["checkpoint_path"] = checkpoint_path
|
199 |
+
|
200 |
+
# 3) resolve Wan2.1 model root
|
201 |
+
snapshot_args["model_root"] = resolve_wan_model_root(preferred_dir=wan_model_dir, hf_token=hf_token)
|
202 |
+
|
203 |
+
# Backward-compat fields
|
204 |
+
if "training_max_frame_stride" not in snapshot_args:
|
205 |
+
snapshot_args["training_max_frame_stride"] = 4
|
206 |
+
snapshot_args["random_spaced_cond_frames"] = False
|
207 |
+
args = argparse.Namespace(**snapshot_args)
|
208 |
+
if not fast_dev:
|
209 |
+
model = ToonComposer(
|
210 |
+
base_model_name=args.base_model_name,
|
211 |
+
model_root=args.model_root,
|
212 |
+
learning_rate=args.learning_rate,
|
213 |
+
train_architecture=args.train_architecture,
|
214 |
+
lora_rank=args.lora_rank,
|
215 |
+
lora_alpha=args.lora_alpha,
|
216 |
+
lora_target_modules=args.lora_target_modules,
|
217 |
+
init_lora_weights=args.init_lora_weights,
|
218 |
+
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
219 |
+
checkpoint_path=args.checkpoint_path,
|
220 |
+
tiled=args.tiled,
|
221 |
+
tile_size=(args.tile_size_height, args.tile_size_width),
|
222 |
+
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
223 |
+
output_path=args.output_path,
|
224 |
+
use_local_lora=args.use_local_lora,
|
225 |
+
use_dera=args.use_dera,
|
226 |
+
dera_rank=args.dera_rank,
|
227 |
+
use_dera_spatial=args.use_dera_spatial,
|
228 |
+
use_dera_temporal=args.use_dera_temporal,
|
229 |
+
use_sequence_cond=args.use_sequence_cond,
|
230 |
+
sequence_cond_mode=args.sequence_cond_mode,
|
231 |
+
use_channel_cond=args.use_channel_cond,
|
232 |
+
use_sequence_cond_position_aware_residual=args.use_sequence_cond_position_aware_residual,
|
233 |
+
use_sequence_cond_loss=args.use_sequence_cond_loss,
|
234 |
+
fast_dev=args.fast_dev,
|
235 |
+
max_num_cond_images=args.max_num_cond_images,
|
236 |
+
max_num_cond_sketches=args.max_num_cond_sketches,
|
237 |
+
visualize_attention=args.visualize_attention,
|
238 |
+
random_spaced_cond_frames=args.random_spaced_cond_frames,
|
239 |
+
use_sketch_mask=args.use_sketch_mask,
|
240 |
+
sketch_mask_ratio=args.sketch_mask_ratio,
|
241 |
+
no_first_sketch=args.no_first_sketch,
|
242 |
+
)
|
243 |
+
model = model.to(device, dtype=dtype).eval()
|
244 |
+
else:
|
245 |
+
print("Fast dev mode. Models will not be loaded.")
|
246 |
+
model = None
|
247 |
+
print("Models initialized.")
|
248 |
+
return model, device, dtype
|
249 |
+
|
250 |
+
# -----------------------------------------------------------------------------
|
251 |
+
# CLI args and global initialization
|
252 |
+
# -----------------------------------------------------------------------------
|
253 |
+
|
254 |
+
def _parse_args():
|
255 |
+
parser = argparse.ArgumentParser()
|
256 |
+
parser.add_argument("--resolution", type=str, default=os.environ.get("TOONCOMPOSER_RESOLUTION", "480p"), choices=["480p", "608p"], help="Target resolution to load by default.")
|
257 |
+
parser.add_argument("--device", type=str, default=os.environ.get("DEVICE", "cuda"))
|
258 |
+
parser.add_argument("--dtype", type=str, default=os.environ.get("DTYPE", "bfloat16"), choices=["bfloat16", "float32"])
|
259 |
+
parser.add_argument("--wan_model_dir", type=str, default=os.environ.get("WAN21_I2V_DIR"), help="Local directory containing Wan2.1 model files. If not provided, will try HF cache and download if needed.")
|
260 |
+
parser.add_argument("--tooncomposer_dir", type=str, default=os.environ.get("TOONCOMPOSER_DIR"), help="Local directory containing ToonComposer weights with 480p/608p subdirectories. If not provided, will try HF cache and download if needed.")
|
261 |
+
parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="Hugging Face token (if needed for gated models).")
|
262 |
+
parser.add_argument("--fast_dev", action="store_true", help="Run in fast dev mode without loading heavy models.")
|
263 |
+
return parser.parse_args()
|
264 |
+
|
265 |
+
_cli_args = _parse_args()
|
266 |
+
|
267 |
+
# Resolve ToonComposer repo dir and build resolution mapping
|
268 |
+
_toon_dir = resolve_tooncomposer_repo_dir(preferred_dir=_cli_args.tooncomposer_dir, hf_token=_cli_args.hf_token)
|
269 |
+
checkpoints_by_resolution = build_checkpoints_by_resolution(_toon_dir)
|
270 |
+
|
271 |
+
_dtype_map = {
|
272 |
+
"bfloat16": torch.bfloat16,
|
273 |
+
"float32": torch.float32,
|
274 |
+
}
|
275 |
+
fast_dev = bool(_cli_args.fast_dev)
|
276 |
+
model, device, dtype = initialize_model(
|
277 |
+
resolution=_cli_args.resolution,
|
278 |
+
fast_dev=fast_dev,
|
279 |
+
device=_cli_args.device,
|
280 |
+
dtype=_dtype_map[_cli_args.dtype],
|
281 |
+
wan_model_dir=_cli_args.wan_model_dir,
|
282 |
+
tooncomposer_dir=_cli_args.tooncomposer_dir,
|
283 |
+
hf_token=_cli_args.hf_token,
|
284 |
+
)
|
285 |
+
|
286 |
+
def process_conditions(num_items, item_inputs, num_frames, is_sketch=False, target_height=480, target_width=832):
|
287 |
+
"""Process condition images/sketches into masked video tensor and mask"""
|
288 |
+
# Create empty tensors filled with -1
|
289 |
+
video = torch.zeros((1, 3, num_frames, target_height, target_width), device=device)
|
290 |
+
mask = torch.zeros((1, num_frames), device=device)
|
291 |
+
|
292 |
+
for i in range(num_items):
|
293 |
+
img, frame_idx = item_inputs[i]
|
294 |
+
if img is None or frame_idx is None:
|
295 |
+
continue
|
296 |
+
|
297 |
+
# Convert PIL image to tensor
|
298 |
+
img_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0
|
299 |
+
if is_sketch:
|
300 |
+
img_tensor = -img_tensor
|
301 |
+
img_tensor = img_tensor.unsqueeze(0).to(device)
|
302 |
+
|
303 |
+
# Resize to model's expected resolution while preserving aspect ratio
|
304 |
+
# Get original dimensions
|
305 |
+
_, _, h, w = img_tensor.shape
|
306 |
+
|
307 |
+
# Resize based on short edge while maintaining aspect ratio
|
308 |
+
if h/w < target_height/target_width:
|
309 |
+
new_h = target_height
|
310 |
+
new_w = int(w * (new_h / h))
|
311 |
+
else: # Width is the short edge
|
312 |
+
new_w = target_width
|
313 |
+
new_h = int(h * (new_w / w))
|
314 |
+
|
315 |
+
# Resize with the calculated dimensions
|
316 |
+
img_tensor = torch.nn.functional.interpolate(img_tensor, size=(new_h, new_w), mode="bilinear")
|
317 |
+
|
318 |
+
# Center crop to target resolution if needed
|
319 |
+
if new_h > target_height or new_w > target_width:
|
320 |
+
# Calculate starting positions for crop
|
321 |
+
start_h = max(0, (new_h - target_height) // 2)
|
322 |
+
start_w = max(0, (new_w - target_width) // 2)
|
323 |
+
# Crop
|
324 |
+
img_tensor = img_tensor[:, :, start_h:start_h+target_height, start_w:start_w+target_width]
|
325 |
+
|
326 |
+
# Place in video tensor
|
327 |
+
frame_idx = min(max(int(frame_idx), 0), num_frames-1)
|
328 |
+
if is_sketch:
|
329 |
+
video[:, :, frame_idx] = img_tensor[:, :3] # Handle RGBA sketches
|
330 |
+
else:
|
331 |
+
video[:, :, frame_idx] = img_tensor
|
332 |
+
mask[:, frame_idx] = 1.0
|
333 |
+
return video, mask
|
334 |
+
|
335 |
+
def process_sketch_masks(num_sketch_masks, sketch_mask_inputs, num_frames, target_height=480, target_width=832):
|
336 |
+
"""Process sketch masks into a single tensor"""
|
337 |
+
# Create empty tensor filled with 1s (1 means no mask, keep original)
|
338 |
+
sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
|
339 |
+
|
340 |
+
for i in range(num_sketch_masks):
|
341 |
+
editor_value, frame_idx = sketch_mask_inputs[i]
|
342 |
+
if editor_value is None or frame_idx is None:
|
343 |
+
continue
|
344 |
+
|
345 |
+
# For ImageMask, we need to extract the mask from the editor_value dictionary
|
346 |
+
# editor_value is a dict with 'background', 'layers', and 'composite' keys from ImageEditor
|
347 |
+
if isinstance(editor_value, dict):
|
348 |
+
if "composite" in editor_value and editor_value["composite"] is not None:
|
349 |
+
# The 'composite' is the image with mask drawn on it
|
350 |
+
# Since we're using ImageMask with fixed black brush, the black areas are the mask
|
351 |
+
# Convert the composite to a binary mask (0=masked, 1=not masked)
|
352 |
+
# sketch = editor_value["background"] # This is the sketch
|
353 |
+
mask = editor_value["layers"][0] if editor_value["layers"] else None # This is the mask layer
|
354 |
+
if mask is not None:
|
355 |
+
# Convert mask to tensor and normalize
|
356 |
+
mask_array = np.array(mask)
|
357 |
+
mask_array = np.max(mask_array, axis=2)
|
358 |
+
|
359 |
+
# Convert to tensor, normalize to [0, 1]
|
360 |
+
mask_tensor = torch.from_numpy(mask_array).float()
|
361 |
+
if mask_tensor.max() > 1.0:
|
362 |
+
mask_tensor = mask_tensor / 255.0
|
363 |
+
|
364 |
+
# Resize to model's expected resolution
|
365 |
+
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, h, w]
|
366 |
+
mask_tensor = torch.nn.functional.interpolate(mask_tensor, size=(target_height, target_width), mode="nearest")
|
367 |
+
|
368 |
+
# Invert the mask: black (0) = masked area, white (1) = keep original
|
369 |
+
# We need to invert because in the UI black means "masked"
|
370 |
+
mask_tensor = 1.0 - mask_tensor
|
371 |
+
|
372 |
+
# Place in sketch_local_mask tensor
|
373 |
+
frame_idx = min(max(int(frame_idx), 0), num_frames-1)
|
374 |
+
sketch_local_mask[:, :, frame_idx] = mask_tensor
|
375 |
+
|
376 |
+
sketch_mask_vis = torch.ones((1, 3, num_frames, target_height, target_width), device=device)
|
377 |
+
for t in range(sketch_local_mask.shape[2]):
|
378 |
+
for c in range(3):
|
379 |
+
sketch_mask_vis[0, c, t, :, :] = torch.where(
|
380 |
+
sketch_local_mask[0, 0, t] > 0.5,
|
381 |
+
1.0, # White for unmasked areas
|
382 |
+
-1.0 # Black for masked areas
|
383 |
+
)
|
384 |
+
return sketch_local_mask
|
385 |
+
|
386 |
+
|
387 |
+
def invert_sketch(image):
|
388 |
+
"""Invert the colors of an image (black to white, white to black)"""
|
389 |
+
if image is None:
|
390 |
+
return None
|
391 |
+
|
392 |
+
# Handle input from ImageMask component (EditorValue dictionary)
|
393 |
+
if isinstance(image, dict) and "background" in image:
|
394 |
+
# Extract the background image
|
395 |
+
bg_image = image["background"]
|
396 |
+
|
397 |
+
# Invert the background
|
398 |
+
inverted_bg = invert_sketch_internal(bg_image)
|
399 |
+
|
400 |
+
# Return updated editor value
|
401 |
+
return gr.update(value=inverted_bg)
|
402 |
+
|
403 |
+
# Original function for regular images
|
404 |
+
return invert_sketch_internal(image)
|
405 |
+
|
406 |
+
def invert_sketch_internal(image):
|
407 |
+
"""Internal function to invert an image"""
|
408 |
+
if image is None:
|
409 |
+
return None
|
410 |
+
|
411 |
+
# Convert to PIL image if needed
|
412 |
+
if isinstance(image, str): # If it's a filepath
|
413 |
+
image = Image.open(image)
|
414 |
+
elif isinstance(image, np.ndarray):
|
415 |
+
image = Image.fromarray(image)
|
416 |
+
|
417 |
+
# Ensure it's a PIL image now
|
418 |
+
if not isinstance(image, Image.Image):
|
419 |
+
try:
|
420 |
+
image = Image.fromarray(np.array(image))
|
421 |
+
except:
|
422 |
+
print(f"Warning: Could not convert image of type {type(image)} to PIL Image")
|
423 |
+
return image
|
424 |
+
|
425 |
+
# Invert the image
|
426 |
+
inverted = Image.fromarray(255 - np.array(image))
|
427 |
+
return inverted
|
428 |
+
|
429 |
+
def create_blank_mask(canvas_width=832, canvas_height=480):
|
430 |
+
"""Create a blank white mask image"""
|
431 |
+
return Image.new('RGB', (canvas_width, canvas_height), color='white')
|
432 |
+
|
433 |
+
def create_mask_with_sketch(sketch, canvas_width=832, canvas_height=480):
|
434 |
+
"""Create a mask image with sketch as background"""
|
435 |
+
if sketch is None:
|
436 |
+
return create_blank_mask(canvas_width, canvas_height)
|
437 |
+
|
438 |
+
# Convert sketch to PIL if needed
|
439 |
+
if not isinstance(sketch, Image.Image):
|
440 |
+
sketch = Image.fromarray(np.array(sketch))
|
441 |
+
|
442 |
+
# Resize sketch to fit the canvas
|
443 |
+
sketch = sketch.resize((canvas_width, canvas_height))
|
444 |
+
|
445 |
+
# Create a semi-transparent white layer over the sketch
|
446 |
+
overlay = Image.new('RGBA', (canvas_width, canvas_height), (255, 255, 255, 128))
|
447 |
+
|
448 |
+
# Ensure sketch has alpha channel
|
449 |
+
if sketch.mode != 'RGBA':
|
450 |
+
sketch = sketch.convert('RGBA')
|
451 |
+
|
452 |
+
# Overlay the semi-transparent white layer on the sketch
|
453 |
+
result = Image.alpha_composite(sketch, overlay)
|
454 |
+
|
455 |
+
# Convert back to RGB for Gradio
|
456 |
+
return result.convert('RGB')
|
457 |
+
|
458 |
+
def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args):
|
459 |
+
"""Validate user inputs and return error messages if any"""
|
460 |
+
errors = []
|
461 |
+
|
462 |
+
# Check text prompt
|
463 |
+
if not text_prompt or text_prompt.strip() == "":
|
464 |
+
errors.append("❌ Text prompt is required. Please enter a description for your video.")
|
465 |
+
|
466 |
+
# Check condition images
|
467 |
+
cond_images_count = 0
|
468 |
+
for i in range(int(num_cond_images)):
|
469 |
+
img = args[i*2]
|
470 |
+
frame_idx = args[i*2+1]
|
471 |
+
|
472 |
+
if img is None:
|
473 |
+
errors.append(f"❌ Image #{i+1} is missing. Please upload an image or reduce the number of keyframe images.")
|
474 |
+
else:
|
475 |
+
cond_images_count += 1
|
476 |
+
|
477 |
+
if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
|
478 |
+
errors.append(f"❌ Frame index for Image #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
|
479 |
+
|
480 |
+
# Check condition sketches
|
481 |
+
num_cond_sketches_index = 8 # Starting index for sketch inputs
|
482 |
+
cond_sketches_count = 0
|
483 |
+
sketch_frame_indices = []
|
484 |
+
|
485 |
+
for i in range(int(num_cond_sketches)):
|
486 |
+
sketch_idx = num_cond_sketches_index + i*2
|
487 |
+
frame_idx_idx = num_cond_sketches_index + 1 + i*2
|
488 |
+
|
489 |
+
if sketch_idx < len(args) and frame_idx_idx < len(args):
|
490 |
+
sketch = args[sketch_idx]
|
491 |
+
frame_idx = args[frame_idx_idx]
|
492 |
+
|
493 |
+
# Check if sketch is provided
|
494 |
+
if sketch is None:
|
495 |
+
errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch or reduce the number of keyframe sketches.")
|
496 |
+
else:
|
497 |
+
# For ImageMask components, check if background is provided
|
498 |
+
if isinstance(sketch, dict):
|
499 |
+
if "background" not in sketch or sketch["background"] is None:
|
500 |
+
errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch image.")
|
501 |
+
else:
|
502 |
+
cond_sketches_count += 1
|
503 |
+
else:
|
504 |
+
cond_sketches_count += 1
|
505 |
+
|
506 |
+
# Check frame index
|
507 |
+
if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
|
508 |
+
errors.append(f"❌ Frame index for Sketch #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
|
509 |
+
elif frame_idx is not None:
|
510 |
+
sketch_frame_indices.append(frame_idx)
|
511 |
+
|
512 |
+
# Check for duplicate frame indices
|
513 |
+
image_frame_indices = []
|
514 |
+
for i in range(int(num_cond_images)):
|
515 |
+
frame_idx = args[i*2+1]
|
516 |
+
if frame_idx is not None:
|
517 |
+
image_frame_indices.append(frame_idx)
|
518 |
+
|
519 |
+
all_frame_indices = image_frame_indices + sketch_frame_indices
|
520 |
+
if len(all_frame_indices) != len(set(all_frame_indices)):
|
521 |
+
errors.append("❌ Duplicate frame indices detected. Each image and sketch must be placed at a different frame.")
|
522 |
+
|
523 |
+
# Check minimum requirements
|
524 |
+
if cond_images_count == 0:
|
525 |
+
errors.append("❌ At least one input image is required.")
|
526 |
+
|
527 |
+
return errors
|
528 |
+
|
529 |
+
def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
|
530 |
+
# Validate inputs first
|
531 |
+
validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
|
532 |
+
|
533 |
+
if validation_errors:
|
534 |
+
error_message = "\n".join(validation_errors)
|
535 |
+
return gr.update(value=None), error_message
|
536 |
+
|
537 |
+
try:
|
538 |
+
# Parse inputs
|
539 |
+
# Get the condition images
|
540 |
+
cond_images = []
|
541 |
+
for i in range(int(num_cond_images)):
|
542 |
+
img = args[i*2]
|
543 |
+
frame_idx = args[i*2+1]
|
544 |
+
if img is not None and frame_idx is not None:
|
545 |
+
cond_images.append((img, frame_idx))
|
546 |
+
|
547 |
+
# Get num_cond_sketches
|
548 |
+
if num_cond_sketches is None:
|
549 |
+
num_cond_sketches = 0
|
550 |
+
else:
|
551 |
+
num_cond_sketches = int(num_cond_sketches)
|
552 |
+
|
553 |
+
# Get condition sketches and masks
|
554 |
+
cond_sketches = []
|
555 |
+
sketch_masks = []
|
556 |
+
num_cond_sketches_index = 8 # Starting index for sketch inputs
|
557 |
+
|
558 |
+
for i in range(num_cond_sketches):
|
559 |
+
sketch_idx = num_cond_sketches_index + i*2
|
560 |
+
frame_idx_idx = num_cond_sketches_index + 1 + i*2
|
561 |
+
|
562 |
+
if sketch_idx < len(args) and frame_idx_idx < len(args):
|
563 |
+
editor_value = args[sketch_idx]
|
564 |
+
frame_idx = args[frame_idx_idx]
|
565 |
+
|
566 |
+
if editor_value is not None and frame_idx is not None:
|
567 |
+
# Extract the sketch from the background of the editor value
|
568 |
+
if isinstance(editor_value, dict) and "background" in editor_value:
|
569 |
+
sketch = editor_value["background"]
|
570 |
+
if sketch is not None:
|
571 |
+
cond_sketches.append((sketch, frame_idx))
|
572 |
+
# Also add to sketch_masks for mask processing
|
573 |
+
sketch_masks.append((editor_value, frame_idx))
|
574 |
+
else:
|
575 |
+
# For regular image inputs (first sketch)
|
576 |
+
if editor_value is not None:
|
577 |
+
cond_sketches.append((editor_value, frame_idx))
|
578 |
+
|
579 |
+
# Set target resolution based on selection
|
580 |
+
target_height, target_width = checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"]
|
581 |
+
|
582 |
+
# Update model resolution
|
583 |
+
if not fast_dev:
|
584 |
+
model.update_height_width(target_height, target_width)
|
585 |
+
|
586 |
+
# Process conditions
|
587 |
+
with torch.no_grad():
|
588 |
+
# Process image conditions
|
589 |
+
masked_cond_video, preserved_cond_mask = process_conditions(
|
590 |
+
num_cond_images, cond_images, num_frames, target_height=target_height, target_width=target_width
|
591 |
+
)
|
592 |
+
|
593 |
+
# Process sketch conditions
|
594 |
+
masked_cond_sketch, preserved_sketch_mask = process_conditions(
|
595 |
+
len(cond_sketches), cond_sketches, num_frames, is_sketch=True, target_height=target_height, target_width=target_width
|
596 |
+
)
|
597 |
+
|
598 |
+
# Process sketch masks (if any)
|
599 |
+
sketch_local_mask = None
|
600 |
+
if len(sketch_masks) > 0:
|
601 |
+
sketch_local_mask = process_sketch_masks(
|
602 |
+
len(sketch_masks), sketch_masks, num_frames, target_height=target_height, target_width=target_width
|
603 |
+
)
|
604 |
+
else:
|
605 |
+
sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
|
606 |
+
|
607 |
+
if fast_dev:
|
608 |
+
print("Fast dev mode, returning dummy video")
|
609 |
+
# Create a simple dummy video for testing
|
610 |
+
temp_dir = tempfile.mkdtemp()
|
611 |
+
video_path = os.path.join(temp_dir, "dummy_video.mp4")
|
612 |
+
|
613 |
+
# Create a simple test video
|
614 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
615 |
+
video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (target_width, target_height))
|
616 |
+
|
617 |
+
for i in range(30): # 30 frames
|
618 |
+
# Create a simple colored frame
|
619 |
+
frame = np.full((target_height, target_width, 3), (i * 8) % 255, dtype=np.uint8)
|
620 |
+
video_writer.write(frame)
|
621 |
+
|
622 |
+
video_writer.release()
|
623 |
+
return video_path, "✅ Dummy video generated successfully in fast dev mode!"
|
624 |
+
|
625 |
+
masked_cond_video = masked_cond_video.to(device=device, dtype=dtype)
|
626 |
+
preserved_cond_mask = preserved_cond_mask.to(device=device, dtype=dtype)
|
627 |
+
masked_cond_sketch = masked_cond_sketch.to(device=device, dtype=dtype)
|
628 |
+
preserved_sketch_mask = preserved_sketch_mask.to(device=device, dtype=dtype)
|
629 |
+
|
630 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(device).type):
|
631 |
+
# Generate video
|
632 |
+
model.pipe.device = device
|
633 |
+
generated_video = model.pipe(
|
634 |
+
prompt=[text_prompt],
|
635 |
+
negative_prompt=[model.negative_prompt],
|
636 |
+
input_image=None,
|
637 |
+
num_inference_steps=15,
|
638 |
+
num_frames=num_frames,
|
639 |
+
seed=42, tiled=True,
|
640 |
+
input_condition_video=masked_cond_video,
|
641 |
+
input_condition_preserved_mask=preserved_cond_mask,
|
642 |
+
input_condition_video_sketch=masked_cond_sketch,
|
643 |
+
input_condition_preserved_mask_sketch=preserved_sketch_mask,
|
644 |
+
sketch_local_mask=sketch_local_mask,
|
645 |
+
cfg_scale=cfg_scale,
|
646 |
+
sequence_cond_residual_scale=sequence_cond_residual_scale,
|
647 |
+
height=target_height,
|
648 |
+
width=target_width,
|
649 |
+
)
|
650 |
+
|
651 |
+
# Convert to PIL images
|
652 |
+
video_frames = model.pipe.tensor2video(generated_video[0].cpu())
|
653 |
+
|
654 |
+
# Convert PIL images to an MP4 video
|
655 |
+
temp_dir = tempfile.mkdtemp()
|
656 |
+
video_path = os.path.join(temp_dir, "generated_video.mp4")
|
657 |
+
|
658 |
+
width, height = video_frames[0].size
|
659 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 video
|
660 |
+
video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (width, height)) # 20 fps
|
661 |
+
|
662 |
+
for frame in video_frames:
|
663 |
+
# Convert PIL image to OpenCV BGR format
|
664 |
+
frame_bgr = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
|
665 |
+
video_writer.write(frame_bgr)
|
666 |
+
|
667 |
+
video_writer.release()
|
668 |
+
print(f"Generated video saved to {video_path}. Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
669 |
+
|
670 |
+
return video_path, f"✅ Video generated successfully! (with {len(cond_images)} keyframe images, {len(cond_sketches)} keyframe sketches)"
|
671 |
+
|
672 |
+
except Exception as e:
|
673 |
+
error_msg = f"❌ Error during generation: {str(e)}"
|
674 |
+
print(error_msg)
|
675 |
+
return gr.update(value=None), error_msg
|
676 |
+
|
677 |
+
def create_sample_gallery():
|
678 |
+
"""Create gallery items for samples"""
|
679 |
+
import os
|
680 |
+
|
681 |
+
gallery_items = []
|
682 |
+
sample_info = [
|
683 |
+
{
|
684 |
+
"id": 1,
|
685 |
+
"title": "Sample 1",
|
686 |
+
"description": "Man playing with blue fish underwater (3 sketches)",
|
687 |
+
"preview": "samples/1_image1.png"
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"id": 2,
|
691 |
+
"title": "Sample 2",
|
692 |
+
"description": "Girl and boy planting a growing flower (2 sketches)",
|
693 |
+
"preview": "samples/2_image1.jpg"
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"id": 3,
|
697 |
+
"title": "Sample 3",
|
698 |
+
"description": "Ancient Chinese boy giving apple to elder (1 sketch)",
|
699 |
+
"preview": "samples/3_image1.png"
|
700 |
+
}
|
701 |
+
]
|
702 |
+
|
703 |
+
for sample in sample_info:
|
704 |
+
if os.path.exists(sample["preview"]):
|
705 |
+
gallery_items.append((sample["preview"], f"{sample['title']}: {sample['description']}"))
|
706 |
+
|
707 |
+
return gallery_items
|
708 |
+
|
709 |
+
def handle_gallery_select(evt: gr.SelectData):
|
710 |
+
"""Handle gallery selection and load the corresponding sample"""
|
711 |
+
sample_id = evt.index + 1 # Gallery index starts from 0, sample IDs start from 1
|
712 |
+
return apply_sample_to_ui(sample_id)
|
713 |
+
|
714 |
+
def load_sample_data(sample_id):
|
715 |
+
"""Load sample data based on the selected sample"""
|
716 |
+
import os
|
717 |
+
|
718 |
+
samples_dir = "samples"
|
719 |
+
|
720 |
+
# Sample configurations
|
721 |
+
sample_configs = {
|
722 |
+
1: {
|
723 |
+
"prompt": "Underwater scene: A shirtless man plays with a spiraling blue fish. A whale follows a bag in the man's hand, swimming in circles as the man uses the bag to lure the blue fish forward. Anime. High quality.",
|
724 |
+
"num_sketches": 3,
|
725 |
+
"image_frame": 0,
|
726 |
+
"sketch_frames": [20, 40, 60],
|
727 |
+
"num_frames": 61
|
728 |
+
},
|
729 |
+
2: {
|
730 |
+
"prompt": "A girl and a silver-haired boy plant a huge flower. As the camera slowly moves up, the huge flower continues to grow and bloom. Anime. High quality.",
|
731 |
+
"num_sketches": 2,
|
732 |
+
"image_frame": 0,
|
733 |
+
"sketch_frames": [30, 60],
|
734 |
+
"num_frames": 61
|
735 |
+
},
|
736 |
+
3: {
|
737 |
+
"prompt": "An ancient Chinese boy holds an apple and smiles as he gives it to an elderly man nearby. Anime. High quality.",
|
738 |
+
"num_sketches": 1,
|
739 |
+
"image_frame": 0,
|
740 |
+
"sketch_frames": [30],
|
741 |
+
"num_frames": 33
|
742 |
+
}
|
743 |
+
}
|
744 |
+
|
745 |
+
if sample_id not in sample_configs:
|
746 |
+
return None
|
747 |
+
|
748 |
+
config = sample_configs[sample_id]
|
749 |
+
|
750 |
+
# Load image
|
751 |
+
image_path = os.path.join(samples_dir, f"{sample_id}_image1.png")
|
752 |
+
if not os.path.exists(image_path):
|
753 |
+
image_path = os.path.join(samples_dir, f"{sample_id}_image1.jpg")
|
754 |
+
|
755 |
+
# Load sketches
|
756 |
+
sketches = []
|
757 |
+
for i in range(config["num_sketches"]):
|
758 |
+
sketch_path = os.path.join(samples_dir, f"{sample_id}_sketch{i+1}.jpg")
|
759 |
+
if os.path.exists(sketch_path):
|
760 |
+
sketches.append(sketch_path)
|
761 |
+
|
762 |
+
# Load output video
|
763 |
+
output_path = os.path.join(samples_dir, f"{sample_id}_out.mp4")
|
764 |
+
|
765 |
+
return {
|
766 |
+
"prompt": config["prompt"],
|
767 |
+
"image": image_path if os.path.exists(image_path) else None,
|
768 |
+
"sketches": sketches,
|
769 |
+
"image_frame": config["image_frame"],
|
770 |
+
"sketch_frames": config["sketch_frames"][:len(sketches)],
|
771 |
+
"output_video": output_path if os.path.exists(output_path) else None,
|
772 |
+
"num_sketches": len(sketches),
|
773 |
+
"num_frames": config["num_frames"]
|
774 |
+
}
|
775 |
+
|
776 |
+
def apply_sample_to_ui(sample_id):
|
777 |
+
"""Apply sample data to UI components"""
|
778 |
+
sample_data = load_sample_data(sample_id)
|
779 |
+
|
780 |
+
if not sample_data:
|
781 |
+
return [gr.update() for _ in range(20)] # Return no updates if sample not found
|
782 |
+
|
783 |
+
updates = [gr.update(value=sample_data["num_frames"])]
|
784 |
+
|
785 |
+
# Update prompt
|
786 |
+
updates.append(gr.update(value=sample_data["prompt"]))
|
787 |
+
|
788 |
+
# Update number of sketches
|
789 |
+
updates.append(gr.update(value=sample_data["num_sketches"]))
|
790 |
+
|
791 |
+
# Update condition image
|
792 |
+
updates.append(gr.update(value=sample_data["image"]))
|
793 |
+
updates.append(gr.update(value=sample_data["image_frame"]))
|
794 |
+
|
795 |
+
# Update sketches (up to 4)
|
796 |
+
for i in range(4):
|
797 |
+
if i < len(sample_data["sketches"]):
|
798 |
+
# Load sketch image
|
799 |
+
sketch_img = Image.open(sample_data["sketches"][i])
|
800 |
+
# Create ImageMask format
|
801 |
+
sketch_dict = {
|
802 |
+
"background": sketch_img,
|
803 |
+
"layers": [],
|
804 |
+
"composite": sketch_img
|
805 |
+
}
|
806 |
+
updates.append(gr.update(value=sketch_dict))
|
807 |
+
updates.append(gr.update(value=sample_data["sketch_frames"][i]))
|
808 |
+
else:
|
809 |
+
updates.append(gr.update(value=None))
|
810 |
+
updates.append(gr.update(value=30))
|
811 |
+
|
812 |
+
# Update output video
|
813 |
+
updates.append(gr.update(value=sample_data["output_video"]))
|
814 |
+
|
815 |
+
# Update status
|
816 |
+
updates.append(gr.update(value=f"✅ Loaded Sample {sample_id}: {sample_data['prompt'][:50]}..."))
|
817 |
+
|
818 |
+
return updates
|
819 |
+
|
820 |
+
if __name__ == "__main__":
|
821 |
+
from util.stylesheets import css, pre_js, banner_image
|
822 |
+
with gr.Blocks(title="🎨 ToonComposer Demo", css=css, js=pre_js) as iface:
|
823 |
+
with gr.Row():
|
824 |
+
with gr.Column(scale=1):
|
825 |
+
gr.HTML(banner_image)
|
826 |
+
with gr.Column(scale=1):
|
827 |
+
gr.Markdown("""
|
828 |
+
💡 **Quick Guide**
|
829 |
+
1. Set the promopt and number of target frames, input keyframe images/sketches, etc.
|
830 |
+
2. Upload keyframe image as the first frame (with index set to 0).
|
831 |
+
3. Upload sketches with optional motion masks for controlled generation at specified frame indices.
|
832 |
+
4. Click the *Generate* button to create your cartoon video.
|
833 |
+
""")
|
834 |
+
|
835 |
+
max_num_frames = 61
|
836 |
+
cond_images_inputs = []
|
837 |
+
cond_sketches_inputs = []
|
838 |
+
with gr.Row():
|
839 |
+
with gr.Column(scale=1):
|
840 |
+
with gr.Accordion("Video Settings", open=True):
|
841 |
+
num_frames = gr.Slider(
|
842 |
+
minimum=17, maximum=max_num_frames, value=max_num_frames, step=1, label="🎥 Number of Frames",
|
843 |
+
info="Select the total number of frames for the generated video. Should be 4N+"
|
844 |
+
)
|
845 |
+
|
846 |
+
resolution = gr.Radio(
|
847 |
+
choices=["480p", "608p"],
|
848 |
+
value="480p",
|
849 |
+
label="🎥 Resolution",
|
850 |
+
info="Select the resolution for the generated video."
|
851 |
+
)
|
852 |
+
|
853 |
+
text_prompt = gr.Textbox(
|
854 |
+
label="📝 Text Prompt",
|
855 |
+
placeholder="Enter a description for the video.",
|
856 |
+
info="Describe what you want to generate in the video.",
|
857 |
+
lines=5
|
858 |
+
)
|
859 |
+
cfg_scale = gr.Slider(
|
860 |
+
minimum=1.0, maximum=15.0, value=7.5, label="⚙️ CFG Scale",
|
861 |
+
info="Adjust the classifier-free guidance scale for generation."
|
862 |
+
)
|
863 |
+
sequence_cond_residual_scale = gr.Slider(
|
864 |
+
minimum=0.0, maximum=1.2, value=1.0, label="⚙️ Pos-aware Residual Scale",
|
865 |
+
info="Adjust the residual scale for the position-aware sequence condition."
|
866 |
+
)
|
867 |
+
|
868 |
+
with gr.Column(scale=3):
|
869 |
+
with gr.Accordion("Keyframe Image(s)", open=True):
|
870 |
+
num_cond_images = gr.Slider(
|
871 |
+
minimum=1, maximum=4, value=1, step=1, label="🖼️ Number of Keyframe Images",
|
872 |
+
info="Specify how many keyframe color images to use (max 4 images)."
|
873 |
+
)
|
874 |
+
for i in range(4): # Max 4 condition images
|
875 |
+
with gr.Tab(label=f"Image {i+1}", interactive=i==0) as tab:
|
876 |
+
gr.Markdown("At least one image is required. \n Each image or sketch will be used to control the cartoon geneartion at the given frame index.")
|
877 |
+
image_input = gr.Image(
|
878 |
+
label=f"Image {i+1}", type="pil",
|
879 |
+
placeholder=f"Upload a keyframe image {i+1}..."
|
880 |
+
)
|
881 |
+
frame_index_input = gr.Slider(
|
882 |
+
label=f"Frame Index for Image #{i+1}", minimum=0, maximum=max_num_frames - 1,
|
883 |
+
value=i * (max_num_frames-1) // 3, step=1,
|
884 |
+
info=f"Frame position for Image {i+1} (0 to {max_num_frames-1})"
|
885 |
+
)
|
886 |
+
cond_images_inputs.append((image_input, frame_index_input, tab))
|
887 |
+
|
888 |
+
|
889 |
+
with gr.Column(scale=3):
|
890 |
+
with gr.Accordion("Keyframe Sketch(es)", open=True):
|
891 |
+
num_cond_sketches = gr.Slider(
|
892 |
+
minimum=1, maximum=4, value=1, step=1, label="✏️ Number of Keyframe Sketch(es)",
|
893 |
+
info="Specify how many keyframe sketches to use (max 4 sketches)."
|
894 |
+
)
|
895 |
+
for i in range(4): # Max 4 condition sketches
|
896 |
+
with gr.Tab(label=f"Sketch {i + 1}", interactive=i==0) as tab:
|
897 |
+
|
898 |
+
gr.Markdown("At least one sketch is required. \n You can optionally draw black areas using the brush tool to mark regions where motion can be generated freely.")
|
899 |
+
|
900 |
+
# Use ImageMask which allows uploading an image and drawing a mask
|
901 |
+
sketch_input = gr.ImageMask(
|
902 |
+
label=f"Sketch {i + 1} with Motion Mask",
|
903 |
+
type="pil",
|
904 |
+
elem_id=f"sketch_mask_{i + 1}"
|
905 |
+
)
|
906 |
+
|
907 |
+
# All sketches have a frame index input
|
908 |
+
_frame_index_input = gr.Slider(
|
909 |
+
label=f"Frame Index for Sketch #{i + 1}", minimum=0, maximum=max_num_frames - 1,
|
910 |
+
value=max_num_frames-1, step=1,
|
911 |
+
info=f"Frame position for Sketch {i + 1} (0 to {max_num_frames-1})"
|
912 |
+
)
|
913 |
+
|
914 |
+
cond_sketches_inputs.append((sketch_input, _frame_index_input, tab))
|
915 |
+
|
916 |
+
with gr.Row():
|
917 |
+
with gr.Column(scale=1):
|
918 |
+
# Sample Gallery Section
|
919 |
+
with gr.Accordion("🔍 Sample Gallery", open=True):
|
920 |
+
gr.Markdown("Click on any sample image below to load the sample inputs.")
|
921 |
+
sample_gallery = gr.Gallery(
|
922 |
+
value=create_sample_gallery(),
|
923 |
+
label="Sample Examples",
|
924 |
+
show_label=False,
|
925 |
+
elem_id="sample-gallery",
|
926 |
+
columns=3,
|
927 |
+
rows=1,
|
928 |
+
height=200,
|
929 |
+
allow_preview=True,
|
930 |
+
object_fit="contain")
|
931 |
+
|
932 |
+
with gr.Accordion("🛠️ Tools", open=False):
|
933 |
+
tool_input = gr.Image(
|
934 |
+
label=f"Input Image", type="pil",
|
935 |
+
placeholder=f"Upload an image."
|
936 |
+
)
|
937 |
+
invert_btn = gr.Button(f"Invert Colors")
|
938 |
+
invert_btn.click(
|
939 |
+
fn=invert_sketch,
|
940 |
+
inputs=[tool_input],
|
941 |
+
outputs=[tool_input]
|
942 |
+
)
|
943 |
+
|
944 |
+
with gr.Column(scale=1):
|
945 |
+
status_text = gr.Textbox(
|
946 |
+
label="📊 Status",
|
947 |
+
value="Ready to generate. Please check your inputs and click Run.",
|
948 |
+
interactive=False,
|
949 |
+
lines=5
|
950 |
+
)
|
951 |
+
|
952 |
+
with gr.Accordion("🎬 Generated Video", open=True):
|
953 |
+
output_video = gr.Video(
|
954 |
+
label="Video Output",
|
955 |
+
show_label=True
|
956 |
+
)
|
957 |
+
run_button = gr.Button("🚀 Generate Video", variant="primary", size="lg")
|
958 |
+
|
959 |
+
def update_visibility(num_items, num_frames):
|
960 |
+
# Update visibility for columns
|
961 |
+
updates_images = []
|
962 |
+
updates_indices = []
|
963 |
+
for i in range(4):
|
964 |
+
is_visible = i < num_items
|
965 |
+
# is_visible = True
|
966 |
+
updates_images.append(gr.update(interactive=is_visible))
|
967 |
+
updates_indices.append(gr.update(
|
968 |
+
value=((num_frames - 1) // max(num_items, 1)) * (i + 1),
|
969 |
+
minimum=0, maximum=num_frames-1,
|
970 |
+
))
|
971 |
+
return updates_images + updates_indices
|
972 |
+
|
973 |
+
def update_visibility_images(num_items, num_frames):
|
974 |
+
# Update visibility for columns
|
975 |
+
updates_images = []
|
976 |
+
updates_indices = []
|
977 |
+
for i in range(4):
|
978 |
+
is_visible = i < num_items
|
979 |
+
updates_images.append(gr.update(interactive=is_visible))
|
980 |
+
updates_indices.append(gr.update(
|
981 |
+
value=((num_frames - 1) // max(num_items, 1)) * i,
|
982 |
+
minimum=0, maximum=num_frames-1,
|
983 |
+
))
|
984 |
+
return updates_images + updates_indices
|
985 |
+
|
986 |
+
def update_frame_ranges(num_items_images, num_items_sketches, num_frames):
|
987 |
+
"""Update the maximum values for all frame index sliders"""
|
988 |
+
updates = []
|
989 |
+
for i in range(4): # Images
|
990 |
+
updates.append(gr.update(
|
991 |
+
value=((num_frames - 1) // max(num_items_images, 1)) * i,
|
992 |
+
maximum=num_frames-1
|
993 |
+
))
|
994 |
+
for i in range(4): # Sketches
|
995 |
+
updates.append(gr.update(
|
996 |
+
value=((num_frames - 1) // max(num_items_sketches, 1)) * (i + 1),
|
997 |
+
maximum=num_frames-1))
|
998 |
+
return updates
|
999 |
+
|
1000 |
+
num_cond_images.change(
|
1001 |
+
fn=update_visibility_images,
|
1002 |
+
inputs=[num_cond_images, num_frames],
|
1003 |
+
outputs=[tab for _, _, tab in cond_images_inputs] \
|
1004 |
+
+ [frame_index_input for _, frame_index_input, _ in cond_images_inputs],
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
num_cond_sketches.change(
|
1008 |
+
fn=update_visibility,
|
1009 |
+
inputs=[num_cond_sketches, num_frames],
|
1010 |
+
outputs=[tab for _, _, tab in cond_sketches_inputs] \
|
1011 |
+
+ [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs],
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
num_frames.change(
|
1015 |
+
fn=update_frame_ranges,
|
1016 |
+
inputs=[num_cond_images, num_cond_sketches, num_frames],
|
1017 |
+
outputs=[frame_index_input for _, frame_index_input, _ in cond_images_inputs] + \
|
1018 |
+
[frame_index_input for _, frame_index_input, _ in cond_sketches_inputs]
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
def update_resolution(resolution):
|
1022 |
+
model.update_height_width(checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"])
|
1023 |
+
model.load_tooncomposer_checkpoint(checkpoints_by_resolution[resolution]["checkpoint_path"])
|
1024 |
+
return gr.update(), gr.update()
|
1025 |
+
|
1026 |
+
resolution.change(
|
1027 |
+
fn=update_resolution,
|
1028 |
+
inputs=[resolution],
|
1029 |
+
outputs=[output_video, run_button]
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
sample_outputs = [
|
1033 |
+
num_frames, text_prompt, num_cond_sketches,
|
1034 |
+
cond_images_inputs[0][0], cond_images_inputs[0][1], # Image 1
|
1035 |
+
cond_sketches_inputs[0][0], cond_sketches_inputs[0][1], # Sketch 1
|
1036 |
+
cond_sketches_inputs[1][0], cond_sketches_inputs[1][1], # Sketch 2
|
1037 |
+
cond_sketches_inputs[2][0], cond_sketches_inputs[2][1], # Sketch 3
|
1038 |
+
cond_sketches_inputs[3][0], cond_sketches_inputs[3][1], # Sketch 4
|
1039 |
+
output_video, status_text
|
1040 |
+
]
|
1041 |
+
|
1042 |
+
sample_gallery.select(
|
1043 |
+
fn=handle_gallery_select,
|
1044 |
+
outputs=sample_outputs
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
inputs = [num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution]
|
1048 |
+
run_button.click(
|
1049 |
+
fn=tooncomposer_inference,
|
1050 |
+
inputs=inputs,
|
1051 |
+
outputs=[output_video, status_text]
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
# Add condition image inputs
|
1055 |
+
for image_input, frame_index_input, _ in cond_images_inputs:
|
1056 |
+
inputs.append(image_input)
|
1057 |
+
inputs.append(frame_index_input)
|
1058 |
+
|
1059 |
+
# Add sketch inputs (both regular and ImageMask)
|
1060 |
+
for sketch_input, frame_index_input, _ in cond_sketches_inputs:
|
1061 |
+
inputs.append(sketch_input)
|
1062 |
+
inputs.append(frame_index_input)
|
1063 |
+
|
1064 |
+
iface.launch(server_port=7860, server_name="0.0.0.0",
|
1065 |
+
allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")),
|
1066 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "samples"))])
|
model/__init__.py
ADDED
File without changes
|
model/dera.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
from .dit import flash_attention
|
5 |
+
import torch.amp as amp
|
6 |
+
|
7 |
+
|
8 |
+
class DeRAAttention(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
dim,
|
12 |
+
num_heads,
|
13 |
+
window_size=(-1, -1),
|
14 |
+
mode="spatial"):
|
15 |
+
assert dim % num_heads == 0
|
16 |
+
super().__init__()
|
17 |
+
self.dim = dim
|
18 |
+
self.num_heads = num_heads
|
19 |
+
self.head_dim = dim // num_heads
|
20 |
+
self.window_size = window_size
|
21 |
+
|
22 |
+
self.q = nn.Linear(dim, dim)
|
23 |
+
self.k = nn.Linear(dim, dim)
|
24 |
+
self.v = nn.Linear(dim, dim)
|
25 |
+
self.o = nn.Linear(dim, dim)
|
26 |
+
self.visualize_attention = False
|
27 |
+
|
28 |
+
if mode == 'spatial':
|
29 |
+
self.rope_apply = self.rope_apply_spatial
|
30 |
+
elif mode == 'temporal':
|
31 |
+
self.rope_apply = self.rope_apply_temporal
|
32 |
+
elif mode == 'spatial_temporal':
|
33 |
+
self.rope_apply = self.rope_apply_spatial_temporal
|
34 |
+
else:
|
35 |
+
raise ValueError("Invalid mode: {}".format(mode))
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
@amp.autocast(enabled=False, device_type="cuda")
|
39 |
+
def rope_apply_spatial(x, grid_size, freqs, sequence_cond_compressed_indices=None):
|
40 |
+
batch, _, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
41 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
42 |
+
assert len(grid_size) == 2, "grid_size mustbe [h, w]"
|
43 |
+
h, w = grid_size[0], grid_size[1]
|
44 |
+
seq_len = h * w
|
45 |
+
x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
|
46 |
+
batch, seq_len, n, -1, 2))
|
47 |
+
freqs_i = torch.cat([
|
48 |
+
freqs[1][:h].view(1, h, 1, -1).expand(1, h, w, -1),
|
49 |
+
freqs[2][:w].view(1, 1, w, -1).expand(1, h, w, -1)
|
50 |
+
], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
|
51 |
+
freqs_i = torch.concat([freqs_i.new_ones(batch, seq_len, 1, c//3), freqs_i], dim=3)
|
52 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
|
53 |
+
return x_i.float()
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
@amp.autocast(enabled=False, device_type="cuda")
|
57 |
+
def rope_apply_temporal(x, grid_size, freqs, sequence_cond_compressed_indices=None):
|
58 |
+
batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
59 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
60 |
+
assert len(grid_size) == 1, "grid_size must be [t]"
|
61 |
+
seq_len = grid_size[0]
|
62 |
+
x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(batch, seq_len, n, -1, 2))
|
63 |
+
freqs_i = torch.cat([
|
64 |
+
freqs[0][:seq_len].view(seq_len, 1, 1, -1)
|
65 |
+
], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
|
66 |
+
freqs_i = torch.concat([freqs_i, freqs_i.new_ones(batch, seq_len, 1, 2 * c//3)], dim=3)
|
67 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
|
68 |
+
if seq_len_actual > seq_len:
|
69 |
+
sequence_cond_seq_length = seq_len_actual - seq_len
|
70 |
+
if sequence_cond_seq_length == seq_len:
|
71 |
+
x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
|
72 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
|
73 |
+
else:
|
74 |
+
sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
|
75 |
+
sequence_cond_t_length = len(sequence_cond_compressed_index)
|
76 |
+
assert sequence_cond_t_length == sequence_cond_seq_length, "`sequence_cond_t_length` must be equal to `sequence_cond_seq_length`"
|
77 |
+
x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
|
78 |
+
freqs_i_sequence_cond = torch.cat([
|
79 |
+
freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1),
|
80 |
+
], dim=-1).reshape(sequence_cond_seq_length, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
|
81 |
+
freqs_i_sequence_cond = torch.concat([freqs_i_sequence_cond, freqs_i_sequence_cond.new_ones(batch, sequence_cond_t_length, 1, 2 * c//3)], dim=3)
|
82 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
|
83 |
+
x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
|
84 |
+
|
85 |
+
return x_i.float()
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
@amp.autocast(enabled=False, device_type="cuda")
|
89 |
+
def rope_apply_spatial_temporal(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
|
90 |
+
batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
91 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
92 |
+
assert len(grid_sizes) == 3, "grid_sizes must be ([f, h, w])"
|
93 |
+
f, h, w = grid_sizes[0], grid_sizes[1], grid_sizes[2]
|
94 |
+
seq_len = f * h * w
|
95 |
+
x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
|
96 |
+
batch, seq_len, n, -1, 2))
|
97 |
+
freqs_i = torch.cat([
|
98 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
99 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
100 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
101 |
+
], dim=-1).reshape(seq_len, 1, -1)
|
102 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
|
103 |
+
if seq_len_actual > seq_len:
|
104 |
+
sequence_cond_seq_length = seq_len_actual - seq_len
|
105 |
+
if sequence_cond_seq_length == seq_len:
|
106 |
+
x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
|
107 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
|
108 |
+
else:
|
109 |
+
sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
|
110 |
+
sequence_cond_t_length = len(sequence_cond_compressed_index)
|
111 |
+
assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
|
112 |
+
x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
|
113 |
+
freqs_i_sequence_cond = torch.cat([
|
114 |
+
freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
|
115 |
+
freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
|
116 |
+
freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
|
117 |
+
], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
|
118 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
|
119 |
+
x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
|
120 |
+
return x_i.float()
|
121 |
+
|
122 |
+
|
123 |
+
def forward(self, x, seq_lens, grid_size, freqs, sequence_cond_compressed_indices):
|
124 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
125 |
+
def qkv_fn(x):
|
126 |
+
q = self.q(x).view(b, s, n, d)
|
127 |
+
k = self.k(x).view(b, s, n, d)
|
128 |
+
v = self.v(x).view(b, s, n, d)
|
129 |
+
return q, k, v
|
130 |
+
|
131 |
+
q, k, v = qkv_fn(x)
|
132 |
+
q_rope = self.rope_apply(q, grid_size, freqs, sequence_cond_compressed_indices)
|
133 |
+
k_rope = self.rope_apply(k, grid_size, freqs, sequence_cond_compressed_indices)
|
134 |
+
if self.visualize_attention:
|
135 |
+
with torch.no_grad():
|
136 |
+
self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
|
137 |
+
self._last_grid_sizes = grid_size
|
138 |
+
self._last_seq_lens = seq_lens
|
139 |
+
x = flash_attention(
|
140 |
+
q=q_rope,
|
141 |
+
k=k_rope,
|
142 |
+
v=v,
|
143 |
+
k_lens=None,
|
144 |
+
window_size=self.window_size)
|
145 |
+
x = x.flatten(2)
|
146 |
+
x = self.o(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class DeRA(nn.Module):
|
151 |
+
def __init__(self, dim, rank, use_spatial=True, use_temporal=True):
|
152 |
+
super(DeRA, self).__init__()
|
153 |
+
self.dim = dim
|
154 |
+
self.rank = rank
|
155 |
+
self.use_spatial = use_spatial
|
156 |
+
self.use_temporal = use_temporal
|
157 |
+
|
158 |
+
if not use_spatial and not use_temporal:
|
159 |
+
self.attention_mode = "none"
|
160 |
+
else:
|
161 |
+
self.attention_mode = "spatial_temporal" if use_spatial and use_temporal else "spatial" if use_spatial else "temporal"
|
162 |
+
|
163 |
+
self.spatial_down_proj = nn.Linear(self.dim, rank, bias=False)
|
164 |
+
self.spatial_up_proj = nn.Linear(rank, self.dim, bias=False)
|
165 |
+
self.spatial_up_proj.weight.data.zero_()
|
166 |
+
if self.attention_mode != "none":
|
167 |
+
self.spatial_attn = DeRAAttention(dim=rank, num_heads=4, window_size=(-1, -1),
|
168 |
+
mode=self.attention_mode)
|
169 |
+
else:
|
170 |
+
self.spatial_attn = None
|
171 |
+
|
172 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
|
173 |
+
_, actual_seq, _ = x.shape
|
174 |
+
if isinstance(grid_sizes, torch.Tensor):
|
175 |
+
grid_sizes = tuple(grid_sizes[0].tolist())
|
176 |
+
|
177 |
+
if len(grid_sizes) != 3:
|
178 |
+
raise ValueError("`grid_sizes` should contain time, spatial height, and width dimensions")
|
179 |
+
_, orig_h, orig_w = grid_sizes
|
180 |
+
actual_t = actual_seq // (orig_h * orig_w)
|
181 |
+
|
182 |
+
x_low = self.spatial_down_proj(x)
|
183 |
+
if self.attention_mode == "spatial":
|
184 |
+
x_low_spatial = rearrange(x_low, 'b (t h w) r -> (b t) (h w) r', t=actual_t, h=orig_h, w=orig_w)
|
185 |
+
x_low_spatial = self.spatial_attn(x_low_spatial, seq_lens, grid_sizes[1:], freqs, sequence_cond_compressed_indices)
|
186 |
+
x_low = rearrange(x_low_spatial, '(b t) (h w) r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
|
187 |
+
elif self.attention_mode == "temporal":
|
188 |
+
x_low_temporal = rearrange(x_low, 'b (t h w) r -> (b h w) t r', t=actual_t, h=orig_h, w=orig_w)
|
189 |
+
x_low_temporal = self.spatial_attn(x_low_temporal, seq_lens, grid_sizes[:1], freqs, sequence_cond_compressed_indices)
|
190 |
+
x_low = rearrange(x_low_temporal, '(b h w) t r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
|
191 |
+
elif self.attention_mode == "spatial_temporal":
|
192 |
+
x_low = self.spatial_attn(x_low, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
|
193 |
+
x_out = self.spatial_up_proj(x_low)
|
194 |
+
return x_out
|
195 |
+
|
model/dit.py
ADDED
@@ -0,0 +1,1090 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.amp as amp
|
5 |
+
import torch.nn as nn
|
6 |
+
from util.model_util import hash_state_dict_keys
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
try:
|
10 |
+
import flash_attn_interface
|
11 |
+
FLASH_ATTN_3_AVAILABLE = True
|
12 |
+
except ModuleNotFoundError:
|
13 |
+
FLASH_ATTN_3_AVAILABLE = False
|
14 |
+
|
15 |
+
try:
|
16 |
+
import flash_attn
|
17 |
+
FLASH_ATTN_2_AVAILABLE = True
|
18 |
+
except ModuleNotFoundError:
|
19 |
+
FLASH_ATTN_2_AVAILABLE = False
|
20 |
+
|
21 |
+
try:
|
22 |
+
from sageattention import sageattn
|
23 |
+
SAGE_ATTN_AVAILABLE = True
|
24 |
+
except ModuleNotFoundError:
|
25 |
+
SAGE_ATTN_AVAILABLE = False
|
26 |
+
|
27 |
+
import warnings
|
28 |
+
|
29 |
+
|
30 |
+
__all__ = ['WanModel']
|
31 |
+
|
32 |
+
|
33 |
+
def flash_attention(
|
34 |
+
q,
|
35 |
+
k,
|
36 |
+
v,
|
37 |
+
q_lens=None,
|
38 |
+
k_lens=None,
|
39 |
+
dropout_p=0.,
|
40 |
+
softmax_scale=None,
|
41 |
+
q_scale=None,
|
42 |
+
causal=False,
|
43 |
+
window_size=(-1, -1),
|
44 |
+
deterministic=False,
|
45 |
+
dtype=torch.bfloat16,
|
46 |
+
version=None,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
q: [B, Lq, Nq, C1].
|
50 |
+
k: [B, Lk, Nk, C1].
|
51 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
52 |
+
q_lens: [B].
|
53 |
+
k_lens: [B].
|
54 |
+
dropout_p: float. Dropout probability.
|
55 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
56 |
+
causal: bool. Whether to apply causal attention mask.
|
57 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
58 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
59 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
60 |
+
"""
|
61 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
62 |
+
assert dtype in half_dtypes
|
63 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
64 |
+
|
65 |
+
# params
|
66 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
67 |
+
|
68 |
+
def half(x):
|
69 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
70 |
+
|
71 |
+
# preprocess query
|
72 |
+
if q_lens is None:
|
73 |
+
q = half(q.flatten(0, 1))
|
74 |
+
q_lens = torch.tensor(
|
75 |
+
[lq] * b, dtype=torch.int32).to(
|
76 |
+
device=q.device, non_blocking=True)
|
77 |
+
else:
|
78 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
79 |
+
|
80 |
+
# preprocess key, value
|
81 |
+
if k_lens is None:
|
82 |
+
k = half(k.flatten(0, 1))
|
83 |
+
v = half(v.flatten(0, 1))
|
84 |
+
k_lens = torch.tensor(
|
85 |
+
[lk] * b, dtype=torch.int32).to(
|
86 |
+
device=k.device, non_blocking=True)
|
87 |
+
else:
|
88 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
89 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
90 |
+
|
91 |
+
q = q.to(v.dtype)
|
92 |
+
k = k.to(v.dtype)
|
93 |
+
|
94 |
+
if q_scale is not None:
|
95 |
+
q = q * q_scale
|
96 |
+
|
97 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
98 |
+
warnings.warn(
|
99 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
100 |
+
)
|
101 |
+
|
102 |
+
# apply attention
|
103 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
104 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
105 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
106 |
+
q=q,
|
107 |
+
k=k,
|
108 |
+
v=v,
|
109 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
110 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
111 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
112 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
113 |
+
seqused_q=None,
|
114 |
+
seqused_k=None,
|
115 |
+
max_seqlen_q=lq,
|
116 |
+
max_seqlen_k=lk,
|
117 |
+
softmax_scale=softmax_scale,
|
118 |
+
causal=causal,
|
119 |
+
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
120 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
121 |
+
x = flash_attn.flash_attn_varlen_func(
|
122 |
+
q=q,
|
123 |
+
k=k,
|
124 |
+
v=v,
|
125 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
126 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
127 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
128 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
129 |
+
max_seqlen_q=lq,
|
130 |
+
max_seqlen_k=lk,
|
131 |
+
dropout_p=dropout_p,
|
132 |
+
softmax_scale=softmax_scale,
|
133 |
+
causal=causal,
|
134 |
+
window_size=window_size,
|
135 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
136 |
+
elif SAGE_ATTN_AVAILABLE:
|
137 |
+
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
138 |
+
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
139 |
+
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
|
140 |
+
x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
|
141 |
+
x = x.transpose(1, 2).contiguous()
|
142 |
+
else:
|
143 |
+
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
144 |
+
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
145 |
+
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
|
146 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
147 |
+
x = x.transpose(1, 2).contiguous()
|
148 |
+
|
149 |
+
# output
|
150 |
+
return x.type(out_dtype)
|
151 |
+
|
152 |
+
|
153 |
+
def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
|
154 |
+
b, lq, lk = q.size(0), q.size(1), k.size(1)
|
155 |
+
if q_lens is None:
|
156 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32)
|
157 |
+
if k_lens is None:
|
158 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32)
|
159 |
+
attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
|
160 |
+
for i in range(b):
|
161 |
+
q_len, k_len = q_lens[i], k_lens[i]
|
162 |
+
attn_mask[i, q_len:, :] = True
|
163 |
+
attn_mask[i, :, k_len:] = True
|
164 |
+
|
165 |
+
if causal:
|
166 |
+
causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
|
167 |
+
attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
|
168 |
+
|
169 |
+
attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
|
170 |
+
return attn_mask
|
171 |
+
|
172 |
+
|
173 |
+
def attention(
|
174 |
+
q,
|
175 |
+
k,
|
176 |
+
v,
|
177 |
+
q_lens=None,
|
178 |
+
k_lens=None,
|
179 |
+
dropout_p=0.,
|
180 |
+
softmax_scale=None,
|
181 |
+
q_scale=None,
|
182 |
+
causal=False,
|
183 |
+
window_size=(-1, -1),
|
184 |
+
deterministic=False,
|
185 |
+
dtype=torch.bfloat16,
|
186 |
+
fa_version=None,
|
187 |
+
):
|
188 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
189 |
+
return flash_attention(
|
190 |
+
q=q,
|
191 |
+
k=k,
|
192 |
+
v=v,
|
193 |
+
q_lens=q_lens,
|
194 |
+
k_lens=k_lens,
|
195 |
+
dropout_p=dropout_p,
|
196 |
+
softmax_scale=softmax_scale,
|
197 |
+
q_scale=q_scale,
|
198 |
+
causal=causal,
|
199 |
+
window_size=window_size,
|
200 |
+
deterministic=deterministic,
|
201 |
+
dtype=dtype,
|
202 |
+
version=fa_version,
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
if q_lens is not None or k_lens is not None:
|
206 |
+
warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
|
207 |
+
attn_mask = None
|
208 |
+
|
209 |
+
q = q.transpose(1, 2).to(dtype)
|
210 |
+
k = k.transpose(1, 2).to(dtype)
|
211 |
+
v = v.transpose(1, 2).to(dtype)
|
212 |
+
|
213 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
214 |
+
|
215 |
+
out = out.transpose(1, 2).contiguous()
|
216 |
+
return out
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def sinusoidal_embedding_1d(dim, position):
|
221 |
+
# preprocess
|
222 |
+
assert dim % 2 == 0
|
223 |
+
half = dim // 2
|
224 |
+
position = position.type(torch.float64)
|
225 |
+
|
226 |
+
# calculation
|
227 |
+
sinusoid = torch.outer(
|
228 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
229 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
@amp.autocast(enabled=False, device_type="cuda")
|
234 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
235 |
+
assert dim % 2 == 0
|
236 |
+
freqs = torch.outer(
|
237 |
+
torch.arange(max_seq_len),
|
238 |
+
1.0 / torch.pow(theta,
|
239 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
240 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
241 |
+
return freqs
|
242 |
+
|
243 |
+
|
244 |
+
@amp.autocast(enabled=False, device_type="cuda")
|
245 |
+
def rope_apply(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
|
246 |
+
batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
247 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
248 |
+
output = []
|
249 |
+
assert len(grid_sizes) == batch, "grid_sizes must have the same length as the batch size ([b, 3=[f, h, w])"
|
250 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
251 |
+
seq_len = f * h * w
|
252 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
253 |
+
seq_len, n, -1, 2))
|
254 |
+
freqs_i = torch.cat([
|
255 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
256 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
257 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
258 |
+
], dim=-1).reshape(seq_len, 1, -1)
|
259 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
260 |
+
|
261 |
+
if seq_len_actual > seq_len:
|
262 |
+
sequence_cond_seq_length = seq_len_actual - seq_len
|
263 |
+
if sequence_cond_seq_length == seq_len:
|
264 |
+
x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(seq_len_actual - seq_len, n, -1, 2))
|
265 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(2)
|
266 |
+
else:
|
267 |
+
sequence_cond_compressed_index = sequence_cond_compressed_indices[i]
|
268 |
+
sequence_cond_t_length = len(sequence_cond_compressed_index)
|
269 |
+
assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
|
270 |
+
x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(sequence_cond_seq_length, n, -1, 2))
|
271 |
+
freqs_i_sequence_cond = torch.cat([
|
272 |
+
freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
|
273 |
+
freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
|
274 |
+
freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
|
275 |
+
], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
|
276 |
+
x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(2)
|
277 |
+
x_i = torch.cat([x_i, x_i_sequence_cond])
|
278 |
+
|
279 |
+
output.append(x_i)
|
280 |
+
return torch.stack(output).float()
|
281 |
+
|
282 |
+
|
283 |
+
class WanRMSNorm(nn.Module):
|
284 |
+
|
285 |
+
def __init__(self, dim, eps=1e-5):
|
286 |
+
super().__init__()
|
287 |
+
self.dim = dim
|
288 |
+
self.eps = eps
|
289 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
293 |
+
|
294 |
+
def _norm(self, x):
|
295 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
296 |
+
|
297 |
+
|
298 |
+
class WanLayerNorm(nn.LayerNorm):
|
299 |
+
|
300 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
301 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
return super().forward(x.float()).type_as(x)
|
305 |
+
|
306 |
+
|
307 |
+
class WanSelfAttention(nn.Module):
|
308 |
+
|
309 |
+
def __init__(self,
|
310 |
+
dim,
|
311 |
+
num_heads,
|
312 |
+
window_size=(-1, -1),
|
313 |
+
qk_norm=True,
|
314 |
+
eps=1e-6):
|
315 |
+
assert dim % num_heads == 0
|
316 |
+
super().__init__()
|
317 |
+
self.dim = dim
|
318 |
+
self.num_heads = num_heads
|
319 |
+
self.head_dim = dim // num_heads
|
320 |
+
self.window_size = window_size
|
321 |
+
self.qk_norm = qk_norm
|
322 |
+
self.eps = eps
|
323 |
+
|
324 |
+
self.q = nn.Linear(dim, dim)
|
325 |
+
self.k = nn.Linear(dim, dim)
|
326 |
+
self.v = nn.Linear(dim, dim)
|
327 |
+
self.o = nn.Linear(dim, dim)
|
328 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
329 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
330 |
+
self.visualize_attention = False
|
331 |
+
|
332 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
|
333 |
+
"""
|
334 |
+
Args:
|
335 |
+
x: [B, L, C].
|
336 |
+
seq_lens: [B].
|
337 |
+
grid_sizes: [B, 3=[f, h, w]].
|
338 |
+
freqs: [L, 2].
|
339 |
+
sequence_cond_compressed_indices: [B, T_sequence_condITION].
|
340 |
+
|
341 |
+
`f` in `grid_sizes` can less than the actual seq_lens (L),
|
342 |
+
which indicates full in-context condition (when L=2*f) or
|
343 |
+
sparse in-context condition (when `f` < L < 2*f and `sequence_cond_compressed_indices` is not None) is used.
|
344 |
+
"""
|
345 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
346 |
+
|
347 |
+
def qkv_fn(x):
|
348 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
349 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
350 |
+
v = self.v(x).view(b, s, n, d)
|
351 |
+
return q, k, v
|
352 |
+
|
353 |
+
q, k, v = qkv_fn(x)
|
354 |
+
|
355 |
+
q_rope = rope_apply(q, grid_sizes, freqs, sequence_cond_compressed_indices)
|
356 |
+
k_rope = rope_apply(k, grid_sizes, freqs, sequence_cond_compressed_indices)
|
357 |
+
|
358 |
+
if self.visualize_attention:
|
359 |
+
with torch.no_grad():
|
360 |
+
self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
|
361 |
+
self._last_grid_sizes = grid_sizes
|
362 |
+
self._last_seq_lens = seq_lens
|
363 |
+
|
364 |
+
x = flash_attention(
|
365 |
+
q=q_rope,
|
366 |
+
k=k_rope,
|
367 |
+
v=v,
|
368 |
+
k_lens=None,
|
369 |
+
window_size=self.window_size)
|
370 |
+
|
371 |
+
# output
|
372 |
+
x = x.flatten(2)
|
373 |
+
x = self.o(x)
|
374 |
+
return x
|
375 |
+
|
376 |
+
def _compute_attention_for_visualization(self, q, k):
|
377 |
+
"""Compute attention maps for visualization purposes"""
|
378 |
+
# b, _, n, d = q.shape
|
379 |
+
print("Computing attention maps for visualization")
|
380 |
+
# Reshape for attention computation
|
381 |
+
q = q.permute(0, 2, 1, 3) # [b, n, s, d]
|
382 |
+
k = k.permute(0, 2, 1, 3) # [b, n, s, d]
|
383 |
+
# query: b, n, s, d
|
384 |
+
print("q.shape=", q.shape)
|
385 |
+
print("k.shape=", k.shape)
|
386 |
+
attention_probs_list = []
|
387 |
+
for i in range(0, q.shape[1], 20):
|
388 |
+
print(f"Computing attention for head {i} to {i+20}")
|
389 |
+
query_attention = q[-1][i : i + 20]
|
390 |
+
key_attention = k[-1][i : i + 20]
|
391 |
+
identity_matrix = torch.eye(
|
392 |
+
query_attention.shape[-2],
|
393 |
+
device=query_attention.device,
|
394 |
+
dtype=query_attention.dtype,
|
395 |
+
) # shape=[s]
|
396 |
+
attention_probs_temp = torch.nn.functional.scaled_dot_product_attention(
|
397 |
+
query_attention,
|
398 |
+
key_attention,
|
399 |
+
identity_matrix,
|
400 |
+
attn_mask=None,
|
401 |
+
dropout_p=0.0,
|
402 |
+
is_causal=False,
|
403 |
+
)
|
404 |
+
attention_probs_list.append(attention_probs_temp.detach().cpu())
|
405 |
+
del (
|
406 |
+
query_attention,
|
407 |
+
key_attention,
|
408 |
+
identity_matrix,
|
409 |
+
attention_probs_temp,
|
410 |
+
)
|
411 |
+
attention_probs = torch.mean(torch.cat(attention_probs_list), dim=0).float().numpy()
|
412 |
+
print("Attention maps computed. Shape=", attention_probs.shape)
|
413 |
+
# Only keep attention maps, don't compute the output
|
414 |
+
return attention_probs # [s, s]
|
415 |
+
|
416 |
+
|
417 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
418 |
+
|
419 |
+
def forward(self, x, context, context_lens):
|
420 |
+
"""
|
421 |
+
x: [B, L1, C].
|
422 |
+
context: [B, L2, C].
|
423 |
+
context_lens: [B].
|
424 |
+
"""
|
425 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
426 |
+
|
427 |
+
# compute query, key, value
|
428 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
429 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
430 |
+
v = self.v(context).view(b, -1, n, d)
|
431 |
+
|
432 |
+
# compute attention
|
433 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
434 |
+
|
435 |
+
# output
|
436 |
+
x = x.flatten(2)
|
437 |
+
x = self.o(x)
|
438 |
+
return x
|
439 |
+
|
440 |
+
|
441 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
442 |
+
|
443 |
+
def __init__(self,
|
444 |
+
dim,
|
445 |
+
num_heads,
|
446 |
+
window_size=(-1, -1),
|
447 |
+
qk_norm=True,
|
448 |
+
eps=1e-6):
|
449 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
450 |
+
|
451 |
+
self.k_img = nn.Linear(dim, dim)
|
452 |
+
self.v_img = nn.Linear(dim, dim)
|
453 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
454 |
+
self.norm_k_img = WanRMSNorm(
|
455 |
+
dim, eps=eps) if qk_norm else nn.Identity()
|
456 |
+
|
457 |
+
def forward(self, x, context, context_lens):
|
458 |
+
"""
|
459 |
+
x: [B, L1, C].
|
460 |
+
context: [B, L2, C].
|
461 |
+
context_lens: [B].
|
462 |
+
"""
|
463 |
+
context_img = context[:, :257]
|
464 |
+
context = context[:, 257:]
|
465 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
466 |
+
|
467 |
+
# compute query, key, value
|
468 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
469 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
470 |
+
v = self.v(context).view(b, -1, n, d)
|
471 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
472 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
473 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
474 |
+
# compute attention
|
475 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
476 |
+
|
477 |
+
# output
|
478 |
+
x = x.flatten(2)
|
479 |
+
img_x = img_x.flatten(2)
|
480 |
+
x = x + img_x
|
481 |
+
x = self.o(x)
|
482 |
+
return x
|
483 |
+
|
484 |
+
|
485 |
+
WANX_CROSSATTENTION_CLASSES = {
|
486 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
487 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
488 |
+
}
|
489 |
+
|
490 |
+
|
491 |
+
class WanAttentionBlock(nn.Module):
|
492 |
+
|
493 |
+
def __init__(self,
|
494 |
+
cross_attn_type,
|
495 |
+
dim,
|
496 |
+
ffn_dim,
|
497 |
+
num_heads,
|
498 |
+
window_size=(-1, -1),
|
499 |
+
qk_norm=True,
|
500 |
+
cross_attn_norm=False,
|
501 |
+
eps=1e-6,
|
502 |
+
use_local_lora=False,
|
503 |
+
use_dera=False,
|
504 |
+
dera_rank=None,
|
505 |
+
use_dera_spatial=True,
|
506 |
+
use_dera_temporal=True):
|
507 |
+
super().__init__()
|
508 |
+
self.dim = dim
|
509 |
+
self.ffn_dim = ffn_dim
|
510 |
+
self.num_heads = num_heads
|
511 |
+
self.window_size = window_size
|
512 |
+
self.qk_norm = qk_norm
|
513 |
+
self.cross_attn_norm = cross_attn_norm
|
514 |
+
self.eps = eps
|
515 |
+
|
516 |
+
# layers
|
517 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
518 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
519 |
+
self.norm3 = WanLayerNorm(
|
520 |
+
dim, eps,
|
521 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
522 |
+
self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
|
523 |
+
dim, num_heads, (-1, -1), qk_norm, eps)
|
524 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
525 |
+
self.ffn = nn.Sequential(
|
526 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
527 |
+
nn.Linear(ffn_dim, dim))
|
528 |
+
|
529 |
+
# modulation
|
530 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
531 |
+
|
532 |
+
self.use_local_lora = use_local_lora
|
533 |
+
if use_local_lora:
|
534 |
+
from .local_lora import LocalLoRA
|
535 |
+
self.local_lora = LocalLoRA(dim=dim, rank=64, kernel_size=(3, 3), stride=(1, 1))
|
536 |
+
|
537 |
+
self.use_dera = use_dera
|
538 |
+
if use_dera:
|
539 |
+
from .dera import DeRA
|
540 |
+
self.dera = DeRA(dim, rank=dera_rank, use_spatial=use_dera_spatial, use_temporal=use_dera_temporal)
|
541 |
+
|
542 |
+
def forward(
|
543 |
+
self,
|
544 |
+
x,
|
545 |
+
e,
|
546 |
+
seq_lens,
|
547 |
+
grid_sizes,
|
548 |
+
freqs,
|
549 |
+
context,
|
550 |
+
context_lens,
|
551 |
+
sequence_cond_compressed_indices,
|
552 |
+
dera_freqs=None
|
553 |
+
):
|
554 |
+
assert e.dtype == torch.float32
|
555 |
+
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
556 |
+
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
|
557 |
+
assert e[0].dtype == torch.float32
|
558 |
+
|
559 |
+
# self-attention
|
560 |
+
x_self_attn_input = self.norm1(x).float() * (1 + e[1]) + e[0]
|
561 |
+
y = self.self_attn(x_self_attn_input, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
|
562 |
+
if self.use_local_lora:
|
563 |
+
y = y + self.local_lora(x_self_attn_input, grid_sizes)
|
564 |
+
|
565 |
+
if self.use_dera:
|
566 |
+
y = y + self.dera(x_self_attn_input, seq_lens, grid_sizes, dera_freqs, sequence_cond_compressed_indices)
|
567 |
+
|
568 |
+
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
569 |
+
x = x + y * e[2]
|
570 |
+
|
571 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
572 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
573 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
574 |
+
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
575 |
+
x = x + y * e[5]
|
576 |
+
return x
|
577 |
+
|
578 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
579 |
+
return x
|
580 |
+
|
581 |
+
|
582 |
+
class Head(nn.Module):
|
583 |
+
|
584 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
585 |
+
super().__init__()
|
586 |
+
self.dim = dim
|
587 |
+
self.out_dim = out_dim
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.eps = eps
|
590 |
+
|
591 |
+
# layers
|
592 |
+
out_dim = math.prod(patch_size) * out_dim
|
593 |
+
self.norm = WanLayerNorm(dim, eps)
|
594 |
+
self.head = nn.Linear(dim, out_dim)
|
595 |
+
|
596 |
+
# modulation
|
597 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
598 |
+
|
599 |
+
def forward(self, x, e):
|
600 |
+
assert e.dtype == torch.float32
|
601 |
+
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
602 |
+
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
603 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
604 |
+
return x
|
605 |
+
|
606 |
+
|
607 |
+
class MLPProj(torch.nn.Module):
|
608 |
+
|
609 |
+
def __init__(self, in_dim, out_dim):
|
610 |
+
super().__init__()
|
611 |
+
|
612 |
+
self.proj = torch.nn.Sequential(
|
613 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
614 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
615 |
+
torch.nn.LayerNorm(out_dim))
|
616 |
+
|
617 |
+
def forward(self, image_embeds):
|
618 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
619 |
+
return clip_extra_context_tokens
|
620 |
+
|
621 |
+
|
622 |
+
class WanModel(nn.Module):
|
623 |
+
|
624 |
+
def __init__(self,
|
625 |
+
model_type='t2v',
|
626 |
+
patch_size=(1, 2, 2),
|
627 |
+
text_len=512,
|
628 |
+
in_dim=16,
|
629 |
+
dim=2048,
|
630 |
+
ffn_dim=8192,
|
631 |
+
freq_dim=256,
|
632 |
+
text_dim=4096,
|
633 |
+
out_dim=16,
|
634 |
+
num_heads=16,
|
635 |
+
num_layers=32,
|
636 |
+
window_size=(-1, -1),
|
637 |
+
qk_norm=True,
|
638 |
+
cross_attn_norm=False,
|
639 |
+
eps=1e-6,
|
640 |
+
use_local_lora=False,
|
641 |
+
use_dera=False,
|
642 |
+
dera_rank=None,
|
643 |
+
use_dera_spatial=True,
|
644 |
+
use_dera_temporal=True,
|
645 |
+
use_sequence_cond=False,
|
646 |
+
sequence_cond_in_dim=None,
|
647 |
+
sequence_cond_mode=None,
|
648 |
+
use_channel_cond=False,
|
649 |
+
channel_cond_in_dim=None,
|
650 |
+
use_sequence_cond_position_aware_residual=False,
|
651 |
+
use_sequence_cond_loss=False
|
652 |
+
):
|
653 |
+
super().__init__()
|
654 |
+
|
655 |
+
assert model_type in ['t2v', 'i2v']
|
656 |
+
self.model_type = model_type
|
657 |
+
|
658 |
+
self.patch_size = patch_size
|
659 |
+
self.text_len = text_len
|
660 |
+
self.in_dim = in_dim
|
661 |
+
self.dim = dim
|
662 |
+
self.ffn_dim = ffn_dim
|
663 |
+
self.freq_dim = freq_dim
|
664 |
+
self.text_dim = text_dim
|
665 |
+
self.out_dim = out_dim
|
666 |
+
self.num_heads = num_heads
|
667 |
+
self.num_layers = num_layers
|
668 |
+
self.window_size = window_size
|
669 |
+
self.qk_norm = qk_norm
|
670 |
+
self.cross_attn_norm = cross_attn_norm
|
671 |
+
self.eps = eps
|
672 |
+
|
673 |
+
self.use_local_lora = use_local_lora
|
674 |
+
self.use_dera = use_dera
|
675 |
+
|
676 |
+
# embeddings
|
677 |
+
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
678 |
+
self.text_embedding = nn.Sequential(
|
679 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
680 |
+
nn.Linear(dim, dim))
|
681 |
+
|
682 |
+
self.time_embedding = nn.Sequential(
|
683 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
684 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
685 |
+
|
686 |
+
# blocks
|
687 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
688 |
+
self.blocks = nn.ModuleList([
|
689 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
690 |
+
window_size, qk_norm, cross_attn_norm, eps, use_local_lora=use_local_lora,
|
691 |
+
use_dera=use_dera, dera_rank=dera_rank, use_dera_spatial=use_dera_spatial, use_dera_temporal=use_dera_temporal)
|
692 |
+
for _ in range(num_layers)
|
693 |
+
])
|
694 |
+
|
695 |
+
# head
|
696 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
697 |
+
|
698 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
699 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
700 |
+
d = dim // num_heads
|
701 |
+
self.freqs = torch.cat([
|
702 |
+
rope_params(1024, d - 4 * (d // 6)),
|
703 |
+
rope_params(1024, 2 * (d // 6)),
|
704 |
+
rope_params(1024, 2 * (d // 6))
|
705 |
+
], dim=1)
|
706 |
+
|
707 |
+
if self.use_dera:
|
708 |
+
dera_d = dera_rank // 4 # (18)
|
709 |
+
self.dera_freqs = torch.cat([
|
710 |
+
rope_params(1024, dera_d - 4 * (dera_d // 6)),
|
711 |
+
rope_params(1024, 2 * (dera_d // 6)),
|
712 |
+
rope_params(1024, 2 * (dera_d // 6))
|
713 |
+
], dim=1)
|
714 |
+
else:
|
715 |
+
self.dera_freqs = None
|
716 |
+
|
717 |
+
if model_type == 'i2v':
|
718 |
+
self.img_emb = MLPProj(1280, dim)
|
719 |
+
|
720 |
+
self.init_weights()
|
721 |
+
|
722 |
+
self.use_sequence_cond = use_sequence_cond
|
723 |
+
self.sequence_cond_in_dim = sequence_cond_in_dim
|
724 |
+
self.sequence_cond_mode = sequence_cond_mode
|
725 |
+
if use_sequence_cond:
|
726 |
+
assert sequence_cond_in_dim is not None, "`sequence_cond_in_dim` must be provided when `use_sequence_cond` is True"
|
727 |
+
self.sequence_cond_patch_embedding = nn.Conv3d(sequence_cond_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
728 |
+
self.sequence_cond_identifier = nn.Parameter(torch.randn(1, 1, dim) / dim**0.5)
|
729 |
+
|
730 |
+
self.use_channel_cond = use_channel_cond
|
731 |
+
self.channel_cond_in_dim = channel_cond_in_dim
|
732 |
+
if use_channel_cond:
|
733 |
+
assert channel_cond_in_dim is not None, "`channel_cond_in_dim` must be provided when `use_channel_cond` is True"
|
734 |
+
self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
|
735 |
+
if use_sequence_cond_position_aware_residual:
|
736 |
+
self.sequence_cond_residual_proj = nn.Linear(dim, dim, bias=False)
|
737 |
+
self.sequence_cond_residual_proj.weight.data.zero_()
|
738 |
+
|
739 |
+
self.use_sequence_cond_loss = use_sequence_cond_loss
|
740 |
+
if self.use_sequence_cond_loss:
|
741 |
+
self.sequence_latent_to_cond_proj = nn.Linear(dim, dim, bias=False)
|
742 |
+
self.sequence_latent_to_cond_proj.weight.data.zero_()
|
743 |
+
self.head_sequence_cond_out = nn.Linear(dim, math.prod(patch_size) * out_dim)
|
744 |
+
|
745 |
+
def copy_sequence_cond_patch_embedding_weights(self):
|
746 |
+
size_patch_embedding = self.patch_embedding.weight.size(1)
|
747 |
+
size_sequence_cond_patch_embedding = self.sequence_cond_patch_embedding.weight.size(1)
|
748 |
+
self.sequence_cond_patch_embedding.weight.data = self.patch_embedding.weight.data[:, size_patch_embedding - size_sequence_cond_patch_embedding:, :, :, :].clone()
|
749 |
+
if self.patch_embedding.bias is not None:
|
750 |
+
self.sequence_cond_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
|
751 |
+
|
752 |
+
def copy_patch_embedding_weights_for_channel_cond(self):
|
753 |
+
original_patch_in_channels = self.patch_embedding.in_channels
|
754 |
+
new_patch_embedding = nn.Conv3d(in_channels=original_patch_in_channels + self.channel_cond_in_dim,
|
755 |
+
out_channels=self.dim, kernel_size=self.patch_size, stride=self.patch_size)
|
756 |
+
new_patch_embedding.weight.data[:, :original_patch_in_channels, :, :, :] = self.patch_embedding.weight.data.clone()
|
757 |
+
if self.patch_embedding.bias is not None:
|
758 |
+
new_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
|
759 |
+
del self.patch_embedding
|
760 |
+
self.patch_embedding = new_patch_embedding
|
761 |
+
|
762 |
+
def forward(
|
763 |
+
self,
|
764 |
+
x,
|
765 |
+
timestep,
|
766 |
+
context,
|
767 |
+
seq_len,
|
768 |
+
clip_fea=None,
|
769 |
+
y=None,
|
770 |
+
use_gradient_checkpointing=False,
|
771 |
+
sequence_cond=None,
|
772 |
+
sequence_cond_compressed_indices=None,
|
773 |
+
channel_cond=None,
|
774 |
+
sequence_cond_residual_scale=1.0,
|
775 |
+
**kwargs,
|
776 |
+
):
|
777 |
+
"""
|
778 |
+
x: A list of videos each with shape [C, T, H, W].
|
779 |
+
t: [B].
|
780 |
+
context: A list of text embeddings each with shape [L, C].
|
781 |
+
sequence_cond: A list of conditional frames each with shape [C, T_sequence_cond, H, W].
|
782 |
+
sequence_cond_compressed_indices: [B, T_sequence_cond] Indices for any additional conditional information, where T_sequence_cond < T. For sparse mode only.
|
783 |
+
|
784 |
+
|
785 |
+
Note:
|
786 |
+
sequence_cond will be injected into the model as an additional input sequence, i.e., sequence dimension.
|
787 |
+
channel_cond will be injected into the model in the input' channel dimension.
|
788 |
+
|
789 |
+
Examples:
|
790 |
+
1) for extra cond case:
|
791 |
+
# given x: [B, C, T, H, W] ----> [B, L=T*H*W, C] --patch_embedding--> [B, L, D]
|
792 |
+
# sequence_cond: [B, C_sequence_cond, T_sequence_cond, H, W] ----> [B, L_sequence_cond=T_sequence_cond*H*W, C_sequence_cond] --sequence_cond_embedding--> [B, L_sequence_cond, D]
|
793 |
+
x = torch.concat([x, sequence_cond], dim=2) # Concat on sequence dimension after patch/extra cond embedding
|
794 |
+
# after concat, x: [B, L+L_sequence_cond, D]
|
795 |
+
2) for channel cond case:
|
796 |
+
given x: [B, C, T, H, W]
|
797 |
+
channel_cond: [B, C_CHANNEL_COND, T, H, W]
|
798 |
+
x = torch.concat([x, channel_cond], dim=1) # Concat on channel dimension before patch/extra cond embedding
|
799 |
+
# x: [B, C + C_CHANNEL_COND, T, H, W] --patch_embedding(requires param copy and tuning)--> [B, L=T*H*W, D]
|
800 |
+
"""
|
801 |
+
if self.model_type == 'i2v':
|
802 |
+
assert clip_fea is not None and y is not None
|
803 |
+
# params
|
804 |
+
device = x[0].device
|
805 |
+
if self.freqs.device != device:
|
806 |
+
self.freqs = self.freqs.to(device)
|
807 |
+
if self.dera_freqs is not None and self.dera_freqs.device != device:
|
808 |
+
self.dera_freqs = self.dera_freqs.to(device)
|
809 |
+
|
810 |
+
if y is not None:
|
811 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
812 |
+
|
813 |
+
if channel_cond is not None:
|
814 |
+
assert self.use_channel_cond, "forward argument `channel_cond` is provided but model property `self.use_channel_cond` is False"
|
815 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, channel_cond)]
|
816 |
+
|
817 |
+
# embeddings
|
818 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
819 |
+
grid_sizes = torch.stack(
|
820 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
821 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
822 |
+
x = torch.cat(x, dim=0)
|
823 |
+
|
824 |
+
if sequence_cond is not None:
|
825 |
+
assert self.use_sequence_cond, "forward argument `sequence_cond` is provided but model property `self.use_sequence_cond` is False"
|
826 |
+
sequence_cond = [self.sequence_cond_patch_embedding(u.unsqueeze(0)) for u in sequence_cond]
|
827 |
+
sequence_cond = [u.flatten(2).transpose(1, 2) + self.sequence_cond_identifier for u in sequence_cond]
|
828 |
+
sequence_cond = torch.concat(sequence_cond, dim=0)
|
829 |
+
|
830 |
+
x = torch.concat([x, sequence_cond], dim=1)
|
831 |
+
|
832 |
+
actual_seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
833 |
+
|
834 |
+
# time embeddings
|
835 |
+
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
836 |
+
e = self.time_embedding(
|
837 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep).float())
|
838 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
839 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
840 |
+
|
841 |
+
# context
|
842 |
+
context_lens = None
|
843 |
+
context = self.text_embedding(
|
844 |
+
torch.stack([
|
845 |
+
torch.cat(
|
846 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
847 |
+
for u in context
|
848 |
+
]))
|
849 |
+
|
850 |
+
if clip_fea is not None:
|
851 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
852 |
+
context = torch.concat([context_clip, context], dim=1)
|
853 |
+
|
854 |
+
# arguments
|
855 |
+
kwargs = dict(e=e0, seq_lens=actual_seq_lens, grid_sizes=grid_sizes,
|
856 |
+
freqs=self.freqs, context=context, context_lens=context_lens,
|
857 |
+
sequence_cond_compressed_indices=sequence_cond_compressed_indices, dera_freqs=self.dera_freqs)
|
858 |
+
|
859 |
+
def create_custom_forward(module):
|
860 |
+
def custom_forward(*inputs, **kwargs):
|
861 |
+
return module(*inputs, **kwargs)
|
862 |
+
return custom_forward
|
863 |
+
|
864 |
+
for block_idx, block in enumerate(self.blocks):
|
865 |
+
if self.training and use_gradient_checkpointing:
|
866 |
+
x = torch.utils.checkpoint.checkpoint(
|
867 |
+
create_custom_forward(block),
|
868 |
+
x, **kwargs,
|
869 |
+
use_reentrant=False,
|
870 |
+
)
|
871 |
+
else:
|
872 |
+
x = block(x, **kwargs)
|
873 |
+
|
874 |
+
if self.use_sequence_cond_loss and block_idx == len(self.blocks) - 3:
|
875 |
+
# This this function, the context length will be extended from (N+C) to 2N, where C is the length of the sparse sequence cond.
|
876 |
+
x_ori = x[:, :seq_len, :]
|
877 |
+
x_ori_projected = self.sequence_latent_to_cond_proj(x_ori)
|
878 |
+
x_seq_cond = x[:, seq_len:, :]
|
879 |
+
seq_cond_length = len(sequence_cond_compressed_indices[0])
|
880 |
+
x_ori_projected = rearrange(x_ori_projected, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
|
881 |
+
x_seq_cond = rearrange(x_seq_cond, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
|
882 |
+
x_ori_projected[:, :, sequence_cond_compressed_indices[0], :, :] += x_seq_cond
|
883 |
+
x_ori_projected = rearrange(x_ori_projected, 'b c t h w -> b (t h w) c')
|
884 |
+
x = torch.concat([x_ori, x_ori_projected], dim=1)
|
885 |
+
# Let the later blocks generate sketches at the full seqeuence length
|
886 |
+
|
887 |
+
if self.use_sequence_cond_position_aware_residual and block_idx < len(self.blocks) - 1:
|
888 |
+
# Apply the sequence condition position-aware residual for all blocks except the last one
|
889 |
+
x_ori = x[:, :seq_len, :]
|
890 |
+
x_seq_cond = x[:, seq_len:, :]
|
891 |
+
x_seq_cond_porjected = self.sequence_cond_residual_proj(x_seq_cond)
|
892 |
+
assert x_ori.shape[0] == 1, "Only support batch size 1 for `sequence_cond_position_aware_residual`."
|
893 |
+
seq_cond_length = len(sequence_cond_compressed_indices[0])
|
894 |
+
x_ori = rearrange(x_ori, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
|
895 |
+
x_seq_cond_porjected = rearrange(x_seq_cond_porjected, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
|
896 |
+
|
897 |
+
x_ori[:, :, sequence_cond_compressed_indices[0], :, :] = x_ori[:, :, sequence_cond_compressed_indices[0], :, :] + x_seq_cond_porjected * sequence_cond_residual_scale
|
898 |
+
x_ori = rearrange(x_ori, 'b c t h w -> b (t h w) c')
|
899 |
+
x = torch.concat([x_ori, x_seq_cond], dim=1)
|
900 |
+
|
901 |
+
if sequence_cond is not None:
|
902 |
+
if self.use_sequence_cond_loss:
|
903 |
+
sequence_cond_out = x[:, seq_len:, :]
|
904 |
+
sequence_cond_out = self.unpatchify(sequence_cond_out, grid_sizes) # sequence_cond_grid_sizes
|
905 |
+
sequence_cond_out = torch.stack(sequence_cond_out).float() # b, c, t, h, w
|
906 |
+
else:
|
907 |
+
sequence_cond_out = None
|
908 |
+
x = x[:, :seq_len, :]
|
909 |
+
# head
|
910 |
+
x = self.head(x, e)
|
911 |
+
|
912 |
+
# unpatchify
|
913 |
+
x = self.unpatchify(x, grid_sizes)
|
914 |
+
x = torch.stack(x).float()
|
915 |
+
if sequence_cond is not None and self.use_sequence_cond_loss:
|
916 |
+
return x, sequence_cond_out
|
917 |
+
return x
|
918 |
+
|
919 |
+
def unpatchify(self, x, grid_sizes):
|
920 |
+
c = self.out_dim
|
921 |
+
out = []
|
922 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
923 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
924 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
925 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
926 |
+
out.append(u)
|
927 |
+
return out
|
928 |
+
|
929 |
+
def init_weights(self):
|
930 |
+
for m in self.modules():
|
931 |
+
if isinstance(m, nn.Linear):
|
932 |
+
nn.init.xavier_uniform_(m.weight)
|
933 |
+
if m.bias is not None:
|
934 |
+
nn.init.zeros_(m.bias)
|
935 |
+
|
936 |
+
# init embeddings
|
937 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
938 |
+
for m in self.text_embedding.modules():
|
939 |
+
if isinstance(m, nn.Linear):
|
940 |
+
nn.init.normal_(m.weight, std=.02)
|
941 |
+
for m in self.time_embedding.modules():
|
942 |
+
if isinstance(m, nn.Linear):
|
943 |
+
nn.init.normal_(m.weight, std=.02)
|
944 |
+
|
945 |
+
# init output layer
|
946 |
+
nn.init.zeros_(self.head.head.weight)
|
947 |
+
|
948 |
+
@staticmethod
|
949 |
+
def state_dict_converter():
|
950 |
+
return WanModelStateDictConverter()
|
951 |
+
|
952 |
+
|
953 |
+
class WanModelStateDictConverter:
|
954 |
+
def __init__(self):
|
955 |
+
pass
|
956 |
+
|
957 |
+
def from_diffusers(self, state_dict):
|
958 |
+
rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
959 |
+
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
960 |
+
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
961 |
+
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
962 |
+
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
963 |
+
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
964 |
+
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
965 |
+
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
966 |
+
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
967 |
+
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
968 |
+
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
969 |
+
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
970 |
+
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
971 |
+
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
972 |
+
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
973 |
+
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
974 |
+
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
975 |
+
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
976 |
+
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
977 |
+
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
978 |
+
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
979 |
+
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
980 |
+
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
981 |
+
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
982 |
+
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
983 |
+
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
984 |
+
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
985 |
+
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
986 |
+
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
987 |
+
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
988 |
+
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
989 |
+
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
990 |
+
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
991 |
+
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
992 |
+
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
993 |
+
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
994 |
+
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
995 |
+
"patch_embedding.bias": "patch_embedding.bias",
|
996 |
+
"patch_embedding.weight": "patch_embedding.weight",
|
997 |
+
"scale_shift_table": "head.modulation",
|
998 |
+
"proj_out.bias": "head.head.bias",
|
999 |
+
"proj_out.weight": "head.head.weight",
|
1000 |
+
}
|
1001 |
+
state_dict_ = {}
|
1002 |
+
for name, param in state_dict.items():
|
1003 |
+
if name in rename_dict:
|
1004 |
+
state_dict_[rename_dict[name]] = param
|
1005 |
+
else:
|
1006 |
+
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
1007 |
+
if name_ in rename_dict:
|
1008 |
+
name_ = rename_dict[name_]
|
1009 |
+
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
1010 |
+
state_dict_[name_] = param
|
1011 |
+
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
1012 |
+
config = {
|
1013 |
+
"model_type": "t2v",
|
1014 |
+
"patch_size": (1, 2, 2),
|
1015 |
+
"text_len": 512,
|
1016 |
+
"in_dim": 16,
|
1017 |
+
"dim": 5120,
|
1018 |
+
"ffn_dim": 13824,
|
1019 |
+
"freq_dim": 256,
|
1020 |
+
"text_dim": 4096,
|
1021 |
+
"out_dim": 16,
|
1022 |
+
"num_heads": 40,
|
1023 |
+
"num_layers": 40,
|
1024 |
+
"window_size": (-1, -1),
|
1025 |
+
"qk_norm": True,
|
1026 |
+
"cross_attn_norm": True,
|
1027 |
+
"eps": 1e-6,
|
1028 |
+
}
|
1029 |
+
else:
|
1030 |
+
config = {}
|
1031 |
+
return state_dict_, config
|
1032 |
+
|
1033 |
+
def from_civitai(self, state_dict):
|
1034 |
+
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
1035 |
+
config = {
|
1036 |
+
"model_type": "t2v",
|
1037 |
+
"patch_size": (1, 2, 2),
|
1038 |
+
"text_len": 512,
|
1039 |
+
"in_dim": 16,
|
1040 |
+
"dim": 1536,
|
1041 |
+
"ffn_dim": 8960,
|
1042 |
+
"freq_dim": 256,
|
1043 |
+
"text_dim": 4096,
|
1044 |
+
"out_dim": 16,
|
1045 |
+
"num_heads": 12,
|
1046 |
+
"num_layers": 30,
|
1047 |
+
"window_size": (-1, -1),
|
1048 |
+
"qk_norm": True,
|
1049 |
+
"cross_attn_norm": True,
|
1050 |
+
"eps": 1e-6,
|
1051 |
+
}
|
1052 |
+
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
1053 |
+
config = {
|
1054 |
+
"model_type": "t2v",
|
1055 |
+
"patch_size": (1, 2, 2),
|
1056 |
+
"text_len": 512,
|
1057 |
+
"in_dim": 16,
|
1058 |
+
"dim": 5120,
|
1059 |
+
"ffn_dim": 13824,
|
1060 |
+
"freq_dim": 256,
|
1061 |
+
"text_dim": 4096,
|
1062 |
+
"out_dim": 16,
|
1063 |
+
"num_heads": 40,
|
1064 |
+
"num_layers": 40,
|
1065 |
+
"window_size": (-1, -1),
|
1066 |
+
"qk_norm": True,
|
1067 |
+
"cross_attn_norm": True,
|
1068 |
+
"eps": 1e-6,
|
1069 |
+
}
|
1070 |
+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
1071 |
+
config = {
|
1072 |
+
"model_type": "i2v",
|
1073 |
+
"patch_size": (1, 2, 2),
|
1074 |
+
"text_len": 512,
|
1075 |
+
"in_dim": 36,
|
1076 |
+
"dim": 5120,
|
1077 |
+
"ffn_dim": 13824,
|
1078 |
+
"freq_dim": 256,
|
1079 |
+
"text_dim": 4096,
|
1080 |
+
"out_dim": 16,
|
1081 |
+
"num_heads": 40,
|
1082 |
+
"num_layers": 40,
|
1083 |
+
"window_size": (-1, -1),
|
1084 |
+
"qk_norm": True,
|
1085 |
+
"cross_attn_norm": True,
|
1086 |
+
"eps": 1e-6,
|
1087 |
+
}
|
1088 |
+
else:
|
1089 |
+
config = {}
|
1090 |
+
return state_dict, config
|
model/image_encoder.py
ADDED
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Concise re-implementation of
|
3 |
+
``https://github.com/openai/CLIP'' and
|
4 |
+
``https://github.com/mlfoundations/open_clip''.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms as T
|
11 |
+
from .dit import flash_attention
|
12 |
+
|
13 |
+
|
14 |
+
class SelfAttention(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
17 |
+
assert dim % num_heads == 0
|
18 |
+
super().__init__()
|
19 |
+
self.dim = dim
|
20 |
+
self.num_heads = num_heads
|
21 |
+
self.head_dim = dim // num_heads
|
22 |
+
self.eps = eps
|
23 |
+
|
24 |
+
# layers
|
25 |
+
self.q = nn.Linear(dim, dim)
|
26 |
+
self.k = nn.Linear(dim, dim)
|
27 |
+
self.v = nn.Linear(dim, dim)
|
28 |
+
self.o = nn.Linear(dim, dim)
|
29 |
+
self.dropout = nn.Dropout(dropout)
|
30 |
+
|
31 |
+
def forward(self, x, mask):
|
32 |
+
"""
|
33 |
+
x: [B, L, C].
|
34 |
+
"""
|
35 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
36 |
+
|
37 |
+
# compute query, key, value
|
38 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
39 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
40 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
41 |
+
|
42 |
+
# compute attention
|
43 |
+
p = self.dropout.p if self.training else 0.0
|
44 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
45 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
46 |
+
|
47 |
+
# output
|
48 |
+
x = self.o(x)
|
49 |
+
x = self.dropout(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class AttentionBlock(nn.Module):
|
54 |
+
|
55 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
56 |
+
super().__init__()
|
57 |
+
self.dim = dim
|
58 |
+
self.num_heads = num_heads
|
59 |
+
self.post_norm = post_norm
|
60 |
+
self.eps = eps
|
61 |
+
|
62 |
+
# layers
|
63 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
64 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
65 |
+
self.ffn = nn.Sequential(
|
66 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
67 |
+
nn.Dropout(dropout))
|
68 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
69 |
+
|
70 |
+
def forward(self, x, mask):
|
71 |
+
if self.post_norm:
|
72 |
+
x = self.norm1(x + self.attn(x, mask))
|
73 |
+
x = self.norm2(x + self.ffn(x))
|
74 |
+
else:
|
75 |
+
x = x + self.attn(self.norm1(x), mask)
|
76 |
+
x = x + self.ffn(self.norm2(x))
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class XLMRoberta(nn.Module):
|
81 |
+
"""
|
82 |
+
XLMRobertaModel with no pooler and no LM head.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
vocab_size=250002,
|
87 |
+
max_seq_len=514,
|
88 |
+
type_size=1,
|
89 |
+
pad_id=1,
|
90 |
+
dim=1024,
|
91 |
+
num_heads=16,
|
92 |
+
num_layers=24,
|
93 |
+
post_norm=True,
|
94 |
+
dropout=0.1,
|
95 |
+
eps=1e-5):
|
96 |
+
super().__init__()
|
97 |
+
self.vocab_size = vocab_size
|
98 |
+
self.max_seq_len = max_seq_len
|
99 |
+
self.type_size = type_size
|
100 |
+
self.pad_id = pad_id
|
101 |
+
self.dim = dim
|
102 |
+
self.num_heads = num_heads
|
103 |
+
self.num_layers = num_layers
|
104 |
+
self.post_norm = post_norm
|
105 |
+
self.eps = eps
|
106 |
+
|
107 |
+
# embeddings
|
108 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
109 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
110 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
111 |
+
self.dropout = nn.Dropout(dropout)
|
112 |
+
|
113 |
+
# blocks
|
114 |
+
self.blocks = nn.ModuleList([
|
115 |
+
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
116 |
+
for _ in range(num_layers)
|
117 |
+
])
|
118 |
+
|
119 |
+
# norm layer
|
120 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
121 |
+
|
122 |
+
def forward(self, ids):
|
123 |
+
"""
|
124 |
+
ids: [B, L] of torch.LongTensor.
|
125 |
+
"""
|
126 |
+
b, s = ids.shape
|
127 |
+
mask = ids.ne(self.pad_id).long()
|
128 |
+
|
129 |
+
# embeddings
|
130 |
+
x = self.token_embedding(ids) + \
|
131 |
+
self.type_embedding(torch.zeros_like(ids)) + \
|
132 |
+
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
133 |
+
if self.post_norm:
|
134 |
+
x = self.norm(x)
|
135 |
+
x = self.dropout(x)
|
136 |
+
|
137 |
+
# blocks
|
138 |
+
mask = torch.where(
|
139 |
+
mask.view(b, 1, 1, s).gt(0), 0.0,
|
140 |
+
torch.finfo(x.dtype).min)
|
141 |
+
for block in self.blocks:
|
142 |
+
x = block(x, mask)
|
143 |
+
|
144 |
+
# output
|
145 |
+
if not self.post_norm:
|
146 |
+
x = self.norm(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
def xlm_roberta_large(pretrained=False,
|
151 |
+
return_tokenizer=False,
|
152 |
+
device='cpu',
|
153 |
+
**kwargs):
|
154 |
+
"""
|
155 |
+
XLMRobertaLarge adapted from Huggingface.
|
156 |
+
"""
|
157 |
+
# params
|
158 |
+
cfg = dict(
|
159 |
+
vocab_size=250002,
|
160 |
+
max_seq_len=514,
|
161 |
+
type_size=1,
|
162 |
+
pad_id=1,
|
163 |
+
dim=1024,
|
164 |
+
num_heads=16,
|
165 |
+
num_layers=24,
|
166 |
+
post_norm=True,
|
167 |
+
dropout=0.1,
|
168 |
+
eps=1e-5)
|
169 |
+
cfg.update(**kwargs)
|
170 |
+
|
171 |
+
# init model
|
172 |
+
if pretrained:
|
173 |
+
from sora import DOWNLOAD_TO_CACHE
|
174 |
+
|
175 |
+
# init a meta model
|
176 |
+
with torch.device('meta'):
|
177 |
+
model = XLMRoberta(**cfg)
|
178 |
+
|
179 |
+
# load checkpoint
|
180 |
+
model.load_state_dict(
|
181 |
+
torch.load(
|
182 |
+
DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
|
183 |
+
map_location=device),
|
184 |
+
assign=True)
|
185 |
+
else:
|
186 |
+
# init a model on device
|
187 |
+
with torch.device(device):
|
188 |
+
model = XLMRoberta(**cfg)
|
189 |
+
|
190 |
+
# init tokenizer
|
191 |
+
if return_tokenizer:
|
192 |
+
from sora.data import HuggingfaceTokenizer
|
193 |
+
tokenizer = HuggingfaceTokenizer(
|
194 |
+
name='xlm-roberta-large',
|
195 |
+
seq_len=model.text_len,
|
196 |
+
clean='whitespace')
|
197 |
+
return model, tokenizer
|
198 |
+
else:
|
199 |
+
return model
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
def pos_interpolate(pos, seq_len):
|
204 |
+
if pos.size(1) == seq_len:
|
205 |
+
return pos
|
206 |
+
else:
|
207 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
208 |
+
tar_grid = int(math.sqrt(seq_len))
|
209 |
+
n = pos.size(1) - src_grid * src_grid
|
210 |
+
return torch.cat([
|
211 |
+
pos[:, :n],
|
212 |
+
F.interpolate(
|
213 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
214 |
+
0, 3, 1, 2),
|
215 |
+
size=(tar_grid, tar_grid),
|
216 |
+
mode='bicubic',
|
217 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
218 |
+
],
|
219 |
+
dim=1)
|
220 |
+
|
221 |
+
|
222 |
+
class QuickGELU(nn.Module):
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
return x * torch.sigmoid(1.702 * x)
|
226 |
+
|
227 |
+
|
228 |
+
class LayerNorm(nn.LayerNorm):
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
return super().forward(x.float()).type_as(x)
|
232 |
+
|
233 |
+
|
234 |
+
class SelfAttention(nn.Module):
|
235 |
+
|
236 |
+
def __init__(self,
|
237 |
+
dim,
|
238 |
+
num_heads,
|
239 |
+
causal=False,
|
240 |
+
attn_dropout=0.0,
|
241 |
+
proj_dropout=0.0):
|
242 |
+
assert dim % num_heads == 0
|
243 |
+
super().__init__()
|
244 |
+
self.dim = dim
|
245 |
+
self.num_heads = num_heads
|
246 |
+
self.head_dim = dim // num_heads
|
247 |
+
self.causal = causal
|
248 |
+
self.attn_dropout = attn_dropout
|
249 |
+
self.proj_dropout = proj_dropout
|
250 |
+
|
251 |
+
# layers
|
252 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
253 |
+
self.proj = nn.Linear(dim, dim)
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
"""
|
257 |
+
x: [B, L, C].
|
258 |
+
"""
|
259 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
260 |
+
|
261 |
+
# compute query, key, value
|
262 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
263 |
+
|
264 |
+
# compute attention
|
265 |
+
p = self.attn_dropout if self.training else 0.0
|
266 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
267 |
+
x = x.reshape(b, s, c)
|
268 |
+
|
269 |
+
# output
|
270 |
+
x = self.proj(x)
|
271 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
272 |
+
return x
|
273 |
+
|
274 |
+
|
275 |
+
class SwiGLU(nn.Module):
|
276 |
+
|
277 |
+
def __init__(self, dim, mid_dim):
|
278 |
+
super().__init__()
|
279 |
+
self.dim = dim
|
280 |
+
self.mid_dim = mid_dim
|
281 |
+
|
282 |
+
# layers
|
283 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
284 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
285 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
286 |
+
|
287 |
+
def forward(self, x):
|
288 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
289 |
+
x = self.fc3(x)
|
290 |
+
return x
|
291 |
+
|
292 |
+
|
293 |
+
class AttentionBlock(nn.Module):
|
294 |
+
|
295 |
+
def __init__(self,
|
296 |
+
dim,
|
297 |
+
mlp_ratio,
|
298 |
+
num_heads,
|
299 |
+
post_norm=False,
|
300 |
+
causal=False,
|
301 |
+
activation='quick_gelu',
|
302 |
+
attn_dropout=0.0,
|
303 |
+
proj_dropout=0.0,
|
304 |
+
norm_eps=1e-5):
|
305 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
306 |
+
super().__init__()
|
307 |
+
self.dim = dim
|
308 |
+
self.mlp_ratio = mlp_ratio
|
309 |
+
self.num_heads = num_heads
|
310 |
+
self.post_norm = post_norm
|
311 |
+
self.causal = causal
|
312 |
+
self.norm_eps = norm_eps
|
313 |
+
|
314 |
+
# layers
|
315 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
316 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
317 |
+
proj_dropout)
|
318 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
319 |
+
if activation == 'swi_glu':
|
320 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
321 |
+
else:
|
322 |
+
self.mlp = nn.Sequential(
|
323 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
324 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
325 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
326 |
+
|
327 |
+
def forward(self, x):
|
328 |
+
if self.post_norm:
|
329 |
+
x = x + self.norm1(self.attn(x))
|
330 |
+
x = x + self.norm2(self.mlp(x))
|
331 |
+
else:
|
332 |
+
x = x + self.attn(self.norm1(x))
|
333 |
+
x = x + self.mlp(self.norm2(x))
|
334 |
+
return x
|
335 |
+
|
336 |
+
|
337 |
+
class AttentionPool(nn.Module):
|
338 |
+
|
339 |
+
def __init__(self,
|
340 |
+
dim,
|
341 |
+
mlp_ratio,
|
342 |
+
num_heads,
|
343 |
+
activation='gelu',
|
344 |
+
proj_dropout=0.0,
|
345 |
+
norm_eps=1e-5):
|
346 |
+
assert dim % num_heads == 0
|
347 |
+
super().__init__()
|
348 |
+
self.dim = dim
|
349 |
+
self.mlp_ratio = mlp_ratio
|
350 |
+
self.num_heads = num_heads
|
351 |
+
self.head_dim = dim // num_heads
|
352 |
+
self.proj_dropout = proj_dropout
|
353 |
+
self.norm_eps = norm_eps
|
354 |
+
|
355 |
+
# layers
|
356 |
+
gain = 1.0 / math.sqrt(dim)
|
357 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
358 |
+
self.to_q = nn.Linear(dim, dim)
|
359 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
360 |
+
self.proj = nn.Linear(dim, dim)
|
361 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
362 |
+
self.mlp = nn.Sequential(
|
363 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
364 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
365 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
"""
|
369 |
+
x: [B, L, C].
|
370 |
+
"""
|
371 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
372 |
+
|
373 |
+
# compute query, key, value
|
374 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
375 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
376 |
+
|
377 |
+
# compute attention
|
378 |
+
x = flash_attention(q, k, v, version=2)
|
379 |
+
x = x.reshape(b, 1, c)
|
380 |
+
|
381 |
+
# output
|
382 |
+
x = self.proj(x)
|
383 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
384 |
+
|
385 |
+
# mlp
|
386 |
+
x = x + self.mlp(self.norm(x))
|
387 |
+
return x[:, 0]
|
388 |
+
|
389 |
+
|
390 |
+
class VisionTransformer(nn.Module):
|
391 |
+
|
392 |
+
def __init__(self,
|
393 |
+
image_size=224,
|
394 |
+
patch_size=16,
|
395 |
+
dim=768,
|
396 |
+
mlp_ratio=4,
|
397 |
+
out_dim=512,
|
398 |
+
num_heads=12,
|
399 |
+
num_layers=12,
|
400 |
+
pool_type='token',
|
401 |
+
pre_norm=True,
|
402 |
+
post_norm=False,
|
403 |
+
activation='quick_gelu',
|
404 |
+
attn_dropout=0.0,
|
405 |
+
proj_dropout=0.0,
|
406 |
+
embedding_dropout=0.0,
|
407 |
+
norm_eps=1e-5):
|
408 |
+
if image_size % patch_size != 0:
|
409 |
+
print(
|
410 |
+
'[WARNING] image_size is not divisible by patch_size',
|
411 |
+
flush=True)
|
412 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
413 |
+
out_dim = out_dim or dim
|
414 |
+
super().__init__()
|
415 |
+
self.image_size = image_size
|
416 |
+
self.patch_size = patch_size
|
417 |
+
self.num_patches = (image_size // patch_size)**2
|
418 |
+
self.dim = dim
|
419 |
+
self.mlp_ratio = mlp_ratio
|
420 |
+
self.out_dim = out_dim
|
421 |
+
self.num_heads = num_heads
|
422 |
+
self.num_layers = num_layers
|
423 |
+
self.pool_type = pool_type
|
424 |
+
self.post_norm = post_norm
|
425 |
+
self.norm_eps = norm_eps
|
426 |
+
|
427 |
+
# embeddings
|
428 |
+
gain = 1.0 / math.sqrt(dim)
|
429 |
+
self.patch_embedding = nn.Conv2d(
|
430 |
+
3,
|
431 |
+
dim,
|
432 |
+
kernel_size=patch_size,
|
433 |
+
stride=patch_size,
|
434 |
+
bias=not pre_norm)
|
435 |
+
if pool_type in ('token', 'token_fc'):
|
436 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
437 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
438 |
+
1, self.num_patches +
|
439 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
440 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
441 |
+
|
442 |
+
# transformer
|
443 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
444 |
+
self.transformer = nn.Sequential(*[
|
445 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
446 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
447 |
+
for _ in range(num_layers)
|
448 |
+
])
|
449 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
450 |
+
|
451 |
+
# head
|
452 |
+
if pool_type == 'token':
|
453 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
454 |
+
elif pool_type == 'token_fc':
|
455 |
+
self.head = nn.Linear(dim, out_dim)
|
456 |
+
elif pool_type == 'attn_pool':
|
457 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
458 |
+
proj_dropout, norm_eps)
|
459 |
+
|
460 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
461 |
+
b = x.size(0)
|
462 |
+
|
463 |
+
# embeddings
|
464 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
465 |
+
if self.pool_type in ('token', 'token_fc'):
|
466 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
|
467 |
+
if interpolation:
|
468 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
469 |
+
else:
|
470 |
+
e = self.pos_embedding
|
471 |
+
e = e.to(dtype=x.dtype, device=x.device)
|
472 |
+
x = self.dropout(x + e)
|
473 |
+
if self.pre_norm is not None:
|
474 |
+
x = self.pre_norm(x)
|
475 |
+
|
476 |
+
# transformer
|
477 |
+
if use_31_block:
|
478 |
+
x = self.transformer[:-1](x)
|
479 |
+
return x
|
480 |
+
else:
|
481 |
+
x = self.transformer(x)
|
482 |
+
return x
|
483 |
+
|
484 |
+
|
485 |
+
class CLIP(nn.Module):
|
486 |
+
|
487 |
+
def __init__(self,
|
488 |
+
embed_dim=512,
|
489 |
+
image_size=224,
|
490 |
+
patch_size=16,
|
491 |
+
vision_dim=768,
|
492 |
+
vision_mlp_ratio=4,
|
493 |
+
vision_heads=12,
|
494 |
+
vision_layers=12,
|
495 |
+
vision_pool='token',
|
496 |
+
vision_pre_norm=True,
|
497 |
+
vision_post_norm=False,
|
498 |
+
vocab_size=49408,
|
499 |
+
text_len=77,
|
500 |
+
text_dim=512,
|
501 |
+
text_mlp_ratio=4,
|
502 |
+
text_heads=8,
|
503 |
+
text_layers=12,
|
504 |
+
text_causal=True,
|
505 |
+
text_pool='argmax',
|
506 |
+
text_head_bias=False,
|
507 |
+
logit_bias=None,
|
508 |
+
activation='quick_gelu',
|
509 |
+
attn_dropout=0.0,
|
510 |
+
proj_dropout=0.0,
|
511 |
+
embedding_dropout=0.0,
|
512 |
+
norm_eps=1e-5):
|
513 |
+
super().__init__()
|
514 |
+
self.embed_dim = embed_dim
|
515 |
+
self.image_size = image_size
|
516 |
+
self.patch_size = patch_size
|
517 |
+
self.vision_dim = vision_dim
|
518 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
519 |
+
self.vision_heads = vision_heads
|
520 |
+
self.vision_layers = vision_layers
|
521 |
+
self.vision_pool = vision_pool
|
522 |
+
self.vision_pre_norm = vision_pre_norm
|
523 |
+
self.vision_post_norm = vision_post_norm
|
524 |
+
self.vocab_size = vocab_size
|
525 |
+
self.text_len = text_len
|
526 |
+
self.text_dim = text_dim
|
527 |
+
self.text_mlp_ratio = text_mlp_ratio
|
528 |
+
self.text_heads = text_heads
|
529 |
+
self.text_layers = text_layers
|
530 |
+
self.text_causal = text_causal
|
531 |
+
self.text_pool = text_pool
|
532 |
+
self.text_head_bias = text_head_bias
|
533 |
+
self.norm_eps = norm_eps
|
534 |
+
|
535 |
+
# models
|
536 |
+
self.visual = VisionTransformer(
|
537 |
+
image_size=image_size,
|
538 |
+
patch_size=patch_size,
|
539 |
+
dim=vision_dim,
|
540 |
+
mlp_ratio=vision_mlp_ratio,
|
541 |
+
out_dim=embed_dim,
|
542 |
+
num_heads=vision_heads,
|
543 |
+
num_layers=vision_layers,
|
544 |
+
pool_type=vision_pool,
|
545 |
+
pre_norm=vision_pre_norm,
|
546 |
+
post_norm=vision_post_norm,
|
547 |
+
activation=activation,
|
548 |
+
attn_dropout=attn_dropout,
|
549 |
+
proj_dropout=proj_dropout,
|
550 |
+
embedding_dropout=embedding_dropout,
|
551 |
+
norm_eps=norm_eps)
|
552 |
+
self.textual = TextTransformer(
|
553 |
+
vocab_size=vocab_size,
|
554 |
+
text_len=text_len,
|
555 |
+
dim=text_dim,
|
556 |
+
mlp_ratio=text_mlp_ratio,
|
557 |
+
out_dim=embed_dim,
|
558 |
+
num_heads=text_heads,
|
559 |
+
num_layers=text_layers,
|
560 |
+
causal=text_causal,
|
561 |
+
pool_type=text_pool,
|
562 |
+
head_bias=text_head_bias,
|
563 |
+
activation=activation,
|
564 |
+
attn_dropout=attn_dropout,
|
565 |
+
proj_dropout=proj_dropout,
|
566 |
+
embedding_dropout=embedding_dropout,
|
567 |
+
norm_eps=norm_eps)
|
568 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
569 |
+
if logit_bias is not None:
|
570 |
+
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
|
571 |
+
|
572 |
+
# initialize weights
|
573 |
+
self.init_weights()
|
574 |
+
|
575 |
+
def forward(self, imgs, txt_ids):
|
576 |
+
"""
|
577 |
+
imgs: [B, 3, H, W] of torch.float32.
|
578 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
579 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
580 |
+
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
|
581 |
+
"""
|
582 |
+
xi = self.visual(imgs)
|
583 |
+
xt = self.textual(txt_ids)
|
584 |
+
return xi, xt
|
585 |
+
|
586 |
+
def init_weights(self):
|
587 |
+
# embeddings
|
588 |
+
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
589 |
+
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
|
590 |
+
|
591 |
+
# attentions
|
592 |
+
for modality in ['visual', 'textual']:
|
593 |
+
dim = self.vision_dim if modality == 'visual' else self.text_dim
|
594 |
+
transformer = getattr(self, modality).transformer
|
595 |
+
proj_gain = (1.0 / math.sqrt(dim)) * (
|
596 |
+
1.0 / math.sqrt(2 * len(transformer)))
|
597 |
+
attn_gain = 1.0 / math.sqrt(dim)
|
598 |
+
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
599 |
+
for block in transformer:
|
600 |
+
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
601 |
+
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
602 |
+
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
603 |
+
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
604 |
+
|
605 |
+
def param_groups(self):
|
606 |
+
groups = [{
|
607 |
+
'params': [
|
608 |
+
p for n, p in self.named_parameters()
|
609 |
+
if 'norm' in n or n.endswith('bias')
|
610 |
+
],
|
611 |
+
'weight_decay': 0.0
|
612 |
+
}, {
|
613 |
+
'params': [
|
614 |
+
p for n, p in self.named_parameters()
|
615 |
+
if not ('norm' in n or n.endswith('bias'))
|
616 |
+
]
|
617 |
+
}]
|
618 |
+
return groups
|
619 |
+
|
620 |
+
|
621 |
+
class XLMRobertaWithHead(XLMRoberta):
|
622 |
+
|
623 |
+
def __init__(self, **kwargs):
|
624 |
+
self.out_dim = kwargs.pop('out_dim')
|
625 |
+
super().__init__(**kwargs)
|
626 |
+
|
627 |
+
# head
|
628 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
629 |
+
self.head = nn.Sequential(
|
630 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
631 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
632 |
+
|
633 |
+
def forward(self, ids):
|
634 |
+
# xlm-roberta
|
635 |
+
x = super().forward(ids)
|
636 |
+
|
637 |
+
# average pooling
|
638 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
639 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
640 |
+
|
641 |
+
# head
|
642 |
+
x = self.head(x)
|
643 |
+
return x
|
644 |
+
|
645 |
+
|
646 |
+
class XLMRobertaCLIP(nn.Module):
|
647 |
+
|
648 |
+
def __init__(self,
|
649 |
+
embed_dim=1024,
|
650 |
+
image_size=224,
|
651 |
+
patch_size=14,
|
652 |
+
vision_dim=1280,
|
653 |
+
vision_mlp_ratio=4,
|
654 |
+
vision_heads=16,
|
655 |
+
vision_layers=32,
|
656 |
+
vision_pool='token',
|
657 |
+
vision_pre_norm=True,
|
658 |
+
vision_post_norm=False,
|
659 |
+
activation='gelu',
|
660 |
+
vocab_size=250002,
|
661 |
+
max_text_len=514,
|
662 |
+
type_size=1,
|
663 |
+
pad_id=1,
|
664 |
+
text_dim=1024,
|
665 |
+
text_heads=16,
|
666 |
+
text_layers=24,
|
667 |
+
text_post_norm=True,
|
668 |
+
text_dropout=0.1,
|
669 |
+
attn_dropout=0.0,
|
670 |
+
proj_dropout=0.0,
|
671 |
+
embedding_dropout=0.0,
|
672 |
+
norm_eps=1e-5):
|
673 |
+
super().__init__()
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.image_size = image_size
|
676 |
+
self.patch_size = patch_size
|
677 |
+
self.vision_dim = vision_dim
|
678 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
679 |
+
self.vision_heads = vision_heads
|
680 |
+
self.vision_layers = vision_layers
|
681 |
+
self.vision_pre_norm = vision_pre_norm
|
682 |
+
self.vision_post_norm = vision_post_norm
|
683 |
+
self.activation = activation
|
684 |
+
self.vocab_size = vocab_size
|
685 |
+
self.max_text_len = max_text_len
|
686 |
+
self.type_size = type_size
|
687 |
+
self.pad_id = pad_id
|
688 |
+
self.text_dim = text_dim
|
689 |
+
self.text_heads = text_heads
|
690 |
+
self.text_layers = text_layers
|
691 |
+
self.text_post_norm = text_post_norm
|
692 |
+
self.norm_eps = norm_eps
|
693 |
+
|
694 |
+
# models
|
695 |
+
self.visual = VisionTransformer(
|
696 |
+
image_size=image_size,
|
697 |
+
patch_size=patch_size,
|
698 |
+
dim=vision_dim,
|
699 |
+
mlp_ratio=vision_mlp_ratio,
|
700 |
+
out_dim=embed_dim,
|
701 |
+
num_heads=vision_heads,
|
702 |
+
num_layers=vision_layers,
|
703 |
+
pool_type=vision_pool,
|
704 |
+
pre_norm=vision_pre_norm,
|
705 |
+
post_norm=vision_post_norm,
|
706 |
+
activation=activation,
|
707 |
+
attn_dropout=attn_dropout,
|
708 |
+
proj_dropout=proj_dropout,
|
709 |
+
embedding_dropout=embedding_dropout,
|
710 |
+
norm_eps=norm_eps)
|
711 |
+
self.textual = None
|
712 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
713 |
+
|
714 |
+
def forward(self, imgs, txt_ids):
|
715 |
+
"""
|
716 |
+
imgs: [B, 3, H, W] of torch.float32.
|
717 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
718 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
719 |
+
txt_ids: [B, L] of torch.long.
|
720 |
+
Encoded by data.CLIPTokenizer.
|
721 |
+
"""
|
722 |
+
xi = self.visual(imgs)
|
723 |
+
xt = self.textual(txt_ids)
|
724 |
+
return xi, xt
|
725 |
+
|
726 |
+
def param_groups(self):
|
727 |
+
groups = [{
|
728 |
+
'params': [
|
729 |
+
p for n, p in self.named_parameters()
|
730 |
+
if 'norm' in n or n.endswith('bias')
|
731 |
+
],
|
732 |
+
'weight_decay': 0.0
|
733 |
+
}, {
|
734 |
+
'params': [
|
735 |
+
p for n, p in self.named_parameters()
|
736 |
+
if not ('norm' in n or n.endswith('bias'))
|
737 |
+
]
|
738 |
+
}]
|
739 |
+
return groups
|
740 |
+
|
741 |
+
|
742 |
+
def _clip(pretrained=False,
|
743 |
+
pretrained_name=None,
|
744 |
+
model_cls=CLIP,
|
745 |
+
return_transforms=False,
|
746 |
+
return_tokenizer=False,
|
747 |
+
tokenizer_padding='eos',
|
748 |
+
dtype=torch.float32,
|
749 |
+
device='cpu',
|
750 |
+
**kwargs):
|
751 |
+
# init model
|
752 |
+
if pretrained and pretrained_name:
|
753 |
+
from sora import BUCKET, DOWNLOAD_TO_CACHE
|
754 |
+
|
755 |
+
# init a meta model
|
756 |
+
with torch.device('meta'):
|
757 |
+
model = model_cls(**kwargs)
|
758 |
+
|
759 |
+
# checkpoint path
|
760 |
+
checkpoint = f'models/clip/{pretrained_name}'
|
761 |
+
if dtype in (torch.float16, torch.bfloat16):
|
762 |
+
suffix = '-' + {
|
763 |
+
torch.float16: 'fp16',
|
764 |
+
torch.bfloat16: 'bf16'
|
765 |
+
}[dtype]
|
766 |
+
if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
|
767 |
+
checkpoint = f'{checkpoint}{suffix}'
|
768 |
+
checkpoint += '.pth'
|
769 |
+
|
770 |
+
# load
|
771 |
+
model.load_state_dict(
|
772 |
+
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
|
773 |
+
assign=True,
|
774 |
+
strict=False)
|
775 |
+
else:
|
776 |
+
# init a model on device
|
777 |
+
with torch.device(device):
|
778 |
+
model = model_cls(**kwargs)
|
779 |
+
|
780 |
+
# set device
|
781 |
+
output = (model,)
|
782 |
+
|
783 |
+
# init transforms
|
784 |
+
if return_transforms:
|
785 |
+
# mean and std
|
786 |
+
if 'siglip' in pretrained_name.lower():
|
787 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
788 |
+
else:
|
789 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
790 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
791 |
+
|
792 |
+
# transforms
|
793 |
+
transforms = T.Compose([
|
794 |
+
T.Resize((model.image_size, model.image_size),
|
795 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
796 |
+
T.ToTensor(),
|
797 |
+
T.Normalize(mean=mean, std=std)
|
798 |
+
])
|
799 |
+
output += (transforms,)
|
800 |
+
|
801 |
+
# init tokenizer
|
802 |
+
if return_tokenizer:
|
803 |
+
from sora import data
|
804 |
+
if 'siglip' in pretrained_name.lower():
|
805 |
+
tokenizer = data.HuggingfaceTokenizer(
|
806 |
+
name=f'timm/{pretrained_name}',
|
807 |
+
seq_len=model.text_len,
|
808 |
+
clean='canonicalize')
|
809 |
+
elif 'xlm' in pretrained_name.lower():
|
810 |
+
tokenizer = data.HuggingfaceTokenizer(
|
811 |
+
name='xlm-roberta-large',
|
812 |
+
seq_len=model.max_text_len - 2,
|
813 |
+
clean='whitespace')
|
814 |
+
elif 'mba' in pretrained_name.lower():
|
815 |
+
tokenizer = data.HuggingfaceTokenizer(
|
816 |
+
name='facebook/xlm-roberta-xl',
|
817 |
+
seq_len=model.max_text_len - 2,
|
818 |
+
clean='whitespace')
|
819 |
+
else:
|
820 |
+
tokenizer = data.CLIPTokenizer(
|
821 |
+
seq_len=model.text_len, padding=tokenizer_padding)
|
822 |
+
output += (tokenizer,)
|
823 |
+
return output[0] if len(output) == 1 else output
|
824 |
+
|
825 |
+
|
826 |
+
def clip_xlm_roberta_vit_h_14(
|
827 |
+
pretrained=False,
|
828 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
829 |
+
**kwargs):
|
830 |
+
cfg = dict(
|
831 |
+
embed_dim=1024,
|
832 |
+
image_size=224,
|
833 |
+
patch_size=14,
|
834 |
+
vision_dim=1280,
|
835 |
+
vision_mlp_ratio=4,
|
836 |
+
vision_heads=16,
|
837 |
+
vision_layers=32,
|
838 |
+
vision_pool='token',
|
839 |
+
activation='gelu',
|
840 |
+
vocab_size=250002,
|
841 |
+
max_text_len=514,
|
842 |
+
type_size=1,
|
843 |
+
pad_id=1,
|
844 |
+
text_dim=1024,
|
845 |
+
text_heads=16,
|
846 |
+
text_layers=24,
|
847 |
+
text_post_norm=True,
|
848 |
+
text_dropout=0.1,
|
849 |
+
attn_dropout=0.0,
|
850 |
+
proj_dropout=0.0,
|
851 |
+
embedding_dropout=0.0)
|
852 |
+
cfg.update(**kwargs)
|
853 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
854 |
+
|
855 |
+
|
856 |
+
class WanImageEncoder(torch.nn.Module):
|
857 |
+
|
858 |
+
def __init__(self):
|
859 |
+
super().__init__()
|
860 |
+
# init model
|
861 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
862 |
+
pretrained=False,
|
863 |
+
return_transforms=True,
|
864 |
+
return_tokenizer=False,
|
865 |
+
dtype=torch.float32,
|
866 |
+
device="cpu")
|
867 |
+
|
868 |
+
def encode_image(self, videos):
|
869 |
+
# preprocess
|
870 |
+
size = (self.model.image_size,) * 2
|
871 |
+
videos = torch.cat([
|
872 |
+
F.interpolate(
|
873 |
+
u,
|
874 |
+
size=size,
|
875 |
+
mode='bicubic',
|
876 |
+
align_corners=False) for u in videos
|
877 |
+
])
|
878 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
879 |
+
|
880 |
+
# forward
|
881 |
+
out = self.model.visual(videos, use_31_block=True)
|
882 |
+
return out
|
883 |
+
|
884 |
+
@staticmethod
|
885 |
+
def state_dict_converter():
|
886 |
+
return WanImageEncoderStateDictConverter()
|
887 |
+
|
888 |
+
|
889 |
+
class WanImageEncoderStateDictConverter:
|
890 |
+
def __init__(self):
|
891 |
+
pass
|
892 |
+
|
893 |
+
def from_diffusers(self, state_dict):
|
894 |
+
return state_dict
|
895 |
+
|
896 |
+
def from_civitai(self, state_dict):
|
897 |
+
state_dict_ = {}
|
898 |
+
for name, param in state_dict.items():
|
899 |
+
if name.startswith("textual."):
|
900 |
+
continue
|
901 |
+
name = "model." + name
|
902 |
+
state_dict_[name] = param
|
903 |
+
return state_dict_
|
model/prompter.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffsynth.prompters.base_prompter import BasePrompter
|
2 |
+
from model.text_encoder import WanTextEncoder
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
import ftfy
|
5 |
+
import html
|
6 |
+
import string
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
def basic_clean(text):
|
11 |
+
text = ftfy.fix_text(text)
|
12 |
+
text = html.unescape(html.unescape(text))
|
13 |
+
return text.strip()
|
14 |
+
|
15 |
+
|
16 |
+
def whitespace_clean(text):
|
17 |
+
text = re.sub(r'\s+', ' ', text)
|
18 |
+
text = text.strip()
|
19 |
+
return text
|
20 |
+
|
21 |
+
|
22 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
23 |
+
text = text.replace('_', ' ')
|
24 |
+
if keep_punctuation_exact_string:
|
25 |
+
text = keep_punctuation_exact_string.join(
|
26 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
27 |
+
for part in text.split(keep_punctuation_exact_string))
|
28 |
+
else:
|
29 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
30 |
+
text = text.lower()
|
31 |
+
text = re.sub(r'\s+', ' ', text)
|
32 |
+
return text.strip()
|
33 |
+
|
34 |
+
|
35 |
+
class HuggingfaceTokenizer:
|
36 |
+
|
37 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
38 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
39 |
+
self.name = name
|
40 |
+
self.seq_len = seq_len
|
41 |
+
self.clean = clean
|
42 |
+
|
43 |
+
# init tokenizer
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
45 |
+
self.vocab_size = self.tokenizer.vocab_size
|
46 |
+
|
47 |
+
def __call__(self, sequence, **kwargs):
|
48 |
+
return_mask = kwargs.pop('return_mask', False)
|
49 |
+
|
50 |
+
# arguments
|
51 |
+
_kwargs = {'return_tensors': 'pt'}
|
52 |
+
if self.seq_len is not None:
|
53 |
+
_kwargs.update({
|
54 |
+
'padding': 'max_length',
|
55 |
+
'truncation': True,
|
56 |
+
'max_length': self.seq_len
|
57 |
+
})
|
58 |
+
_kwargs.update(**kwargs)
|
59 |
+
|
60 |
+
# tokenization
|
61 |
+
if isinstance(sequence, str):
|
62 |
+
sequence = [sequence]
|
63 |
+
if self.clean:
|
64 |
+
sequence = [self._clean(u) for u in sequence]
|
65 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
66 |
+
|
67 |
+
# output
|
68 |
+
if return_mask:
|
69 |
+
return ids.input_ids, ids.attention_mask
|
70 |
+
else:
|
71 |
+
return ids.input_ids
|
72 |
+
|
73 |
+
def _clean(self, text):
|
74 |
+
if self.clean == 'whitespace':
|
75 |
+
text = whitespace_clean(basic_clean(text))
|
76 |
+
elif self.clean == 'lower':
|
77 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
78 |
+
elif self.clean == 'canonicalize':
|
79 |
+
text = canonicalize(basic_clean(text))
|
80 |
+
return text
|
81 |
+
|
82 |
+
|
83 |
+
class WanPrompter(BasePrompter):
|
84 |
+
|
85 |
+
def __init__(self, tokenizer_path=None, text_len=512):
|
86 |
+
super().__init__()
|
87 |
+
self.text_len = text_len
|
88 |
+
self.text_encoder = None
|
89 |
+
self.fetch_tokenizer(tokenizer_path)
|
90 |
+
|
91 |
+
def fetch_tokenizer(self, tokenizer_path=None):
|
92 |
+
if tokenizer_path is not None:
|
93 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
94 |
+
|
95 |
+
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
96 |
+
self.text_encoder = text_encoder
|
97 |
+
|
98 |
+
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
99 |
+
prompt = self.process_prompt(prompt, positive=positive)
|
100 |
+
|
101 |
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
102 |
+
ids = ids.to(device)
|
103 |
+
mask = mask.to(device)
|
104 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
105 |
+
prompt_emb = self.text_encoder(ids, mask)
|
106 |
+
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
107 |
+
return prompt_emb
|
model/text_encoder.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def fp16_clamp(x):
|
9 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
10 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
11 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
12 |
+
return x
|
13 |
+
|
14 |
+
|
15 |
+
class GELU(nn.Module):
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
19 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
20 |
+
|
21 |
+
|
22 |
+
class T5LayerNorm(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, dim, eps=1e-6):
|
25 |
+
super(T5LayerNorm, self).__init__()
|
26 |
+
self.dim = dim
|
27 |
+
self.eps = eps
|
28 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
32 |
+
self.eps)
|
33 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
34 |
+
x = x.type_as(self.weight)
|
35 |
+
return self.weight * x
|
36 |
+
|
37 |
+
|
38 |
+
class T5Attention(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
41 |
+
assert dim_attn % num_heads == 0
|
42 |
+
super(T5Attention, self).__init__()
|
43 |
+
self.dim = dim
|
44 |
+
self.dim_attn = dim_attn
|
45 |
+
self.num_heads = num_heads
|
46 |
+
self.head_dim = dim_attn // num_heads
|
47 |
+
|
48 |
+
# layers
|
49 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
50 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
51 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
52 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
53 |
+
self.dropout = nn.Dropout(dropout)
|
54 |
+
|
55 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
56 |
+
"""
|
57 |
+
x: [B, L1, C].
|
58 |
+
context: [B, L2, C] or None.
|
59 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
60 |
+
"""
|
61 |
+
# check inputs
|
62 |
+
context = x if context is None else context
|
63 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
64 |
+
|
65 |
+
# compute query, key, value
|
66 |
+
q = self.q(x).view(b, -1, n, c)
|
67 |
+
k = self.k(context).view(b, -1, n, c)
|
68 |
+
v = self.v(context).view(b, -1, n, c)
|
69 |
+
|
70 |
+
# attention bias
|
71 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
72 |
+
if pos_bias is not None:
|
73 |
+
attn_bias += pos_bias
|
74 |
+
if mask is not None:
|
75 |
+
assert mask.ndim in [2, 3]
|
76 |
+
mask = mask.view(b, 1, 1,
|
77 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
78 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
79 |
+
|
80 |
+
# compute attention (T5 does not use scaling)
|
81 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
82 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
83 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
84 |
+
|
85 |
+
# output
|
86 |
+
x = x.reshape(b, -1, n * c)
|
87 |
+
x = self.o(x)
|
88 |
+
x = self.dropout(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class T5FeedForward(nn.Module):
|
93 |
+
|
94 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
95 |
+
super(T5FeedForward, self).__init__()
|
96 |
+
self.dim = dim
|
97 |
+
self.dim_ffn = dim_ffn
|
98 |
+
|
99 |
+
# layers
|
100 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
101 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
102 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
103 |
+
self.dropout = nn.Dropout(dropout)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
x = self.fc1(x) * self.gate(x)
|
107 |
+
x = self.dropout(x)
|
108 |
+
x = self.fc2(x)
|
109 |
+
x = self.dropout(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class T5SelfAttention(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self,
|
116 |
+
dim,
|
117 |
+
dim_attn,
|
118 |
+
dim_ffn,
|
119 |
+
num_heads,
|
120 |
+
num_buckets,
|
121 |
+
shared_pos=True,
|
122 |
+
dropout=0.1):
|
123 |
+
super(T5SelfAttention, self).__init__()
|
124 |
+
self.dim = dim
|
125 |
+
self.dim_attn = dim_attn
|
126 |
+
self.dim_ffn = dim_ffn
|
127 |
+
self.num_heads = num_heads
|
128 |
+
self.num_buckets = num_buckets
|
129 |
+
self.shared_pos = shared_pos
|
130 |
+
|
131 |
+
# layers
|
132 |
+
self.norm1 = T5LayerNorm(dim)
|
133 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
134 |
+
self.norm2 = T5LayerNorm(dim)
|
135 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
136 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
137 |
+
num_buckets, num_heads, bidirectional=True)
|
138 |
+
|
139 |
+
def forward(self, x, mask=None, pos_bias=None):
|
140 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
141 |
+
x.size(1), x.size(1))
|
142 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
143 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class T5RelativeEmbedding(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
150 |
+
super(T5RelativeEmbedding, self).__init__()
|
151 |
+
self.num_buckets = num_buckets
|
152 |
+
self.num_heads = num_heads
|
153 |
+
self.bidirectional = bidirectional
|
154 |
+
self.max_dist = max_dist
|
155 |
+
|
156 |
+
# layers
|
157 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
158 |
+
|
159 |
+
def forward(self, lq, lk):
|
160 |
+
device = self.embedding.weight.device
|
161 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
162 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
163 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
164 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
165 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
166 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
167 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
168 |
+
0) # [1, N, Lq, Lk]
|
169 |
+
return rel_pos_embeds.contiguous()
|
170 |
+
|
171 |
+
def _relative_position_bucket(self, rel_pos):
|
172 |
+
# preprocess
|
173 |
+
if self.bidirectional:
|
174 |
+
num_buckets = self.num_buckets // 2
|
175 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
176 |
+
rel_pos = torch.abs(rel_pos)
|
177 |
+
else:
|
178 |
+
num_buckets = self.num_buckets
|
179 |
+
rel_buckets = 0
|
180 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
181 |
+
|
182 |
+
# embeddings for small and large positions
|
183 |
+
max_exact = num_buckets // 2
|
184 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
185 |
+
math.log(self.max_dist / max_exact) *
|
186 |
+
(num_buckets - max_exact)).long()
|
187 |
+
rel_pos_large = torch.min(
|
188 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
189 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
190 |
+
return rel_buckets
|
191 |
+
|
192 |
+
def init_weights(m):
|
193 |
+
if isinstance(m, T5LayerNorm):
|
194 |
+
nn.init.ones_(m.weight)
|
195 |
+
elif isinstance(m, T5FeedForward):
|
196 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
197 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
198 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
199 |
+
elif isinstance(m, T5Attention):
|
200 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
201 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
202 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
203 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
204 |
+
elif isinstance(m, T5RelativeEmbedding):
|
205 |
+
nn.init.normal_(
|
206 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
207 |
+
|
208 |
+
|
209 |
+
class WanTextEncoder(torch.nn.Module):
|
210 |
+
|
211 |
+
def __init__(self,
|
212 |
+
vocab=256384,
|
213 |
+
dim=4096,
|
214 |
+
dim_attn=4096,
|
215 |
+
dim_ffn=10240,
|
216 |
+
num_heads=64,
|
217 |
+
num_layers=24,
|
218 |
+
num_buckets=32,
|
219 |
+
shared_pos=False,
|
220 |
+
dropout=0.1):
|
221 |
+
super(WanTextEncoder, self).__init__()
|
222 |
+
self.dim = dim
|
223 |
+
self.dim_attn = dim_attn
|
224 |
+
self.dim_ffn = dim_ffn
|
225 |
+
self.num_heads = num_heads
|
226 |
+
self.num_layers = num_layers
|
227 |
+
self.num_buckets = num_buckets
|
228 |
+
self.shared_pos = shared_pos
|
229 |
+
|
230 |
+
# layers
|
231 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
232 |
+
else nn.Embedding(vocab, dim)
|
233 |
+
self.pos_embedding = T5RelativeEmbedding(
|
234 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
235 |
+
self.dropout = nn.Dropout(dropout)
|
236 |
+
self.blocks = nn.ModuleList([
|
237 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
238 |
+
shared_pos, dropout) for _ in range(num_layers)
|
239 |
+
])
|
240 |
+
self.norm = T5LayerNorm(dim)
|
241 |
+
|
242 |
+
# initialize weights
|
243 |
+
self.apply(init_weights)
|
244 |
+
|
245 |
+
def forward(self, ids, mask=None):
|
246 |
+
x = self.token_embedding(ids)
|
247 |
+
x = self.dropout(x)
|
248 |
+
e = self.pos_embedding(x.size(1),
|
249 |
+
x.size(1)) if self.shared_pos else None
|
250 |
+
for block in self.blocks:
|
251 |
+
x = block(x, mask, pos_bias=e)
|
252 |
+
x = self.norm(x)
|
253 |
+
x = self.dropout(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
@staticmethod
|
257 |
+
def state_dict_converter():
|
258 |
+
return WanTextEncoderStateDictConverter()
|
259 |
+
|
260 |
+
|
261 |
+
class WanTextEncoderStateDictConverter:
|
262 |
+
def __init__(self):
|
263 |
+
pass
|
264 |
+
|
265 |
+
def from_diffusers(self, state_dict):
|
266 |
+
return state_dict
|
267 |
+
|
268 |
+
def from_civitai(self, state_dict):
|
269 |
+
return state_dict
|
model/vae.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange, repeat
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
CACHE_T = 2
|
9 |
+
|
10 |
+
|
11 |
+
def check_is_instance(model, module_class):
|
12 |
+
if isinstance(model, module_class):
|
13 |
+
return True
|
14 |
+
if hasattr(model, "module") and isinstance(model.module, module_class):
|
15 |
+
return True
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
def block_causal_mask(x, block_size):
|
20 |
+
# params
|
21 |
+
b, n, s, _, device = *x.size(), x.device
|
22 |
+
assert s % block_size == 0
|
23 |
+
num_blocks = s // block_size
|
24 |
+
|
25 |
+
# build mask
|
26 |
+
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
27 |
+
for i in range(num_blocks):
|
28 |
+
mask[:, :,
|
29 |
+
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
30 |
+
return mask
|
31 |
+
|
32 |
+
|
33 |
+
class CausalConv3d(nn.Conv3d):
|
34 |
+
"""
|
35 |
+
Causal 3d convolusion.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, *args, **kwargs):
|
39 |
+
super().__init__(*args, **kwargs)
|
40 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
41 |
+
self.padding[1], 2 * self.padding[0], 0)
|
42 |
+
self.padding = (0, 0, 0)
|
43 |
+
|
44 |
+
def forward(self, x, cache_x=None):
|
45 |
+
padding = list(self._padding)
|
46 |
+
if cache_x is not None and self._padding[4] > 0:
|
47 |
+
cache_x = cache_x.to(x.device)
|
48 |
+
x = torch.cat([cache_x, x], dim=2)
|
49 |
+
padding[4] -= cache_x.shape[2]
|
50 |
+
x = F.pad(x, padding)
|
51 |
+
|
52 |
+
return super().forward(x)
|
53 |
+
|
54 |
+
|
55 |
+
class RMS_norm(nn.Module):
|
56 |
+
|
57 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
58 |
+
super().__init__()
|
59 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
60 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
61 |
+
|
62 |
+
self.channel_first = channel_first
|
63 |
+
self.scale = dim**0.5
|
64 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
65 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return F.normalize(
|
69 |
+
x, dim=(1 if self.channel_first else
|
70 |
+
-1)) * self.scale * self.gamma + self.bias
|
71 |
+
|
72 |
+
|
73 |
+
class Upsample(nn.Upsample):
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
"""
|
77 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
78 |
+
"""
|
79 |
+
return super().forward(x.float()).type_as(x)
|
80 |
+
|
81 |
+
|
82 |
+
class Resample(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self, dim, mode):
|
85 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
86 |
+
'downsample3d')
|
87 |
+
super().__init__()
|
88 |
+
self.dim = dim
|
89 |
+
self.mode = mode
|
90 |
+
|
91 |
+
# layers
|
92 |
+
if mode == 'upsample2d':
|
93 |
+
self.resample = nn.Sequential(
|
94 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
95 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
96 |
+
elif mode == 'upsample3d':
|
97 |
+
self.resample = nn.Sequential(
|
98 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
99 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
100 |
+
self.time_conv = CausalConv3d(dim,
|
101 |
+
dim * 2, (3, 1, 1),
|
102 |
+
padding=(1, 0, 0))
|
103 |
+
|
104 |
+
elif mode == 'downsample2d':
|
105 |
+
self.resample = nn.Sequential(
|
106 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
107 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
108 |
+
elif mode == 'downsample3d':
|
109 |
+
self.resample = nn.Sequential(
|
110 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
111 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
112 |
+
self.time_conv = CausalConv3d(dim,
|
113 |
+
dim, (3, 1, 1),
|
114 |
+
stride=(2, 1, 1),
|
115 |
+
padding=(0, 0, 0))
|
116 |
+
|
117 |
+
else:
|
118 |
+
self.resample = nn.Identity()
|
119 |
+
|
120 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
121 |
+
b, c, t, h, w = x.size()
|
122 |
+
if self.mode == 'upsample3d':
|
123 |
+
if feat_cache is not None:
|
124 |
+
idx = feat_idx[0]
|
125 |
+
if feat_cache[idx] is None:
|
126 |
+
feat_cache[idx] = 'Rep'
|
127 |
+
feat_idx[0] += 1
|
128 |
+
else:
|
129 |
+
|
130 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
131 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
132 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
133 |
+
# cache last frame of last two chunk
|
134 |
+
cache_x = torch.cat([
|
135 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
136 |
+
cache_x.device), cache_x
|
137 |
+
],
|
138 |
+
dim=2)
|
139 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
140 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
141 |
+
cache_x = torch.cat([
|
142 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
143 |
+
cache_x
|
144 |
+
],
|
145 |
+
dim=2)
|
146 |
+
if feat_cache[idx] == 'Rep':
|
147 |
+
x = self.time_conv(x)
|
148 |
+
else:
|
149 |
+
x = self.time_conv(x, feat_cache[idx])
|
150 |
+
feat_cache[idx] = cache_x
|
151 |
+
feat_idx[0] += 1
|
152 |
+
|
153 |
+
x = x.reshape(b, 2, c, t, h, w)
|
154 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
155 |
+
3)
|
156 |
+
x = x.reshape(b, c, t * 2, h, w)
|
157 |
+
t = x.shape[2]
|
158 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
159 |
+
x = self.resample(x)
|
160 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
161 |
+
|
162 |
+
if self.mode == 'downsample3d':
|
163 |
+
if feat_cache is not None:
|
164 |
+
idx = feat_idx[0]
|
165 |
+
if feat_cache[idx] is None:
|
166 |
+
feat_cache[idx] = x.clone()
|
167 |
+
feat_idx[0] += 1
|
168 |
+
else:
|
169 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
170 |
+
x = self.time_conv(
|
171 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
172 |
+
feat_cache[idx] = cache_x
|
173 |
+
feat_idx[0] += 1
|
174 |
+
return x
|
175 |
+
|
176 |
+
def init_weight(self, conv):
|
177 |
+
conv_weight = conv.weight
|
178 |
+
nn.init.zeros_(conv_weight)
|
179 |
+
c1, c2, t, h, w = conv_weight.size()
|
180 |
+
one_matrix = torch.eye(c1, c2)
|
181 |
+
init_matrix = one_matrix
|
182 |
+
nn.init.zeros_(conv_weight)
|
183 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
184 |
+
conv.weight.data.copy_(conv_weight)
|
185 |
+
nn.init.zeros_(conv.bias.data)
|
186 |
+
|
187 |
+
def init_weight2(self, conv):
|
188 |
+
conv_weight = conv.weight.data
|
189 |
+
nn.init.zeros_(conv_weight)
|
190 |
+
c1, c2, t, h, w = conv_weight.size()
|
191 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
192 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
193 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
194 |
+
conv.weight.data.copy_(conv_weight)
|
195 |
+
nn.init.zeros_(conv.bias.data)
|
196 |
+
|
197 |
+
|
198 |
+
class ResidualBlock(nn.Module):
|
199 |
+
|
200 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
201 |
+
super().__init__()
|
202 |
+
self.in_dim = in_dim
|
203 |
+
self.out_dim = out_dim
|
204 |
+
|
205 |
+
# layers
|
206 |
+
self.residual = nn.Sequential(
|
207 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
208 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
209 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
210 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
211 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
212 |
+
if in_dim != out_dim else nn.Identity()
|
213 |
+
|
214 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
215 |
+
h = self.shortcut(x)
|
216 |
+
for layer in self.residual:
|
217 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
218 |
+
idx = feat_idx[0]
|
219 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
220 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
221 |
+
# cache last frame of last two chunk
|
222 |
+
cache_x = torch.cat([
|
223 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
224 |
+
cache_x.device), cache_x
|
225 |
+
],
|
226 |
+
dim=2)
|
227 |
+
x = layer(x, feat_cache[idx])
|
228 |
+
feat_cache[idx] = cache_x
|
229 |
+
feat_idx[0] += 1
|
230 |
+
else:
|
231 |
+
x = layer(x)
|
232 |
+
return x + h
|
233 |
+
|
234 |
+
|
235 |
+
class AttentionBlock(nn.Module):
|
236 |
+
"""
|
237 |
+
Causal self-attention with a single head.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self, dim):
|
241 |
+
super().__init__()
|
242 |
+
self.dim = dim
|
243 |
+
|
244 |
+
# layers
|
245 |
+
self.norm = RMS_norm(dim)
|
246 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
247 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
248 |
+
|
249 |
+
# zero out the last layer params
|
250 |
+
nn.init.zeros_(self.proj.weight)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
identity = x
|
254 |
+
b, c, t, h, w = x.size()
|
255 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
256 |
+
x = self.norm(x)
|
257 |
+
# compute query, key, value
|
258 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
259 |
+
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
260 |
+
|
261 |
+
# apply attention
|
262 |
+
x = F.scaled_dot_product_attention(
|
263 |
+
q,
|
264 |
+
k,
|
265 |
+
v,
|
266 |
+
#attn_mask=block_causal_mask(q, block_size=h * w)
|
267 |
+
)
|
268 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
269 |
+
|
270 |
+
# output
|
271 |
+
x = self.proj(x)
|
272 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
273 |
+
return x + identity
|
274 |
+
|
275 |
+
|
276 |
+
class Encoder3d(nn.Module):
|
277 |
+
|
278 |
+
def __init__(self,
|
279 |
+
dim=128,
|
280 |
+
z_dim=4,
|
281 |
+
dim_mult=[1, 2, 4, 4],
|
282 |
+
num_res_blocks=2,
|
283 |
+
attn_scales=[],
|
284 |
+
temperal_downsample=[True, True, False],
|
285 |
+
dropout=0.0):
|
286 |
+
super().__init__()
|
287 |
+
self.dim = dim
|
288 |
+
self.z_dim = z_dim
|
289 |
+
self.dim_mult = dim_mult
|
290 |
+
self.num_res_blocks = num_res_blocks
|
291 |
+
self.attn_scales = attn_scales
|
292 |
+
self.temperal_downsample = temperal_downsample
|
293 |
+
|
294 |
+
# dimensions
|
295 |
+
dims = [dim * u for u in [1] + dim_mult]
|
296 |
+
scale = 1.0
|
297 |
+
|
298 |
+
# init block
|
299 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
300 |
+
|
301 |
+
# downsample blocks
|
302 |
+
downsamples = []
|
303 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
304 |
+
# residual (+attention) blocks
|
305 |
+
for _ in range(num_res_blocks):
|
306 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
307 |
+
if scale in attn_scales:
|
308 |
+
downsamples.append(AttentionBlock(out_dim))
|
309 |
+
in_dim = out_dim
|
310 |
+
|
311 |
+
# downsample block
|
312 |
+
if i != len(dim_mult) - 1:
|
313 |
+
mode = 'downsample3d' if temperal_downsample[
|
314 |
+
i] else 'downsample2d'
|
315 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
316 |
+
scale /= 2.0
|
317 |
+
self.downsamples = nn.Sequential(*downsamples)
|
318 |
+
|
319 |
+
# middle blocks
|
320 |
+
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
321 |
+
AttentionBlock(out_dim),
|
322 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
323 |
+
|
324 |
+
# output blocks
|
325 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
326 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
327 |
+
|
328 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
329 |
+
if feat_cache is not None:
|
330 |
+
idx = feat_idx[0]
|
331 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
332 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
333 |
+
# cache last frame of last two chunk
|
334 |
+
cache_x = torch.cat([
|
335 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
336 |
+
cache_x.device), cache_x
|
337 |
+
],
|
338 |
+
dim=2)
|
339 |
+
x = self.conv1(x, feat_cache[idx])
|
340 |
+
feat_cache[idx] = cache_x
|
341 |
+
feat_idx[0] += 1
|
342 |
+
else:
|
343 |
+
x = self.conv1(x)
|
344 |
+
|
345 |
+
## downsamples
|
346 |
+
for layer in self.downsamples:
|
347 |
+
if feat_cache is not None:
|
348 |
+
x = layer(x, feat_cache, feat_idx)
|
349 |
+
else:
|
350 |
+
x = layer(x)
|
351 |
+
|
352 |
+
## middle
|
353 |
+
for layer in self.middle:
|
354 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
355 |
+
x = layer(x, feat_cache, feat_idx)
|
356 |
+
else:
|
357 |
+
x = layer(x)
|
358 |
+
|
359 |
+
## head
|
360 |
+
for layer in self.head:
|
361 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
362 |
+
idx = feat_idx[0]
|
363 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
364 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
365 |
+
# cache last frame of last two chunk
|
366 |
+
cache_x = torch.cat([
|
367 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
368 |
+
cache_x.device), cache_x
|
369 |
+
],
|
370 |
+
dim=2)
|
371 |
+
x = layer(x, feat_cache[idx])
|
372 |
+
feat_cache[idx] = cache_x
|
373 |
+
feat_idx[0] += 1
|
374 |
+
else:
|
375 |
+
x = layer(x)
|
376 |
+
return x
|
377 |
+
|
378 |
+
|
379 |
+
class Decoder3d(nn.Module):
|
380 |
+
|
381 |
+
def __init__(self,
|
382 |
+
dim=128,
|
383 |
+
z_dim=4,
|
384 |
+
dim_mult=[1, 2, 4, 4],
|
385 |
+
num_res_blocks=2,
|
386 |
+
attn_scales=[],
|
387 |
+
temperal_upsample=[False, True, True],
|
388 |
+
dropout=0.0):
|
389 |
+
super().__init__()
|
390 |
+
self.dim = dim
|
391 |
+
self.z_dim = z_dim
|
392 |
+
self.dim_mult = dim_mult
|
393 |
+
self.num_res_blocks = num_res_blocks
|
394 |
+
self.attn_scales = attn_scales
|
395 |
+
self.temperal_upsample = temperal_upsample
|
396 |
+
|
397 |
+
# dimensions
|
398 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
399 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
400 |
+
|
401 |
+
# init block
|
402 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
403 |
+
|
404 |
+
# middle blocks
|
405 |
+
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
406 |
+
AttentionBlock(dims[0]),
|
407 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
408 |
+
|
409 |
+
# upsample blocks
|
410 |
+
upsamples = []
|
411 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
412 |
+
# residual (+attention) blocks
|
413 |
+
if i == 1 or i == 2 or i == 3:
|
414 |
+
in_dim = in_dim // 2
|
415 |
+
for _ in range(num_res_blocks + 1):
|
416 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
417 |
+
if scale in attn_scales:
|
418 |
+
upsamples.append(AttentionBlock(out_dim))
|
419 |
+
in_dim = out_dim
|
420 |
+
|
421 |
+
# upsample block
|
422 |
+
if i != len(dim_mult) - 1:
|
423 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
424 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
425 |
+
scale *= 2.0
|
426 |
+
self.upsamples = nn.Sequential(*upsamples)
|
427 |
+
|
428 |
+
# output blocks
|
429 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
430 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
431 |
+
|
432 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
433 |
+
## conv1
|
434 |
+
if feat_cache is not None:
|
435 |
+
idx = feat_idx[0]
|
436 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
437 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
438 |
+
# cache last frame of last two chunk
|
439 |
+
cache_x = torch.cat([
|
440 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
441 |
+
cache_x.device), cache_x
|
442 |
+
],
|
443 |
+
dim=2)
|
444 |
+
x = self.conv1(x, feat_cache[idx])
|
445 |
+
feat_cache[idx] = cache_x
|
446 |
+
feat_idx[0] += 1
|
447 |
+
else:
|
448 |
+
x = self.conv1(x)
|
449 |
+
|
450 |
+
## middle
|
451 |
+
for layer in self.middle:
|
452 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
453 |
+
x = layer(x, feat_cache, feat_idx)
|
454 |
+
else:
|
455 |
+
x = layer(x)
|
456 |
+
|
457 |
+
## upsamples
|
458 |
+
for layer in self.upsamples:
|
459 |
+
if feat_cache is not None:
|
460 |
+
x = layer(x, feat_cache, feat_idx)
|
461 |
+
else:
|
462 |
+
x = layer(x)
|
463 |
+
|
464 |
+
## head
|
465 |
+
for layer in self.head:
|
466 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
467 |
+
idx = feat_idx[0]
|
468 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
469 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
470 |
+
# cache last frame of last two chunk
|
471 |
+
cache_x = torch.cat([
|
472 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
473 |
+
cache_x.device), cache_x
|
474 |
+
],
|
475 |
+
dim=2)
|
476 |
+
x = layer(x, feat_cache[idx])
|
477 |
+
feat_cache[idx] = cache_x
|
478 |
+
feat_idx[0] += 1
|
479 |
+
else:
|
480 |
+
x = layer(x)
|
481 |
+
return x
|
482 |
+
|
483 |
+
|
484 |
+
def count_conv3d(model):
|
485 |
+
count = 0
|
486 |
+
for m in model.modules():
|
487 |
+
if check_is_instance(m, CausalConv3d):
|
488 |
+
count += 1
|
489 |
+
return count
|
490 |
+
|
491 |
+
|
492 |
+
class VideoVAE_(nn.Module):
|
493 |
+
|
494 |
+
def __init__(self,
|
495 |
+
dim=96,
|
496 |
+
z_dim=16,
|
497 |
+
dim_mult=[1, 2, 4, 4],
|
498 |
+
num_res_blocks=2,
|
499 |
+
attn_scales=[],
|
500 |
+
temperal_downsample=[False, True, True],
|
501 |
+
dropout=0.0):
|
502 |
+
super().__init__()
|
503 |
+
self.dim = dim
|
504 |
+
self.z_dim = z_dim
|
505 |
+
self.dim_mult = dim_mult
|
506 |
+
self.num_res_blocks = num_res_blocks
|
507 |
+
self.attn_scales = attn_scales
|
508 |
+
self.temperal_downsample = temperal_downsample
|
509 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
510 |
+
|
511 |
+
# modules
|
512 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
513 |
+
attn_scales, self.temperal_downsample, dropout)
|
514 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
515 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
516 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
517 |
+
attn_scales, self.temperal_upsample, dropout)
|
518 |
+
|
519 |
+
def forward(self, x):
|
520 |
+
mu, log_var = self.encode(x)
|
521 |
+
z = self.reparameterize(mu, log_var)
|
522 |
+
x_recon = self.decode(z)
|
523 |
+
return x_recon, mu, log_var
|
524 |
+
|
525 |
+
def encode(self, x, scale):
|
526 |
+
self.clear_cache()
|
527 |
+
## cache
|
528 |
+
t = x.shape[2]
|
529 |
+
iter_ = 1 + (t - 1) // 4
|
530 |
+
|
531 |
+
for i in range(iter_):
|
532 |
+
self._enc_conv_idx = [0]
|
533 |
+
if i == 0:
|
534 |
+
out = self.encoder(x[:, :, :1, :, :],
|
535 |
+
feat_cache=self._enc_feat_map,
|
536 |
+
feat_idx=self._enc_conv_idx)
|
537 |
+
else:
|
538 |
+
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
539 |
+
feat_cache=self._enc_feat_map,
|
540 |
+
feat_idx=self._enc_conv_idx)
|
541 |
+
out = torch.cat([out, out_], 2)
|
542 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
543 |
+
if isinstance(scale[0], torch.Tensor):
|
544 |
+
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
545 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
546 |
+
1, self.z_dim, 1, 1, 1)
|
547 |
+
else:
|
548 |
+
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
549 |
+
mu = (mu - scale[0]) * scale[1]
|
550 |
+
return mu
|
551 |
+
|
552 |
+
def decode(self, z, scale):
|
553 |
+
self.clear_cache()
|
554 |
+
# z: [b,c,t,h,w]
|
555 |
+
if isinstance(scale[0], torch.Tensor):
|
556 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
557 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
558 |
+
1, self.z_dim, 1, 1, 1)
|
559 |
+
else:
|
560 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
561 |
+
z = z / scale[1] + scale[0]
|
562 |
+
iter_ = z.shape[2]
|
563 |
+
x = self.conv2(z)
|
564 |
+
for i in range(iter_):
|
565 |
+
self._conv_idx = [0]
|
566 |
+
if i == 0:
|
567 |
+
out = self.decoder(x[:, :, i:i + 1, :, :],
|
568 |
+
feat_cache=self._feat_map,
|
569 |
+
feat_idx=self._conv_idx)
|
570 |
+
else:
|
571 |
+
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
572 |
+
feat_cache=self._feat_map,
|
573 |
+
feat_idx=self._conv_idx)
|
574 |
+
out = torch.cat([out, out_], 2) # may add tensor offload
|
575 |
+
return out
|
576 |
+
|
577 |
+
def reparameterize(self, mu, log_var):
|
578 |
+
std = torch.exp(0.5 * log_var)
|
579 |
+
eps = torch.randn_like(std)
|
580 |
+
return eps * std + mu
|
581 |
+
|
582 |
+
def sample(self, imgs, deterministic=False):
|
583 |
+
mu, log_var = self.encode(imgs)
|
584 |
+
if deterministic:
|
585 |
+
return mu
|
586 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
587 |
+
return mu + std * torch.randn_like(std)
|
588 |
+
|
589 |
+
def clear_cache(self):
|
590 |
+
self._conv_num = count_conv3d(self.decoder)
|
591 |
+
self._conv_idx = [0]
|
592 |
+
self._feat_map = [None] * self._conv_num
|
593 |
+
# cache encode
|
594 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
595 |
+
self._enc_conv_idx = [0]
|
596 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
597 |
+
|
598 |
+
|
599 |
+
class WanVideoVAE(nn.Module):
|
600 |
+
|
601 |
+
def __init__(self, z_dim=16):
|
602 |
+
super().__init__()
|
603 |
+
|
604 |
+
mean = [
|
605 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
606 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
607 |
+
]
|
608 |
+
std = [
|
609 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
610 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
611 |
+
]
|
612 |
+
self.mean = torch.tensor(mean)
|
613 |
+
self.std = torch.tensor(std)
|
614 |
+
self.scale = [self.mean, 1.0 / self.std]
|
615 |
+
|
616 |
+
# init model
|
617 |
+
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
618 |
+
self.upsampling_factor = 8
|
619 |
+
|
620 |
+
|
621 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
622 |
+
x = torch.ones((length,))
|
623 |
+
if not left_bound:
|
624 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
625 |
+
if not right_bound:
|
626 |
+
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
627 |
+
return x
|
628 |
+
|
629 |
+
|
630 |
+
def build_mask(self, data, is_bound, border_width):
|
631 |
+
_, _, _, H, W = data.shape
|
632 |
+
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
633 |
+
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
634 |
+
|
635 |
+
h = repeat(h, "H -> H W", H=H, W=W)
|
636 |
+
w = repeat(w, "W -> H W", H=H, W=W)
|
637 |
+
|
638 |
+
mask = torch.stack([h, w]).min(dim=0).values
|
639 |
+
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
640 |
+
return mask
|
641 |
+
|
642 |
+
|
643 |
+
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
644 |
+
_, _, T, H, W = hidden_states.shape
|
645 |
+
size_h, size_w = tile_size
|
646 |
+
stride_h, stride_w = tile_stride
|
647 |
+
|
648 |
+
# Split tasks
|
649 |
+
tasks = []
|
650 |
+
for h in range(0, H, stride_h):
|
651 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
652 |
+
for w in range(0, W, stride_w):
|
653 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
654 |
+
h_, w_ = h + size_h, w + size_w
|
655 |
+
tasks.append((h, h_, w, w_))
|
656 |
+
|
657 |
+
data_device = "cpu"
|
658 |
+
computation_device = device
|
659 |
+
|
660 |
+
out_T = T * 4 - 3
|
661 |
+
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
662 |
+
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
663 |
+
|
664 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
665 |
+
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
666 |
+
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
667 |
+
|
668 |
+
mask = self.build_mask(
|
669 |
+
hidden_states_batch,
|
670 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
671 |
+
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
672 |
+
).to(dtype=hidden_states.dtype, device=data_device)
|
673 |
+
|
674 |
+
target_h = h * self.upsampling_factor
|
675 |
+
target_w = w * self.upsampling_factor
|
676 |
+
values[
|
677 |
+
:,
|
678 |
+
:,
|
679 |
+
:,
|
680 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
681 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
682 |
+
] += hidden_states_batch * mask
|
683 |
+
weight[
|
684 |
+
:,
|
685 |
+
:,
|
686 |
+
:,
|
687 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
688 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
689 |
+
] += mask
|
690 |
+
values = values / weight
|
691 |
+
values = values.float().clamp_(-1, 1)
|
692 |
+
return values
|
693 |
+
|
694 |
+
|
695 |
+
def tiled_encode(self, video, device, tile_size, tile_stride):
|
696 |
+
_, _, T, H, W = video.shape
|
697 |
+
size_h, size_w = tile_size
|
698 |
+
stride_h, stride_w = tile_stride
|
699 |
+
|
700 |
+
# Split tasks
|
701 |
+
tasks = []
|
702 |
+
for h in range(0, H, stride_h):
|
703 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
704 |
+
for w in range(0, W, stride_w):
|
705 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
706 |
+
h_, w_ = h + size_h, w + size_w
|
707 |
+
tasks.append((h, h_, w, w_))
|
708 |
+
|
709 |
+
data_device = "cpu"
|
710 |
+
computation_device = device
|
711 |
+
|
712 |
+
out_T = (T + 3) // 4
|
713 |
+
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
714 |
+
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
715 |
+
|
716 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
717 |
+
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
718 |
+
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
719 |
+
|
720 |
+
mask = self.build_mask(
|
721 |
+
hidden_states_batch,
|
722 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
723 |
+
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
724 |
+
).to(dtype=video.dtype, device=data_device)
|
725 |
+
|
726 |
+
target_h = h // self.upsampling_factor
|
727 |
+
target_w = w // self.upsampling_factor
|
728 |
+
values[
|
729 |
+
:,
|
730 |
+
:,
|
731 |
+
:,
|
732 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
733 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
734 |
+
] += hidden_states_batch * mask
|
735 |
+
weight[
|
736 |
+
:,
|
737 |
+
:,
|
738 |
+
:,
|
739 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
740 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
741 |
+
] += mask
|
742 |
+
values = values / weight
|
743 |
+
values = values.float()
|
744 |
+
return values
|
745 |
+
|
746 |
+
|
747 |
+
def single_encode(self, video, device):
|
748 |
+
video = video.to(device)
|
749 |
+
x = self.model.encode(video, self.scale)
|
750 |
+
return x.float()
|
751 |
+
|
752 |
+
|
753 |
+
def single_decode(self, hidden_state, device):
|
754 |
+
hidden_state = hidden_state.to(device)
|
755 |
+
video = self.model.decode(hidden_state, self.scale)
|
756 |
+
return video.float().clamp_(-1, 1)
|
757 |
+
|
758 |
+
|
759 |
+
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
760 |
+
|
761 |
+
videos = [video.to("cpu") for video in videos]
|
762 |
+
hidden_states = []
|
763 |
+
for video in videos:
|
764 |
+
video = video.unsqueeze(0)
|
765 |
+
if tiled:
|
766 |
+
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
767 |
+
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
768 |
+
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
769 |
+
else:
|
770 |
+
hidden_state = self.single_encode(video, device)
|
771 |
+
hidden_state = hidden_state.squeeze(0)
|
772 |
+
hidden_states.append(hidden_state)
|
773 |
+
hidden_states = torch.stack(hidden_states)
|
774 |
+
return hidden_states
|
775 |
+
|
776 |
+
|
777 |
+
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
778 |
+
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
779 |
+
videos = []
|
780 |
+
for hidden_state in hidden_states:
|
781 |
+
hidden_state = hidden_state.unsqueeze(0)
|
782 |
+
if tiled:
|
783 |
+
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
784 |
+
else:
|
785 |
+
video = self.single_decode(hidden_state, device)
|
786 |
+
video = video.squeeze(0)
|
787 |
+
videos.append(video)
|
788 |
+
videos = torch.stack(videos)
|
789 |
+
return videos
|
790 |
+
|
791 |
+
|
792 |
+
@staticmethod
|
793 |
+
def state_dict_converter():
|
794 |
+
return WanVideoVAEStateDictConverter()
|
795 |
+
|
796 |
+
|
797 |
+
class WanVideoVAEStateDictConverter:
|
798 |
+
|
799 |
+
def __init__(self):
|
800 |
+
pass
|
801 |
+
|
802 |
+
def from_civitai(self, state_dict):
|
803 |
+
state_dict_ = {}
|
804 |
+
if 'model_state' in state_dict:
|
805 |
+
state_dict = state_dict['model_state']
|
806 |
+
for name in state_dict:
|
807 |
+
state_dict_['model.' + name] = state_dict[name]
|
808 |
+
return state_dict_
|
809 |
+
|
pipeline/__init__.py
ADDED
File without changes
|
pipeline/i2v_pipeline.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffsynth import ModelManager
|
2 |
+
from diffsynth.pipelines.base import BasePipeline
|
3 |
+
from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
4 |
+
|
5 |
+
from model.dit import WanModel
|
6 |
+
from model.text_encoder import WanTextEncoder
|
7 |
+
from model.vae import WanVideoVAE
|
8 |
+
from model.image_encoder import WanImageEncoder
|
9 |
+
from model.prompter import WanPrompter
|
10 |
+
from scheduler.flow_match import FlowMatchScheduler
|
11 |
+
|
12 |
+
import torch, os
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
from tqdm import tqdm
|
17 |
+
from safetensors import safe_open
|
18 |
+
|
19 |
+
from model.text_encoder import T5RelativeEmbedding, T5LayerNorm
|
20 |
+
from model.dit import WanLayerNorm, WanRMSNorm, WanSelfAttention
|
21 |
+
from model.vae import RMS_norm, CausalConv3d, Upsample
|
22 |
+
|
23 |
+
|
24 |
+
def binary_tensor_to_indices(tensor):
|
25 |
+
assert tensor.dim() == 2, "Input tensor must be in [b, t]"
|
26 |
+
indices = [(row == 1).nonzero(as_tuple=True)[0] for row in tensor]
|
27 |
+
return indices
|
28 |
+
|
29 |
+
def propagate_visualize_attention_arg(model, visualize_attention=False):
|
30 |
+
"""
|
31 |
+
Recursively set the visualize_attention parameter to True for all WanSelfAttention modules
|
32 |
+
Only for inference/test mode
|
33 |
+
"""
|
34 |
+
for name, module in model.named_modules():
|
35 |
+
if isinstance(module, WanSelfAttention):
|
36 |
+
if "blocks.0.self_attn" in name or "blocks.19.self_attn" in name or "blocks.39.self_attn" in name:
|
37 |
+
print(f"Set `visualize_attention` to {visualize_attention} for {name}")
|
38 |
+
module.visualize_attention = visualize_attention
|
39 |
+
|
40 |
+
class WanVideoPipeline(BasePipeline):
|
41 |
+
|
42 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
43 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
44 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
45 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
46 |
+
self.text_encoder: WanTextEncoder = None
|
47 |
+
self.image_encoder: WanImageEncoder = None
|
48 |
+
self.dit: WanModel = None
|
49 |
+
self.vae: WanVideoVAE = None
|
50 |
+
self.model_names = ['text_encoder', 'dit', 'vae']
|
51 |
+
self.height_division_factor = 16
|
52 |
+
self.width_division_factor = 16
|
53 |
+
|
54 |
+
|
55 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
56 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
57 |
+
enable_vram_management(
|
58 |
+
self.text_encoder,
|
59 |
+
module_map = {
|
60 |
+
torch.nn.Linear: AutoWrappedLinear,
|
61 |
+
torch.nn.Embedding: AutoWrappedModule,
|
62 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
63 |
+
T5LayerNorm: AutoWrappedModule,
|
64 |
+
},
|
65 |
+
module_config = dict(
|
66 |
+
offload_dtype=dtype,
|
67 |
+
offload_device="cpu",
|
68 |
+
onload_dtype=dtype,
|
69 |
+
onload_device="cpu",
|
70 |
+
computation_dtype=self.torch_dtype,
|
71 |
+
computation_device=self.device,
|
72 |
+
),
|
73 |
+
)
|
74 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
75 |
+
enable_vram_management(
|
76 |
+
self.dit,
|
77 |
+
module_map = {
|
78 |
+
torch.nn.Linear: AutoWrappedLinear,
|
79 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
80 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
81 |
+
WanLayerNorm: AutoWrappedModule,
|
82 |
+
WanRMSNorm: AutoWrappedModule,
|
83 |
+
},
|
84 |
+
module_config = dict(
|
85 |
+
offload_dtype=dtype,
|
86 |
+
offload_device="cpu",
|
87 |
+
onload_dtype=dtype,
|
88 |
+
onload_device=self.device,
|
89 |
+
computation_dtype=self.torch_dtype,
|
90 |
+
computation_device=self.device,
|
91 |
+
),
|
92 |
+
max_num_param=num_persistent_param_in_dit,
|
93 |
+
overflow_module_config = dict(
|
94 |
+
offload_dtype=dtype,
|
95 |
+
offload_device="cpu",
|
96 |
+
onload_dtype=dtype,
|
97 |
+
onload_device="cpu",
|
98 |
+
computation_dtype=self.torch_dtype,
|
99 |
+
computation_device=self.device,
|
100 |
+
),
|
101 |
+
)
|
102 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
103 |
+
enable_vram_management(
|
104 |
+
self.vae,
|
105 |
+
module_map = {
|
106 |
+
torch.nn.Linear: AutoWrappedLinear,
|
107 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
108 |
+
RMS_norm: AutoWrappedModule,
|
109 |
+
CausalConv3d: AutoWrappedModule,
|
110 |
+
Upsample: AutoWrappedModule,
|
111 |
+
torch.nn.SiLU: AutoWrappedModule,
|
112 |
+
torch.nn.Dropout: AutoWrappedModule,
|
113 |
+
},
|
114 |
+
module_config = dict(
|
115 |
+
offload_dtype=dtype,
|
116 |
+
offload_device="cpu",
|
117 |
+
onload_dtype=dtype,
|
118 |
+
onload_device=self.device,
|
119 |
+
computation_dtype=self.torch_dtype,
|
120 |
+
computation_device=self.device,
|
121 |
+
),
|
122 |
+
)
|
123 |
+
if self.image_encoder is not None:
|
124 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
125 |
+
enable_vram_management(
|
126 |
+
self.image_encoder,
|
127 |
+
module_map = {
|
128 |
+
torch.nn.Linear: AutoWrappedLinear,
|
129 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
130 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
131 |
+
},
|
132 |
+
module_config = dict(
|
133 |
+
offload_dtype=dtype,
|
134 |
+
offload_device="cpu",
|
135 |
+
onload_dtype=dtype,
|
136 |
+
onload_device="cpu",
|
137 |
+
computation_dtype=self.torch_dtype,
|
138 |
+
computation_device=self.device,
|
139 |
+
),
|
140 |
+
)
|
141 |
+
self.enable_cpu_offload()
|
142 |
+
|
143 |
+
def fetch_models_from_model_manager(self, model_manager: ModelManager):
|
144 |
+
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
145 |
+
if text_encoder_model_and_path is not None:
|
146 |
+
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
147 |
+
self.prompter.fetch_models(self.text_encoder)
|
148 |
+
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
149 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
150 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
151 |
+
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
152 |
+
|
153 |
+
def _init_component_from_checkpoint_path(self, model_cls, state_dict_path, strict=True, config_dict=None):
|
154 |
+
config = {}
|
155 |
+
state_dict = self._load_state_dict(state_dict_path)
|
156 |
+
if hasattr(model_cls, "state_dict_converter"):
|
157 |
+
state_dict_converter = model_cls.state_dict_converter()
|
158 |
+
state_dict = state_dict_converter.from_civitai(state_dict)
|
159 |
+
if isinstance(state_dict, tuple):
|
160 |
+
state_dict, config = state_dict
|
161 |
+
config.update(config_dict or {})
|
162 |
+
model = model_cls(**config)
|
163 |
+
if "use_local_lora" in config_dict or "use_dera" in config_dict:
|
164 |
+
strict = False
|
165 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|
166 |
+
print(f"Missing keys: {missing_keys}")
|
167 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
168 |
+
return model
|
169 |
+
|
170 |
+
def _load_state_dict(self, state_dict_paths):
|
171 |
+
if isinstance(state_dict_paths, str):
|
172 |
+
state_dict_paths = [state_dict_paths]
|
173 |
+
state_dict = {}
|
174 |
+
for state_dict_path in tqdm(state_dict_paths, desc="Reading file(s) from disk"):
|
175 |
+
state_dict.update(self._load_single_file(state_dict_path))
|
176 |
+
return state_dict
|
177 |
+
|
178 |
+
def _load_single_file(self, file_path):
|
179 |
+
if file_path.endswith(".safetensors"):
|
180 |
+
return self._load_state_dict_from_safetensors(file_path)
|
181 |
+
else:
|
182 |
+
return torch.load(file_path, map_location='cpu')
|
183 |
+
|
184 |
+
def _load_state_dict_from_safetensors(self, file_path, torch_dtype=None):
|
185 |
+
state_dict = {}
|
186 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
187 |
+
for k in f.keys():
|
188 |
+
state_dict[k] = f.get_tensor(k)
|
189 |
+
if torch_dtype is not None:
|
190 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
191 |
+
return state_dict
|
192 |
+
|
193 |
+
def initialize_dummy_dit(self, config):
|
194 |
+
print("Initializing a dummy DIT model.")
|
195 |
+
self.dit = WanModel(**config)
|
196 |
+
print("Dummy DIT model is initialized.")
|
197 |
+
|
198 |
+
def fetch_models_from_checkpoints(self, path_dict, config_dict=None):
|
199 |
+
default_config = {"text_encoder": {}, "dit": {}, "vae": {}, "image_encoder": {}}
|
200 |
+
config_dict = {**default_config, **(config_dict or {})}
|
201 |
+
components = {
|
202 |
+
"text_encoder": WanTextEncoder,
|
203 |
+
"dit": WanModel,
|
204 |
+
"vae": WanVideoVAE,
|
205 |
+
"image_encoder": WanImageEncoder
|
206 |
+
}
|
207 |
+
for name, model_cls in components.items():
|
208 |
+
if name not in path_dict:
|
209 |
+
print(f"Component {name} is not found in the checkpoint path dict. Skipping.")
|
210 |
+
continue
|
211 |
+
path = path_dict[name]
|
212 |
+
config = config_dict.get(name, {})
|
213 |
+
print(f"Loading {name} from {path} with config {config}.")
|
214 |
+
setattr(self, name, self._init_component_from_checkpoint_path(model_cls, path, config_dict=config))
|
215 |
+
print(f"Initialized {name} from checkpoint.")
|
216 |
+
if "text_encoder" in path_dict:
|
217 |
+
self.prompter.fetch_models(self.text_encoder)
|
218 |
+
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(path_dict["text_encoder"]), "google/umt5-xxl"))
|
219 |
+
print("Initialized prompter from checkpoint.")
|
220 |
+
print("All components are initialized from checkpoints.")
|
221 |
+
|
222 |
+
@staticmethod
|
223 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
224 |
+
if device is None: device = model_manager.device
|
225 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
226 |
+
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
227 |
+
pipe.fetch_models_from_model_manager(model_manager)
|
228 |
+
return pipe
|
229 |
+
|
230 |
+
def denoising_model(self):
|
231 |
+
return self.dit
|
232 |
+
|
233 |
+
def encode_prompt(self, prompt, positive=True):
|
234 |
+
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
235 |
+
return {"context": prompt_emb}
|
236 |
+
|
237 |
+
def encode_image(self, image, num_frames, height, width):
|
238 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
239 |
+
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
240 |
+
clip_context = self.image_encoder.encode_image([image])
|
241 |
+
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
242 |
+
msk[:, 1:] = 0
|
243 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
244 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
245 |
+
msk = msk.transpose(1, 2)[0]
|
246 |
+
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
247 |
+
y = torch.concat([msk, y])
|
248 |
+
return {"clip_fea": clip_context, "y": [y]}
|
249 |
+
|
250 |
+
def check_and_fix_image_or_video_tensor_input(self, _tensor):
|
251 |
+
assert isinstance(_tensor, torch.Tensor), "Input must be a tensor."
|
252 |
+
if _tensor.max() <= 255 and _tensor.max() > 1.0:
|
253 |
+
_tensor = _tensor.to(self.device) / 127.5 - 1
|
254 |
+
print("Input tensor is converted from [0, 255] to [-1, 1].")
|
255 |
+
elif _tensor.min() >= 0 and _tensor.max() <= 1:
|
256 |
+
_tensor = _tensor.to(self.device) * 2 - 1
|
257 |
+
print("Input tensor is converted from [0, 1] to [-1, 1].")
|
258 |
+
return _tensor
|
259 |
+
|
260 |
+
def encode_video_with_mask(self, video, num_frames, height, width, condition_preserved_mask):
|
261 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
262 |
+
video = video.to(self.device)
|
263 |
+
y = self.vae.encode(video, device=self.device)
|
264 |
+
msk = condition_preserved_mask
|
265 |
+
assert msk is not None, "The mask must be provided for the masked video input."
|
266 |
+
assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
|
267 |
+
assert msk.shape[0] == video.shape[0], "The batch size of the mask must be the same as the input video."
|
268 |
+
assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
|
269 |
+
msk = msk.to(self.device)
|
270 |
+
msk = msk.unsqueeze(-1).unsqueeze(-1)
|
271 |
+
msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
|
272 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
273 |
+
msk = msk.view(video.shape[0], msk.shape[1] // 4, 4, height//8, width//8) # b, t, c, h, w
|
274 |
+
msk = msk.transpose(1, 2) # b, c, t, h, w
|
275 |
+
y = torch.concat([msk, y], dim=1)
|
276 |
+
return y
|
277 |
+
|
278 |
+
def encode_video_with_mask_sparse(self, video, height, width, condition_preserved_mask, sketch_local_mask=None):
|
279 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
280 |
+
batch_size = video.shape[0]
|
281 |
+
cond_indices = binary_tensor_to_indices(condition_preserved_mask)
|
282 |
+
sequence_cond_compressed_indices = [(cond_index + 3) // 4 for cond_index in cond_indices]
|
283 |
+
video = video.to(self.device)
|
284 |
+
video_latent = self.vae.encode(video, device=self.device)
|
285 |
+
video_latent = video_latent[:, :, sequence_cond_compressed_indices[0], :, :]
|
286 |
+
msk = condition_preserved_mask.to(self.device)
|
287 |
+
msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1
|
288 |
+
msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
|
289 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
290 |
+
msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8
|
291 |
+
msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8
|
292 |
+
msk = msk[:, :, sequence_cond_compressed_indices[0], :, :]
|
293 |
+
|
294 |
+
if sketch_local_mask is not None:
|
295 |
+
sketch_local_mask = sketch_local_mask.to(self.device)
|
296 |
+
if sketch_local_mask.shape[-2:] != (height//8, width//8):
|
297 |
+
sk_batch_t = sketch_local_mask.shape[0] * sketch_local_mask.shape[2]
|
298 |
+
sketch_local_mask_reshaped = sketch_local_mask.reshape(sk_batch_t, 1, sketch_local_mask.shape[3], sketch_local_mask.shape[4])
|
299 |
+
sketch_local_mask_resized = torch.nn.functional.interpolate(
|
300 |
+
sketch_local_mask_reshaped,
|
301 |
+
size=(height//8, width//8),
|
302 |
+
mode='nearest'
|
303 |
+
)
|
304 |
+
sketch_local_mask_resized = sketch_local_mask_resized.reshape(
|
305 |
+
sketch_local_mask.shape[0],
|
306 |
+
sketch_local_mask.shape[1],
|
307 |
+
sketch_local_mask.shape[2],
|
308 |
+
height//8, width//8
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
sketch_local_mask_resized = sketch_local_mask
|
312 |
+
|
313 |
+
sketch_mask = sketch_local_mask_resized
|
314 |
+
sketch_mask = torch.concat([torch.repeat_interleave(sketch_mask[:, :, 0:1], repeats=4, dim=2), sketch_mask[:, :, 1:]], dim=2)
|
315 |
+
sketch_mask = sketch_mask.view(batch_size, sketch_mask.shape[1], sketch_mask.shape[2] // 4, 4, height//8, width//8)
|
316 |
+
sketch_mask = sketch_mask.permute(0, 1, 3, 2, 4, 5) # [b, 1, 4, t//4, h//8, w//8]
|
317 |
+
sketch_mask = sketch_mask.view(batch_size, 4, sketch_mask.shape[3], height//8, width//8) # [b, 4, t//4, h//8, w//8]
|
318 |
+
sketch_mask = sketch_mask[:, :, sequence_cond_compressed_indices[0], :, :] # [b, 4, len(indices), h//8, w//8]
|
319 |
+
|
320 |
+
combined_latent = torch.cat([msk, video_latent, sketch_mask], dim=1)
|
321 |
+
else:
|
322 |
+
combined_latent = torch.concat([msk, video_latent], dim=1)
|
323 |
+
|
324 |
+
return combined_latent, sequence_cond_compressed_indices # b, c=(4+16+4=24), t, h, w when sketch_local_mask is provided
|
325 |
+
|
326 |
+
def encode_image_or_masked_video(self, image_or_masked_video, num_frames, height, width, condition_preserved_mask=None):
|
327 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
328 |
+
batch_size = image_or_masked_video.shape[0]
|
329 |
+
if isinstance(image_or_masked_video, PIL.Image.Image) or (isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() <= 4):
|
330 |
+
if isinstance(image_or_masked_video, PIL.Image.Image):
|
331 |
+
image_or_masked_video = self.preprocess_image(image_or_masked_video.resize((width, height))).to(self.device)
|
332 |
+
else:
|
333 |
+
if image_or_masked_video.dim() == 3:
|
334 |
+
image_or_masked_video = image_or_masked_video.unsqueeze(0) # b=1, c, h, w
|
335 |
+
image_or_masked_video = image_or_masked_video.to(self.device)
|
336 |
+
y = self.vae.encode([torch.concat([image_or_masked_video.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_or_masked_video.device)], dim=1)], device=self.device)
|
337 |
+
msk_idx_to_be_zero = range(1, num_frames)
|
338 |
+
clip_context = self.image_encoder.encode_image(image_or_masked_video.unsqueeze(1)) # need to be [b, 1, c, h, w]
|
339 |
+
msk = torch.ones(batch_size, num_frames, height//8, width//8, device=self.device)
|
340 |
+
msk[:, msk_idx_to_be_zero] = 0
|
341 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
342 |
+
msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8)
|
343 |
+
msk = msk.transpose(1, 2)
|
344 |
+
elif isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() == 5:
|
345 |
+
image_or_masked_video = image_or_masked_video.to(self.device)
|
346 |
+
first_image = image_or_masked_video[:, :, 0, :, :].unsqueeze(1)
|
347 |
+
clip_context = self.image_encoder.encode_image(first_image)
|
348 |
+
y = self.vae.encode(image_or_masked_video, device=self.device)
|
349 |
+
msk = condition_preserved_mask # b, t
|
350 |
+
assert msk is not None, "The mask must be provided for the masked video input."
|
351 |
+
assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
|
352 |
+
assert msk.shape[0] == batch_size, "The batch size of the mask must be the same as the input video."
|
353 |
+
assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
|
354 |
+
msk = msk.to(self.device)
|
355 |
+
msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1
|
356 |
+
msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
|
357 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
358 |
+
msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8
|
359 |
+
msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8
|
360 |
+
else:
|
361 |
+
raise ValueError("Input must be an image (PIL/Tensor in [b, c, h, w]) or a masked video (Tensor in [b, c, t, h, w]).")
|
362 |
+
|
363 |
+
y = torch.concat([msk, y], dim=1)
|
364 |
+
return {"clip_fea": clip_context, "y": y}
|
365 |
+
|
366 |
+
def tensor2video(self, frames):
|
367 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
368 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
369 |
+
frames = [PIL.Image.fromarray(frame) for frame in frames]
|
370 |
+
return frames
|
371 |
+
|
372 |
+
def prepare_extra_input(self, latents=None):
|
373 |
+
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
374 |
+
|
375 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
376 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
377 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
378 |
+
return latents
|
379 |
+
|
380 |
+
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
381 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
382 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
383 |
+
return frames
|
384 |
+
|
385 |
+
@torch.no_grad()
|
386 |
+
def __call__(
|
387 |
+
self,
|
388 |
+
prompt,
|
389 |
+
negative_prompt="",
|
390 |
+
input_image=None,
|
391 |
+
input_video=None,
|
392 |
+
denoising_strength=1.0,
|
393 |
+
seed=None,
|
394 |
+
rand_device="cpu",
|
395 |
+
height=480,
|
396 |
+
width=832,
|
397 |
+
num_frames=81,
|
398 |
+
cfg_scale=5.0,
|
399 |
+
num_inference_steps=50,
|
400 |
+
sigma_shift=5.0,
|
401 |
+
tiled=True,
|
402 |
+
tile_size=(30, 52),
|
403 |
+
tile_stride=(15, 26),
|
404 |
+
progress_bar_cmd=tqdm,
|
405 |
+
# progress_bar_st=None,
|
406 |
+
input_condition_video=None,
|
407 |
+
input_condition_preserved_mask=None,
|
408 |
+
input_condition_video_sketch=None,
|
409 |
+
input_condition_preserved_mask_sketch=None,
|
410 |
+
sketch_local_mask=None,
|
411 |
+
visualize_attention=False,
|
412 |
+
output_path=None,
|
413 |
+
batch_idx=None,
|
414 |
+
sequence_cond_residual_scale=1.0,
|
415 |
+
):
|
416 |
+
height, width = self.check_resize_height_width(height, width)
|
417 |
+
if num_frames % 4 != 1:
|
418 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
419 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
420 |
+
|
421 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
422 |
+
|
423 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
424 |
+
|
425 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
426 |
+
if input_video is not None:
|
427 |
+
self.load_models_to_device(['vae'])
|
428 |
+
input_video = self.preprocess_images(input_video)
|
429 |
+
input_video = torch.stack(input_video, dim=2)
|
430 |
+
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
431 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
432 |
+
else:
|
433 |
+
latents = noise
|
434 |
+
|
435 |
+
self.load_models_to_device(["text_encoder"])
|
436 |
+
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
437 |
+
if cfg_scale != 1.0:
|
438 |
+
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
439 |
+
|
440 |
+
self.load_models_to_device(["image_encoder", "vae"])
|
441 |
+
if input_image is not None and self.image_encoder is not None:
|
442 |
+
image_emb = self.encode_image(input_image, num_frames, height, width)
|
443 |
+
elif input_condition_video is not None and self.image_encoder is not None:
|
444 |
+
assert input_condition_preserved_mask is not None, "`input_condition_preserved_mask` must not be None when `input_condition_video` is given."
|
445 |
+
image_emb = self.encode_image_or_masked_video(input_condition_video, num_frames, height, width, input_condition_preserved_mask)
|
446 |
+
else:
|
447 |
+
image_emb = {}
|
448 |
+
|
449 |
+
# Extra input
|
450 |
+
extra_input = self.prepare_extra_input(latents)
|
451 |
+
if self.dit.use_sequence_cond:
|
452 |
+
assert input_condition_video_sketch is not None, "`input_condition_video_sketch` must not be None when `use_sequence_cond` is True."
|
453 |
+
assert input_condition_preserved_mask_sketch is not None, "`input_condition_preserved_mask_sketch` must not be None when `input_condition_video_sketch` is given."
|
454 |
+
|
455 |
+
if self.dit.sequence_cond_mode == "sparse":
|
456 |
+
sequence_cond, sequence_cond_compressed_indices = self.encode_video_with_mask_sparse(input_condition_video_sketch, height, width, input_condition_preserved_mask_sketch, sketch_local_mask)
|
457 |
+
extra_input.update({"sequence_cond": sequence_cond,
|
458 |
+
"sequence_cond_compressed_indices": sequence_cond_compressed_indices})
|
459 |
+
elif self.dit.sequence_cond_mode == "full":
|
460 |
+
sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
|
461 |
+
extra_input.update({"sequence_cond": sequence_cond})
|
462 |
+
else:
|
463 |
+
raise ValueError(f"Invalid `sequence_cond_model`={self.dit.sequence_cond_mode} in the DIT model.")
|
464 |
+
|
465 |
+
elif self.dit.use_channel_cond:
|
466 |
+
sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
|
467 |
+
extra_input.update({"channel_cond": sequence_cond})
|
468 |
+
|
469 |
+
self.load_models_to_device([])
|
470 |
+
|
471 |
+
if sequence_cond_residual_scale != 1.0:
|
472 |
+
extra_input.update({"sequence_cond_residual_scale": sequence_cond_residual_scale})
|
473 |
+
|
474 |
+
# Denoise
|
475 |
+
self.load_models_to_device(["dit"])
|
476 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
477 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
478 |
+
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
479 |
+
_should_visualize_attention = visualize_attention and (progress_id == len(self.scheduler.timesteps) - 1)
|
480 |
+
if _should_visualize_attention:
|
481 |
+
print(f"Visualizing attention maps (Step {progress_id + 1}/{len(self.scheduler.timesteps)}).")
|
482 |
+
propagate_visualize_attention_arg(self.dit, True)
|
483 |
+
|
484 |
+
# Inference
|
485 |
+
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
486 |
+
if isinstance(noise_pred_posi, tuple):
|
487 |
+
noise_pred_posi = noise_pred_posi[0]
|
488 |
+
if cfg_scale != 1.0:
|
489 |
+
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
490 |
+
if isinstance(noise_pred_nega, tuple):
|
491 |
+
noise_pred_nega = noise_pred_nega[0]
|
492 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
493 |
+
else:
|
494 |
+
noise_pred = noise_pred_posi
|
495 |
+
|
496 |
+
# Scheduler
|
497 |
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
498 |
+
|
499 |
+
# If visualization is enabled, save the attention maps
|
500 |
+
if _should_visualize_attention:
|
501 |
+
print("Saving attention maps...")
|
502 |
+
from util.model_util import save_attention_maps
|
503 |
+
save_attention_maps(self.dit, output_path, batch_idx, timestep.squeeze().cpu().numpy().item())
|
504 |
+
propagate_visualize_attention_arg(self.dit, False)
|
505 |
+
|
506 |
+
# Decode
|
507 |
+
self.load_models_to_device(['vae'])
|
508 |
+
frames = self.decode_video(latents, **tiler_kwargs)
|
509 |
+
self.load_models_to_device([])
|
510 |
+
|
511 |
+
return frames
|
requirements.txt
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.2.2
|
2 |
+
accelerate==1.6.0
|
3 |
+
beartype==0.20.2
|
4 |
+
beautifulsoup4==4.13.4
|
5 |
+
braceexpand==0.1.7
|
6 |
+
cached-property==2.0.1
|
7 |
+
certifi==2025.1.31
|
8 |
+
charset-normalizer==3.4.1
|
9 |
+
click==8.1.8
|
10 |
+
clip==0.2.0
|
11 |
+
comm==0.2.3
|
12 |
+
contourpy==1.3.2
|
13 |
+
controlnet_aux==0.0.7
|
14 |
+
crcmod==1.7
|
15 |
+
cycler==0.12.1
|
16 |
+
datasets==3.5.0
|
17 |
+
debugpy==1.8.15
|
18 |
+
decorator==5.2.1
|
19 |
+
decord==0.6.0
|
20 |
+
deepspeed==0.16.7
|
21 |
+
diffsynth==1.1.7
|
22 |
+
diffusers==0.33.1
|
23 |
+
dill==0.3.8
|
24 |
+
docker-pycreds==0.4.0
|
25 |
+
dulwich==0.22.8
|
26 |
+
easydict==1.13
|
27 |
+
einops==0.8.1
|
28 |
+
exceptiongroup==1.2.2
|
29 |
+
executing==2.2.0
|
30 |
+
fairscale==0.4.13
|
31 |
+
fastapi==0.115.12
|
32 |
+
fastrlock==0.8.3
|
33 |
+
ffmpy==0.5.0
|
34 |
+
filelock==3.13.1
|
35 |
+
flash_attn==2.8.0.post2 --global-option="--no-build-isolation"
|
36 |
+
fonttools==4.57.0
|
37 |
+
frozenlist==1.6.0
|
38 |
+
fsspec==2024.12.0
|
39 |
+
ftfy==6.3.1
|
40 |
+
func_timeout==4.3.5
|
41 |
+
fuzzywuzzy==0.18.0
|
42 |
+
gitdb==4.0.12
|
43 |
+
GitPython==3.1.44
|
44 |
+
gradio==5.25.2
|
45 |
+
gradio_client==1.8.0
|
46 |
+
groovy==0.1.2
|
47 |
+
grpcio==1.71.0
|
48 |
+
h11==0.14.0
|
49 |
+
hjson==3.1.0
|
50 |
+
httpcore==1.0.8
|
51 |
+
httpx==0.28.1
|
52 |
+
huggingface-hub==0.30.2
|
53 |
+
idna==3.10
|
54 |
+
imageio==2.37.0
|
55 |
+
imageio-ffmpeg==0.6.0
|
56 |
+
importlib_metadata==8.6.1
|
57 |
+
ipykernel==6.30.0
|
58 |
+
ipython==8.37.0
|
59 |
+
jedi==0.19.2
|
60 |
+
Jinja2==3.1.4
|
61 |
+
joblib==1.4.2
|
62 |
+
kiwisolver==1.4.8
|
63 |
+
kornia==0.8.0
|
64 |
+
kornia_rs==0.1.8
|
65 |
+
lazy_loader==0.4
|
66 |
+
lightning==2.5.1
|
67 |
+
lightning-utilities==0.14.3
|
68 |
+
lpips==0.1.4
|
69 |
+
matplotlib==3.10.1
|
70 |
+
matplotlib-inline==0.1.7
|
71 |
+
mdurl==0.1.2
|
72 |
+
modelscope==1.25.0
|
73 |
+
moviepy==2.1.2
|
74 |
+
mpmath==1.3.0
|
75 |
+
msgpack==1.1.0
|
76 |
+
multidict==6.4.3
|
77 |
+
multiprocess==0.70.16
|
78 |
+
ninja==1.11.1.4
|
79 |
+
numpy==2.2.5
|
80 |
+
omegaconf==2.3.0
|
81 |
+
opencv-python==4.11.0.86
|
82 |
+
orjson==3.10.16
|
83 |
+
packaging==24.2
|
84 |
+
pandas==2.2.3
|
85 |
+
parso==0.8.4
|
86 |
+
peft==0.15.2
|
87 |
+
pexpect==4.9.0
|
88 |
+
pillow==10.4.0
|
89 |
+
platformdirs==4.3.7
|
90 |
+
proglog==0.1.11
|
91 |
+
prompt_toolkit==3.0.51
|
92 |
+
propcache==0.3.1
|
93 |
+
protobuf==5.29.4
|
94 |
+
psutil==7.0.0
|
95 |
+
ptyprocess==0.7.0
|
96 |
+
pure_eval==0.2.3
|
97 |
+
py-cpuinfo==9.0.0
|
98 |
+
pyarrow==19.0.1
|
99 |
+
pycryptodome==3.22.0
|
100 |
+
pydantic==2.11.3
|
101 |
+
pydantic_core==2.33.1
|
102 |
+
pydub==0.25.1
|
103 |
+
Pygments==2.19.1
|
104 |
+
pynvml==12.0.0
|
105 |
+
pyparsing==3.2.3
|
106 |
+
python-dateutil==2.9.0.post0
|
107 |
+
python-dotenv==1.1.0
|
108 |
+
python-multipart==0.0.20
|
109 |
+
pytorch-fid==0.3.0
|
110 |
+
pytorch-lightning==2.5.1
|
111 |
+
pytz==2025.2
|
112 |
+
PyYAML==6.0.2
|
113 |
+
pyzmq==27.0.0
|
114 |
+
regex==2024.11.6
|
115 |
+
requests==2.32.3
|
116 |
+
rich==14.0.0
|
117 |
+
ruff==0.11.6
|
118 |
+
safehttpx==0.1.6
|
119 |
+
safetensors==0.5.3
|
120 |
+
scikit-image==0.25.2
|
121 |
+
scikit-learn==1.6.1
|
122 |
+
scipy==1.15.2
|
123 |
+
semantic-version==2.10.0
|
124 |
+
sentencepiece==0.2.0
|
125 |
+
sentry-sdk==2.26.1
|
126 |
+
setproctitle==1.3.5
|
127 |
+
shellingham==1.5.4
|
128 |
+
simplejson==3.20.1
|
129 |
+
six==1.17.0
|
130 |
+
smmap==5.0.2
|
131 |
+
sniffio==1.3.1
|
132 |
+
soupsieve==2.7
|
133 |
+
stack-data==0.6.3
|
134 |
+
starlette==0.46.2
|
135 |
+
sympy==1.13.1
|
136 |
+
taming-transformers==0.0.1
|
137 |
+
tensorboard==2.19.0
|
138 |
+
tokenizers==0.20.3
|
139 |
+
torch==2.6.0
|
140 |
+
torchaudio==2.6.0
|
141 |
+
torchdiffeq==0.2.5
|
142 |
+
torchmetrics==1.7.1
|
143 |
+
torchsde==0.2.6
|
144 |
+
torchvision==0.21.0
|
145 |
+
tqdm==4.67.1
|
146 |
+
transformers==4.46.2
|
147 |
+
triton==3.2.0
|
148 |
+
xformers==0.0.29.post2
|
samples/1_image1.png
ADDED
![]() |
Git LFS Details
|
samples/1_out.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa51ac0653a18dc20b9a6946aaa1a7923d58fe291e926908703c300a4d13c4a2
|
3 |
+
size 356550
|
samples/1_prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
['在海底,一个上身赤裸的的男子和一个螺旋游动的蓝鱼嬉戏。鲸鱼跟着男人手里拿的袋子绕圈,男子拿着袋子引诱着蓝鱼向前游动。Anime. High quality.']
|
samples/1_sketch1.jpg
ADDED
![]() |
Git LFS Details
|
samples/1_sketch2.jpg
ADDED
![]() |
Git LFS Details
|
samples/1_sketch3.jpg
ADDED
![]() |
Git LFS Details
|
samples/2_image1.jpg
ADDED
![]() |
Git LFS Details
|
samples/2_out.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9f28ab63b4fc5b07c0ed01f715ec671f8d839b8783dc8a432c7764bd35605f5
|
3 |
+
size 151565
|
samples/2_prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
['一个女孩和一个银发男孩种下了一颗巨大的花,随着镜头缓慢向上移动,这个巨大的花不断生长变大并开放。Anime. High quality.']
|
samples/2_sketch1.jpg
ADDED
![]() |
Git LFS Details
|
samples/2_sketch2.jpg
ADDED
![]() |
Git LFS Details
|
samples/3_image1.png
ADDED
![]() |
Git LFS Details
|
samples/3_out.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fdb131043289d831f7c4e0d3dd4a21ecc3c4eecca1bf3ae539bb14414c439cde
|
3 |
+
size 87909
|
samples/3_prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
['一个古代中国男孩拿着苹果,笑眯眯地送给旁边的老人。Anime. High quality.']
|
samples/3_sketch1.jpg
ADDED
![]() |
Git LFS Details
|
samples/ToonComposer-Icon.png
ADDED
![]() |
Git LFS Details
|
samples/ToonComposer-Method.jpg
ADDED
![]() |
Git LFS Details
|
samples/ToonComposer-TLDR.jpg
ADDED
![]() |
Git LFS Details
|
scheduler/__init__.py
ADDED
File without changes
|
scheduler/flow_match.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class FlowMatchScheduler():
|
5 |
+
|
6 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
7 |
+
self.num_train_timesteps = num_train_timesteps
|
8 |
+
self.shift = shift
|
9 |
+
self.sigma_max = sigma_max
|
10 |
+
self.sigma_min = sigma_min
|
11 |
+
self.inverse_timesteps = inverse_timesteps
|
12 |
+
self.extra_one_step = extra_one_step
|
13 |
+
self.reverse_sigmas = reverse_sigmas
|
14 |
+
self.set_timesteps(num_inference_steps)
|
15 |
+
|
16 |
+
|
17 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
18 |
+
if shift is not None:
|
19 |
+
self.shift = shift
|
20 |
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
21 |
+
if self.extra_one_step:
|
22 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
23 |
+
else:
|
24 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
25 |
+
if self.inverse_timesteps:
|
26 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
27 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
28 |
+
if self.reverse_sigmas:
|
29 |
+
self.sigmas = 1 - self.sigmas
|
30 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
31 |
+
if training:
|
32 |
+
x = self.timesteps
|
33 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
34 |
+
y_shifted = y - y.min()
|
35 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
36 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
37 |
+
|
38 |
+
|
39 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
40 |
+
if isinstance(timestep, torch.Tensor):
|
41 |
+
timestep = timestep.cpu()
|
42 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
43 |
+
sigma = self.sigmas[timestep_id]
|
44 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
45 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
46 |
+
else:
|
47 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
48 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
49 |
+
return prev_sample
|
50 |
+
|
51 |
+
|
52 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
53 |
+
if isinstance(timestep, torch.Tensor):
|
54 |
+
timestep = timestep.cpu()
|
55 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
56 |
+
sigma = self.sigmas[timestep_id]
|
57 |
+
model_output = (sample - sample_stablized) / sigma
|
58 |
+
return model_output
|
59 |
+
|
60 |
+
|
61 |
+
def add_noise(self, original_samples, noise, timestep):
|
62 |
+
if isinstance(timestep, torch.Tensor):
|
63 |
+
timestep = timestep.cpu()
|
64 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
65 |
+
sigma = self.sigmas[timestep_id]
|
66 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
67 |
+
return sample
|
68 |
+
|
69 |
+
|
70 |
+
def training_target(self, sample, noise, timestep):
|
71 |
+
target = noise - sample
|
72 |
+
return target
|
73 |
+
|
74 |
+
|
75 |
+
def training_weight(self, timestep):
|
76 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
77 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
78 |
+
return weights
|
tooncomposer.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, lightning, imageio
|
2 |
+
from peft import LoraConfig, inject_adapter_in_model
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from pipeline.i2v_pipeline import WanVideoPipeline
|
6 |
+
|
7 |
+
|
8 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
9 |
+
torch.set_float32_matmul_precision('medium')
|
10 |
+
|
11 |
+
|
12 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
13 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
14 |
+
for frame in frames:
|
15 |
+
frame = np.array(frame)
|
16 |
+
writer.append_data(frame)
|
17 |
+
writer.close()
|
18 |
+
|
19 |
+
|
20 |
+
def get_base_model_paths(base_model_name, format='dict', model_root="./weights"):
|
21 |
+
if base_model_name == "Wan2.1-I2V-14B-480P":
|
22 |
+
if format == 'list':
|
23 |
+
return [
|
24 |
+
[os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
|
25 |
+
os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
26 |
+
os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
|
27 |
+
os.path.join(model_root, "Wan2.1_VAE.pth")
|
28 |
+
]
|
29 |
+
elif format == 'dict':
|
30 |
+
return {
|
31 |
+
"dit": [os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
|
32 |
+
"image_encoder": os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
33 |
+
"text_encoder": os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
|
34 |
+
"vae": os.path.join(model_root, "Wan2.1_VAE.pth")
|
35 |
+
}
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Unsupported format: {format}")
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Unsupported base model name: {base_model_name}")
|
40 |
+
|
41 |
+
|
42 |
+
class ToonComposer(lightning.LightningModule):
|
43 |
+
def __init__(self, base_model_name="Wan2.1-I2V-14B-480P", model_root=None, learning_rate=1e-5, lora_rank=4, lora_alpha=4,
|
44 |
+
train_architecture=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2",
|
45 |
+
init_lora_weights="kaiming", use_gradient_checkpointing=True,
|
46 |
+
checkpoint_path=None, video_condition_preservation_mode="first_and_last",
|
47 |
+
tiled=False, tile_size=(34, 34), tile_stride=(18, 16), output_path=None,
|
48 |
+
use_local_lora=False, use_dera=False, dera_rank=None, use_dera_spatial=True, use_dera_temporal=True, use_sequence_cond=False, sequence_cond_mode="sparse",
|
49 |
+
use_channel_cond=False,
|
50 |
+
use_sequence_cond_position_aware_residual=False,
|
51 |
+
use_sequence_cond_loss=False, fast_dev=False,
|
52 |
+
max_num_cond_images=1, max_num_cond_sketches=2, visualize_attention=False,
|
53 |
+
random_spaced_cond_frames=False, use_sketch_mask=False, sketch_mask_ratio=0.2, no_first_sketch=False,
|
54 |
+
test_sampling_steps=15, test_sequence_cond_residual_scale=0.5, height=480, width=832):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.pipe = WanVideoPipeline(device="cpu", torch_dtype=torch.bfloat16)
|
58 |
+
self.use_local_lora = use_local_lora
|
59 |
+
self.use_dera = use_dera
|
60 |
+
self.use_dera_spatial = use_dera_spatial
|
61 |
+
self.use_dera_temporal = use_dera_temporal
|
62 |
+
self.use_sequence_cond = use_sequence_cond
|
63 |
+
self.sequence_cond_mode = sequence_cond_mode
|
64 |
+
self.use_channel_cond = use_channel_cond
|
65 |
+
self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
|
66 |
+
assert not (use_sequence_cond and use_channel_cond), "Cannot use both sequence condition and channel condition."
|
67 |
+
self.use_sequence_cond_loss = use_sequence_cond_loss
|
68 |
+
|
69 |
+
self.max_num_cond_images = max_num_cond_images
|
70 |
+
self.max_num_cond_sketches = max_num_cond_sketches
|
71 |
+
|
72 |
+
self.visualize_attention = visualize_attention
|
73 |
+
self.random_spaced_cond_frames = random_spaced_cond_frames
|
74 |
+
self.use_sketch_mask = use_sketch_mask
|
75 |
+
self.sketch_mask_ratio = sketch_mask_ratio
|
76 |
+
self.no_first_sketch = no_first_sketch
|
77 |
+
self.test_sampling_steps = test_sampling_steps
|
78 |
+
self.test_sequence_cond_residual_scale = test_sequence_cond_residual_scale
|
79 |
+
|
80 |
+
self.height = height
|
81 |
+
self.width = width
|
82 |
+
|
83 |
+
self.current_checkpoint_path = None
|
84 |
+
|
85 |
+
paths = get_base_model_paths(base_model_name, format='dict', model_root=model_root)
|
86 |
+
if use_sequence_cond:
|
87 |
+
assert sequence_cond_mode in ["sparse", "full"], f"Unsupported sequence condition model: {sequence_cond_mode}"
|
88 |
+
if sequence_cond_mode == "sparse":
|
89 |
+
if use_sketch_mask:
|
90 |
+
sequence_cond_in_dim = 24
|
91 |
+
else:
|
92 |
+
sequence_cond_in_dim = 20
|
93 |
+
else:
|
94 |
+
sequence_cond_in_dim = 20
|
95 |
+
use_channel_cond = False
|
96 |
+
channel_cond_in_dim = None
|
97 |
+
elif use_channel_cond:
|
98 |
+
channel_cond_in_dim = 20
|
99 |
+
sequence_cond_in_dim = None
|
100 |
+
use_sequence_cond = False
|
101 |
+
|
102 |
+
dit_config = {
|
103 |
+
"use_local_lora": use_local_lora,
|
104 |
+
"use_dera": use_dera,
|
105 |
+
"dera_rank": dera_rank,
|
106 |
+
"use_dera_spatial": use_dera_spatial,
|
107 |
+
"use_dera_temporal": use_dera_temporal,
|
108 |
+
"use_sequence_cond": use_sequence_cond,
|
109 |
+
"sequence_cond_mode": sequence_cond_mode,
|
110 |
+
"sequence_cond_in_dim": sequence_cond_in_dim,
|
111 |
+
"use_channel_cond": use_channel_cond,
|
112 |
+
"channel_cond_in_dim": channel_cond_in_dim,
|
113 |
+
"use_sequence_cond_position_aware_residual": use_sequence_cond_position_aware_residual,
|
114 |
+
"use_sequence_cond_loss": use_sequence_cond_loss
|
115 |
+
}
|
116 |
+
if fast_dev:
|
117 |
+
del paths["dit"]
|
118 |
+
dit_config.update({
|
119 |
+
"model_type": "i2v",
|
120 |
+
"patch_size": (1, 2, 2),
|
121 |
+
"text_len": 512,
|
122 |
+
"in_dim": 36,
|
123 |
+
"dim": 512,
|
124 |
+
"ffn_dim": 512,
|
125 |
+
"freq_dim": 256,
|
126 |
+
"text_dim": 4096,
|
127 |
+
"out_dim": 16,
|
128 |
+
"num_heads": 2, # 40
|
129 |
+
"num_layers": 40,
|
130 |
+
"window_size": (-1, -1),
|
131 |
+
"qk_norm": True,
|
132 |
+
"cross_attn_norm": True,
|
133 |
+
"eps": 1e-6,
|
134 |
+
})
|
135 |
+
self.pipe.initialize_dummy_dit(dit_config)
|
136 |
+
|
137 |
+
self.pipe.fetch_models_from_checkpoints(
|
138 |
+
paths,
|
139 |
+
config_dict={
|
140 |
+
"dit": dit_config
|
141 |
+
})
|
142 |
+
|
143 |
+
if use_sequence_cond:
|
144 |
+
self.pipe.denoising_model().copy_sequence_cond_patch_embedding_weights()
|
145 |
+
elif use_channel_cond:
|
146 |
+
self.pipe.denoising_model().copy_patch_embedding_weights_for_channel_cond()
|
147 |
+
|
148 |
+
self.freeze_parameters()
|
149 |
+
if train_architecture == "lora":
|
150 |
+
self.add_lora_to_model(
|
151 |
+
self.pipe.denoising_model(),
|
152 |
+
lora_rank=lora_rank,
|
153 |
+
lora_alpha=lora_alpha,
|
154 |
+
lora_target_modules=lora_target_modules,
|
155 |
+
init_lora_weights=init_lora_weights
|
156 |
+
)
|
157 |
+
elif train_architecture == "full":
|
158 |
+
self.pipe.denoising_model().requires_grad_(True)
|
159 |
+
|
160 |
+
if checkpoint_path is not None:
|
161 |
+
self.load_tooncomposer_checkpoint(checkpoint_path)
|
162 |
+
|
163 |
+
self.learning_rate = learning_rate
|
164 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
165 |
+
|
166 |
+
self.pipe.scheduler.set_timesteps(1000, training=True)
|
167 |
+
self.vae_tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
168 |
+
self.video_condition_preservation_mode = video_condition_preservation_mode
|
169 |
+
self.negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
170 |
+
|
171 |
+
if output_path is None:
|
172 |
+
output_path = "./"
|
173 |
+
self.output_path = output_path
|
174 |
+
|
175 |
+
def load_tooncomposer_checkpoint(self, checkpoint_path):
|
176 |
+
if checkpoint_path == self.current_checkpoint_path:
|
177 |
+
print(f"Skipping loading checkpoint {checkpoint_path} because it is the same as the current checkpoint.")
|
178 |
+
return
|
179 |
+
self.current_checkpoint_path = checkpoint_path
|
180 |
+
self.load_patch_to_model(
|
181 |
+
self.pipe.denoising_model(),
|
182 |
+
checkpoint_path
|
183 |
+
)
|
184 |
+
|
185 |
+
def update_height_width(self, height, width):
|
186 |
+
self.height = height
|
187 |
+
self.width = width
|
188 |
+
|
189 |
+
def freeze_parameters(self):
|
190 |
+
self.pipe.requires_grad_(False)
|
191 |
+
self.pipe.eval()
|
192 |
+
self.pipe.denoising_model().train()
|
193 |
+
|
194 |
+
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
|
195 |
+
self.lora_alpha = lora_alpha
|
196 |
+
if init_lora_weights == "kaiming":
|
197 |
+
init_lora_weights = True
|
198 |
+
|
199 |
+
lora_config = LoraConfig(
|
200 |
+
r=lora_rank,
|
201 |
+
lora_alpha=lora_alpha,
|
202 |
+
init_lora_weights=init_lora_weights,
|
203 |
+
target_modules=lora_target_modules.split(","),
|
204 |
+
)
|
205 |
+
model = inject_adapter_in_model(lora_config, model)
|
206 |
+
for param in model.parameters():
|
207 |
+
if param.requires_grad:
|
208 |
+
param.data = param.to(torch.float32)
|
209 |
+
|
210 |
+
def load_patch_to_model(self, model, pretrained_path, state_dict_converter=None):
|
211 |
+
if pretrained_path is not None:
|
212 |
+
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
|
213 |
+
self.loaded_global_step = 0
|
214 |
+
self.loaded_current_epoch = 0
|
215 |
+
if self.use_sketch_mask:
|
216 |
+
seq_cond_embed_weight = state_dict['sequence_cond_patch_embedding.weight']
|
217 |
+
current_in_channels = self.pipe.denoising_model().sequence_cond_patch_embedding.in_channels
|
218 |
+
if current_in_channels == 24 and seq_cond_embed_weight.shape[1] == 20:
|
219 |
+
new_weight = torch.zeros(
|
220 |
+
seq_cond_embed_weight.shape[0],
|
221 |
+
4,
|
222 |
+
*seq_cond_embed_weight.shape[2:],
|
223 |
+
dtype=seq_cond_embed_weight.dtype
|
224 |
+
)
|
225 |
+
state_dict['sequence_cond_patch_embedding.weight'] = torch.cat([
|
226 |
+
seq_cond_embed_weight, new_weight], dim=1)
|
227 |
+
|
228 |
+
if state_dict_converter is not None:
|
229 |
+
state_dict = state_dict_converter(state_dict)
|
230 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
231 |
+
all_keys = [i for i, _ in model.named_parameters()]
|
232 |
+
num_updated_keys = len(all_keys) - len(missing_keys)
|
233 |
+
num_unexpected_keys = len(unexpected_keys)
|
234 |
+
print(f"[Checkpoint] {num_updated_keys} parameters are loaded from {pretrained_path}. {num_unexpected_keys} parameters are unexpected.")
|
util/model_util.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
from safetensors import safe_open
|
3 |
+
from contextlib import contextmanager
|
4 |
+
import hashlib
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from matplotlib.colors import LinearSegmentedColormap
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
@contextmanager
|
10 |
+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
11 |
+
|
12 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
13 |
+
if include_buffers:
|
14 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
15 |
+
|
16 |
+
def register_empty_parameter(module, name, param):
|
17 |
+
old_register_parameter(module, name, param)
|
18 |
+
if param is not None:
|
19 |
+
param_cls = type(module._parameters[name])
|
20 |
+
kwargs = module._parameters[name].__dict__
|
21 |
+
kwargs["requires_grad"] = param.requires_grad
|
22 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
23 |
+
|
24 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
25 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
26 |
+
if buffer is not None:
|
27 |
+
module._buffers[name] = module._buffers[name].to(device)
|
28 |
+
|
29 |
+
def patch_tensor_constructor(fn):
|
30 |
+
def wrapper(*args, **kwargs):
|
31 |
+
kwargs["device"] = device
|
32 |
+
return fn(*args, **kwargs)
|
33 |
+
|
34 |
+
return wrapper
|
35 |
+
|
36 |
+
if include_buffers:
|
37 |
+
tensor_constructors_to_patch = {
|
38 |
+
torch_function_name: getattr(torch, torch_function_name)
|
39 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
40 |
+
}
|
41 |
+
else:
|
42 |
+
tensor_constructors_to_patch = {}
|
43 |
+
|
44 |
+
try:
|
45 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
46 |
+
if include_buffers:
|
47 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
48 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
49 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
50 |
+
yield
|
51 |
+
finally:
|
52 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
53 |
+
if include_buffers:
|
54 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
55 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
56 |
+
setattr(torch, torch_function_name, old_torch_function)
|
57 |
+
|
58 |
+
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
59 |
+
state_dict = {}
|
60 |
+
for file_name in os.listdir(file_path):
|
61 |
+
if "." in file_name and file_name.split(".")[-1] in [
|
62 |
+
"safetensors", "bin", "ckpt", "pth", "pt"
|
63 |
+
]:
|
64 |
+
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
65 |
+
return state_dict
|
66 |
+
|
67 |
+
|
68 |
+
def load_state_dict(file_path, torch_dtype=None):
|
69 |
+
if file_path.endswith(".safetensors"):
|
70 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
71 |
+
else:
|
72 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
73 |
+
|
74 |
+
|
75 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
76 |
+
state_dict = {}
|
77 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
78 |
+
for k in f.keys():
|
79 |
+
state_dict[k] = f.get_tensor(k)
|
80 |
+
if torch_dtype is not None:
|
81 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
82 |
+
return state_dict
|
83 |
+
|
84 |
+
|
85 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
86 |
+
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
87 |
+
if torch_dtype is not None:
|
88 |
+
for i in state_dict:
|
89 |
+
if isinstance(state_dict[i], torch.Tensor):
|
90 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
91 |
+
return state_dict
|
92 |
+
|
93 |
+
|
94 |
+
def search_for_embeddings(state_dict):
|
95 |
+
embeddings = []
|
96 |
+
for k in state_dict:
|
97 |
+
if isinstance(state_dict[k], torch.Tensor):
|
98 |
+
embeddings.append(state_dict[k])
|
99 |
+
elif isinstance(state_dict[k], dict):
|
100 |
+
embeddings += search_for_embeddings(state_dict[k])
|
101 |
+
return embeddings
|
102 |
+
|
103 |
+
|
104 |
+
def search_parameter(param, state_dict):
|
105 |
+
for name, param_ in state_dict.items():
|
106 |
+
if param.numel() == param_.numel():
|
107 |
+
if param.shape == param_.shape:
|
108 |
+
if torch.dist(param, param_) < 1e-3:
|
109 |
+
return name
|
110 |
+
else:
|
111 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
112 |
+
return name
|
113 |
+
return None
|
114 |
+
|
115 |
+
|
116 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
117 |
+
matched_keys = set()
|
118 |
+
with torch.no_grad():
|
119 |
+
for name in source_state_dict:
|
120 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
121 |
+
if rename is not None:
|
122 |
+
print(f'"{name}": "{rename}",')
|
123 |
+
matched_keys.add(rename)
|
124 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
125 |
+
length = source_state_dict[name].shape[0] // 3
|
126 |
+
rename = []
|
127 |
+
for i in range(3):
|
128 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
129 |
+
if None not in rename:
|
130 |
+
print(f'"{name}": {rename},')
|
131 |
+
for rename_ in rename:
|
132 |
+
matched_keys.add(rename_)
|
133 |
+
for name in target_state_dict:
|
134 |
+
if name not in matched_keys:
|
135 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
136 |
+
|
137 |
+
|
138 |
+
def search_for_files(folder, extensions):
|
139 |
+
files = []
|
140 |
+
if os.path.isdir(folder):
|
141 |
+
for file in sorted(os.listdir(folder)):
|
142 |
+
files += search_for_files(os.path.join(folder, file), extensions)
|
143 |
+
elif os.path.isfile(folder):
|
144 |
+
for extension in extensions:
|
145 |
+
if folder.endswith(extension):
|
146 |
+
files.append(folder)
|
147 |
+
break
|
148 |
+
return files
|
149 |
+
|
150 |
+
|
151 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
152 |
+
keys = []
|
153 |
+
for key, value in state_dict.items():
|
154 |
+
if isinstance(key, str):
|
155 |
+
if isinstance(value, torch.Tensor):
|
156 |
+
if with_shape:
|
157 |
+
shape = "_".join(map(str, list(value.shape)))
|
158 |
+
keys.append(key + ":" + shape)
|
159 |
+
keys.append(key)
|
160 |
+
elif isinstance(value, dict):
|
161 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
162 |
+
keys.sort()
|
163 |
+
keys_str = ",".join(keys)
|
164 |
+
return keys_str
|
165 |
+
|
166 |
+
|
167 |
+
def split_state_dict_with_prefix(state_dict):
|
168 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
169 |
+
prefix_dict = {}
|
170 |
+
for key in keys:
|
171 |
+
prefix = key if "." not in key else key.split(".")[0]
|
172 |
+
if prefix not in prefix_dict:
|
173 |
+
prefix_dict[prefix] = []
|
174 |
+
prefix_dict[prefix].append(key)
|
175 |
+
state_dicts = []
|
176 |
+
for prefix, keys in prefix_dict.items():
|
177 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
178 |
+
state_dicts.append(sub_state_dict)
|
179 |
+
return state_dicts
|
180 |
+
|
181 |
+
|
182 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
183 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
184 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
185 |
+
return hashlib.md5(keys_str).hexdigest()
|
186 |
+
|
187 |
+
|
188 |
+
def save_attention_maps(model, output_path, batch_idx, timestep, layer_indices=None):
|
189 |
+
"""
|
190 |
+
Visualize and save the attention maps from selected layers of the model
|
191 |
+
|
192 |
+
Args:
|
193 |
+
model: The DiT model with attention maps stored
|
194 |
+
output_path: Directory to save visualizations
|
195 |
+
batch_idx: Current batch index for file naming
|
196 |
+
layer_indices: List of layer indices to visualize (if None, visualize all)
|
197 |
+
"""
|
198 |
+
timestep = int(float(str(timestep)))
|
199 |
+
os.makedirs(os.path.join(output_path, "attention_maps"), exist_ok=True)
|
200 |
+
|
201 |
+
# If layer indices not specified, visualize all layers
|
202 |
+
if layer_indices is None:
|
203 |
+
layer_indices = range(len(model.blocks))
|
204 |
+
|
205 |
+
# Create a custom colormap (similar to the ones used in attention visualization papers)
|
206 |
+
colors = [(0, 0, 0.5), (0, 0, 1), (0, 0.5, 1), (0, 1, 1),
|
207 |
+
(0.5, 1, 0.5), (1, 1, 0), (1, 0.5, 0), (1, 0, 0), (0.5, 0, 0)]
|
208 |
+
attention_cmap = LinearSegmentedColormap.from_list('attention_cmap', colors)
|
209 |
+
|
210 |
+
for i in layer_indices:
|
211 |
+
if not hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
|
212 |
+
continue
|
213 |
+
|
214 |
+
attn_map = model.blocks[i].self_attn._last_attn_maps
|
215 |
+
grid_size = model.blocks[i].self_attn._last_grid_sizes
|
216 |
+
seq_len = model.blocks[i].self_attn._last_seq_lens
|
217 |
+
# attn_maps.shape=[s, s]
|
218 |
+
np.savez_compressed(os.path.join(output_path,
|
219 |
+
"attention_maps",
|
220 |
+
f"attn_maps_layer{i}_batch{batch_idx}_t{timestep}.npz"),
|
221 |
+
attn_map=attn_map, grid_size=grid_size, seq_len=seq_len)
|
222 |
+
|
223 |
+
print(f"Saving Layer {i}, Batch {batch_idx} attention maps")
|
224 |
+
attn_map -= attn_map.min()
|
225 |
+
attn_map /= attn_map.max()
|
226 |
+
plt.figure(figsize=(10, 8))
|
227 |
+
plt.imshow(attn_map ** 0.25, cmap=attention_cmap)
|
228 |
+
plt.colorbar(label='Attention Weight')
|
229 |
+
plt.title(f'Layer {i}, Batch {batch_idx} (Average)')
|
230 |
+
save_path = os.path.join(
|
231 |
+
output_path,
|
232 |
+
"attention_maps",
|
233 |
+
f"attn_map_layer{i}_average_batch{batch_idx}_t{timestep}.png"
|
234 |
+
)
|
235 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
236 |
+
plt.close()
|
237 |
+
|
238 |
+
# Clean up the stored attention maps to free memory
|
239 |
+
for i in layer_indices:
|
240 |
+
if hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
|
241 |
+
del model.blocks[i].self_attn._last_attn_maps
|
util/optical_flow.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
|
6 |
+
from typing import List, Tuple, Dict
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
from sklearn.cluster import KMeans
|
10 |
+
from tqdm import tqdm
|
11 |
+
import os
|
12 |
+
|
13 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '64'
|
14 |
+
|
15 |
+
class OpticalFlowAnalyzer:
|
16 |
+
def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
17 |
+
self.device = device
|
18 |
+
self.model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
|
19 |
+
self.model.eval()
|
20 |
+
|
21 |
+
def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
|
22 |
+
"""Preprocess a frame for RAFT model."""
|
23 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
24 |
+
frame = torch.from_numpy(frame).permute(2, 0, 1).float()
|
25 |
+
frame = frame.unsqueeze(0) / 255.0
|
26 |
+
return frame.to(self.device)
|
27 |
+
|
28 |
+
def compute_optical_flow(self, frame1: np.ndarray, frame2: np.ndarray) -> np.ndarray:
|
29 |
+
"""Compute optical flow between two consecutive frames."""
|
30 |
+
with torch.no_grad():
|
31 |
+
frame1_tensor = self.preprocess_frame(frame1)
|
32 |
+
frame2_tensor = self.preprocess_frame(frame2)
|
33 |
+
|
34 |
+
flow = self.model(frame1_tensor, frame2_tensor)[-1]
|
35 |
+
flow = flow[0].permute(1, 2, 0).cpu().numpy()
|
36 |
+
|
37 |
+
return flow
|
38 |
+
|
39 |
+
def analyze_motion_regions(self, flow: np.ndarray, num_clusters: int = 3) -> Tuple[np.ndarray, Dict]:
|
40 |
+
"""Cluster motion regions based on optical flow magnitude and direction."""
|
41 |
+
h, w = flow.shape[:2]
|
42 |
+
magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
|
43 |
+
direction = np.arctan2(flow[..., 1], flow[..., 0])
|
44 |
+
|
45 |
+
# Create feature matrix for clustering
|
46 |
+
features = np.zeros((h * w, 3))
|
47 |
+
features[:, 0] = magnitude.ravel()
|
48 |
+
features[:, 1] = np.cos(direction).ravel()
|
49 |
+
features[:, 2] = np.sin(direction).ravel()
|
50 |
+
|
51 |
+
# Normalize features
|
52 |
+
features = (features - features.mean(axis=0)) / features.std(axis=0)
|
53 |
+
|
54 |
+
# Perform clustering
|
55 |
+
kmeans = KMeans(n_clusters=num_clusters, random_state=42,)
|
56 |
+
labels = kmeans.fit_predict(features)
|
57 |
+
labels = labels.reshape(h, w)
|
58 |
+
|
59 |
+
# Analyze clusters
|
60 |
+
cluster_stats = {}
|
61 |
+
for i in range(num_clusters):
|
62 |
+
cluster_mask = (labels == i)
|
63 |
+
cluster_magnitude = magnitude[cluster_mask]
|
64 |
+
cluster_stats[i] = {
|
65 |
+
'mean_magnitude': np.mean(cluster_magnitude),
|
66 |
+
'std_magnitude': np.std(cluster_magnitude),
|
67 |
+
'pixel_count': np.sum(cluster_mask),
|
68 |
+
'is_static': np.mean(cluster_magnitude) < 0.1 # Threshold for static regions
|
69 |
+
}
|
70 |
+
|
71 |
+
return labels, cluster_stats
|
72 |
+
|
73 |
+
def process_video(self, video_path: str, output_path: str = None) -> List[Tuple[np.ndarray, Dict]]:
|
74 |
+
"""Process a video and return motion analysis results for each frame pair."""
|
75 |
+
cap = cv2.VideoCapture(video_path)
|
76 |
+
if not cap.isOpened():
|
77 |
+
raise ValueError(f"Could not open video: {video_path}")
|
78 |
+
|
79 |
+
results = []
|
80 |
+
ret, prev_frame = cap.read()
|
81 |
+
if not ret:
|
82 |
+
raise ValueError("Could not read first frame")
|
83 |
+
|
84 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
85 |
+
pbar = tqdm(total=total_frames-1, desc="Processing video")
|
86 |
+
|
87 |
+
while True:
|
88 |
+
ret, curr_frame = cap.read()
|
89 |
+
if not ret:
|
90 |
+
break
|
91 |
+
|
92 |
+
flow = self.compute_optical_flow(prev_frame, curr_frame)
|
93 |
+
labels, stats = self.analyze_motion_regions(flow)
|
94 |
+
|
95 |
+
if output_path:
|
96 |
+
# Visualize results
|
97 |
+
vis_frame = curr_frame.copy()
|
98 |
+
for i, stat in stats.items():
|
99 |
+
if not stat['is_static']:
|
100 |
+
mask = (labels == i).astype(np.uint8) * 255
|
101 |
+
print("mask:",mask.shape)
|
102 |
+
print("vis_frame:",vis_frame.shape)
|
103 |
+
mask = np.expand_dims(mask, axis=-1).repeat(3, axis=-1)
|
104 |
+
print("mask:",mask.shape)
|
105 |
+
|
106 |
+
vis_frame[mask > 0] = cv2.addWeighted(vis_frame[mask > 0], 0.7, 255, 0.3, 0)
|
107 |
+
|
108 |
+
cv2.imwrite(f"{output_path}/frame_{len(results):04d}.jpg", vis_frame)
|
109 |
+
|
110 |
+
results.append((labels, stats))
|
111 |
+
prev_frame = curr_frame
|
112 |
+
pbar.update(1)
|
113 |
+
|
114 |
+
cap.release()
|
115 |
+
pbar.close()
|
116 |
+
return results
|
117 |
+
|
118 |
+
def main():
|
119 |
+
parser = argparse.ArgumentParser(description='Analyze motion regions in a video using RAFT optical flow')
|
120 |
+
parser.add_argument('--video', type=str, required=True, help='Path to input video')
|
121 |
+
parser.add_argument('--output', type=str, help='Path to output directory for visualization')
|
122 |
+
parser.add_argument('--clusters', type=int, default=3, help='Number of motion clusters')
|
123 |
+
args = parser.parse_args()
|
124 |
+
|
125 |
+
analyzer = OpticalFlowAnalyzer()
|
126 |
+
results = analyzer.process_video(args.video, args.output)
|
127 |
+
|
128 |
+
# Print summary statistics
|
129 |
+
print("\nMotion Analysis Summary:")
|
130 |
+
for i, (_, stats) in enumerate(results):
|
131 |
+
print(f"\nFrame {i+1}:")
|
132 |
+
for cluster_id, stat in stats.items():
|
133 |
+
motion_type = "Static" if stat['is_static'] else "Moving"
|
134 |
+
print(f" Cluster {cluster_id} ({motion_type}):")
|
135 |
+
print(f" Mean magnitude: {stat['mean_magnitude']:.4f}")
|
136 |
+
print(f" Pixel count: {stat['pixel_count']}")
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
140 |
+
|
util/stylesheets.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
util/training_util.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def create_random_mask(batch_size, num_frames, height, width, device, dtype, shape_type=None):
|
10 |
+
"""
|
11 |
+
Create random masks for sketch frames.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
batch_size: Batch size
|
15 |
+
num_frames: Number of frames to mask
|
16 |
+
height, width: Image dimensions
|
17 |
+
device: Device for tensor
|
18 |
+
dtype: Data type for tensor
|
19 |
+
mask_area_ratio: Ratio of area to mask (0-1)
|
20 |
+
shape_type: Type of shape for masking ('square', 'circle', 'random'). If None, one is randomly selected.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Mask tensor in [b, 1, num_frames, height, width] where 0 indicates areas to mask (inverse of previous implementation)
|
24 |
+
"""
|
25 |
+
# Initialize with ones (unmasked)
|
26 |
+
masks = torch.ones(batch_size, 1, num_frames, height, width, device=device, dtype=dtype)
|
27 |
+
|
28 |
+
for b in range(batch_size):
|
29 |
+
for f in range(num_frames):
|
30 |
+
# Randomly select shape type if not specified
|
31 |
+
if shape_type is None:
|
32 |
+
shape_type = random.choice(['square', 'circle', 'random'])
|
33 |
+
|
34 |
+
# Create numpy mask for easier shape drawing
|
35 |
+
mask = np.zeros((height, width), dtype=np.float32)
|
36 |
+
|
37 |
+
if shape_type == 'square':
|
38 |
+
# Random squares
|
39 |
+
num_squares = random.randint(1, 5)
|
40 |
+
for _ in range(num_squares):
|
41 |
+
# Random square size (proportional to image dimensions)
|
42 |
+
max_size = min(height, width)
|
43 |
+
size = random.randint(max_size // 4, max_size)
|
44 |
+
|
45 |
+
# Random position
|
46 |
+
x = random.randint(0, width - size)
|
47 |
+
y = random.randint(0, height - size)
|
48 |
+
|
49 |
+
# Draw square
|
50 |
+
mask[y:y+size, x:x+size] = 1.0
|
51 |
+
|
52 |
+
elif shape_type == 'circle':
|
53 |
+
# Random circles
|
54 |
+
num_circles = random.randint(1, 5)
|
55 |
+
for _ in range(num_circles):
|
56 |
+
# Random radius (proportional to image dimensions)
|
57 |
+
max_radius = min(height, width) // 2
|
58 |
+
radius = random.randint(max_radius // 4, max_radius)
|
59 |
+
|
60 |
+
# Random center
|
61 |
+
center_x = random.randint(radius, width - radius)
|
62 |
+
center_y = random.randint(radius, height - radius)
|
63 |
+
|
64 |
+
# Draw circle
|
65 |
+
cv2.circle(mask, (center_x, center_y), radius, 1.0, -1)
|
66 |
+
|
67 |
+
elif shape_type == 'random':
|
68 |
+
# Create connected random shape with cv2
|
69 |
+
num_points = random.randint(5, 16)
|
70 |
+
points = []
|
71 |
+
|
72 |
+
# Generate random points
|
73 |
+
for _ in range(num_points):
|
74 |
+
x = random.randint(0, width - 1)
|
75 |
+
y = random.randint(0, height - 1)
|
76 |
+
points.append([x, y])
|
77 |
+
|
78 |
+
# Convert to numpy array for cv2
|
79 |
+
points = np.array(points, dtype=np.int32)
|
80 |
+
|
81 |
+
# Draw filled polygon
|
82 |
+
cv2.fillPoly(mask, [points], 1.0)
|
83 |
+
|
84 |
+
# Convert numpy mask to tensor and subtract from ones (inverse the mask)
|
85 |
+
masks[b, 0, f] = 1.0 - torch.from_numpy(mask).to(device=device, dtype=dtype)
|
86 |
+
|
87 |
+
return masks
|
88 |
+
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def extract_img_to_sketch(_sketch_model, _img, model_name="random"):
|
92 |
+
"""
|
93 |
+
Return sketch: [-1, 1]
|
94 |
+
"""
|
95 |
+
orig_shape = (_img.shape[-2], _img.shape[-1])
|
96 |
+
with torch.amp.autocast(dtype=torch.float32, device_type="cuda"):
|
97 |
+
reshaped_img = torch.nn.functional.interpolate(_img, (2048, 2048))
|
98 |
+
sketch = _sketch_model(reshaped_img, model_name=model_name)
|
99 |
+
sketch = torch.nn.functional.interpolate(sketch, orig_shape)
|
100 |
+
if sketch.shape[1] == 1:
|
101 |
+
sketch = sketch.repeat(1, 3, 1, 1)
|
102 |
+
return sketch
|
103 |
+
|
104 |
+
|
105 |
+
def video_to_frame_and_sketch(
|
106 |
+
sketch_model,
|
107 |
+
original_video,
|
108 |
+
max_num_preserved_sketch_frames=2,
|
109 |
+
max_num_preserved_image_frames=1,
|
110 |
+
min_num_preserved_sketch_frames=2,
|
111 |
+
min_num_preserved_image_frames=1,
|
112 |
+
model_name=None,
|
113 |
+
detach_image_and_sketch=False,
|
114 |
+
equally_spaced_preserve_sketch=False,
|
115 |
+
apply_sketch_mask=False,
|
116 |
+
sketch_mask_ratio=0.2,
|
117 |
+
sketch_mask_shape=None,
|
118 |
+
no_first_sketch: Union[bool, float] = False,
|
119 |
+
video_clip_names=None,
|
120 |
+
is_flux_sketch_available=None,
|
121 |
+
is_evaluation=False,
|
122 |
+
):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
sketch_model: torch.nn.Module, a sketch pool for extracting sketches from images
|
126 |
+
original_video: torch.Tensor, shape=(batch_size, num_channels, num_frames, height, width)
|
127 |
+
max_num_preserved_sketch_frames: int, maximum number of preserved sketch frames
|
128 |
+
max_num_preserved_image_frames: int, maximum number of preserved image frames
|
129 |
+
min_num_preserved_sketch_frames: int, minimum number of preserved sketch frames
|
130 |
+
min_num_preserved_image_frames: int, minimum number of preserved image frames
|
131 |
+
model_name: str, name of the sketch model. If None, randomly select from ["lineart", "lineart_anime", "anime2sketch"]. Default: None.
|
132 |
+
equally_spaced_preserve_sketch: bool, whether to preserve sketches at equally spaced intervals. Default: False.
|
133 |
+
apply_sketch_mask: bool, whether to apply random masking to sketch frames. Default: False.
|
134 |
+
sketch_mask_ratio: float, ratio of frames to mask (0-1). Default: 0.2.
|
135 |
+
sketch_mask_shape: str, shape type for masking ('square', 'circle', 'random'). If None, randomly selected. Default: None.
|
136 |
+
Returns:
|
137 |
+
conditional_image: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
|
138 |
+
preserving_image_mask: torch.Tensor, shape=(batch_size, num_frames, 1, height, width)
|
139 |
+
full_sketch_frames: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
|
140 |
+
sketch_local_mask: torch.Tensor, shape=(batch_size, 1, num_frames, height, width) or None if apply_sketch_mask=False
|
141 |
+
"""
|
142 |
+
video_shape = original_video.shape
|
143 |
+
video_dtype = original_video.dtype
|
144 |
+
video_device = original_video.device
|
145 |
+
|
146 |
+
if min_num_preserved_sketch_frames is None or min_num_preserved_sketch_frames < 2:
|
147 |
+
min_num_preserved_sketch_frames = 2 # Minimum num: 2 (the first and the last)
|
148 |
+
num_preserved_sketch_frames = random.randint(min_num_preserved_sketch_frames, max_num_preserved_sketch_frames)
|
149 |
+
num_preserved_sketch_frames = min(num_preserved_sketch_frames, video_shape[2])
|
150 |
+
|
151 |
+
# Always include first and last frames
|
152 |
+
if video_clip_names is not None and is_flux_sketch_available is not None:
|
153 |
+
if is_flux_sketch_available[0]:
|
154 |
+
num_preserved_sketch_frames = 2
|
155 |
+
|
156 |
+
if isinstance(no_first_sketch, float):
|
157 |
+
no_first_sketch = random.random() < no_first_sketch
|
158 |
+
|
159 |
+
if equally_spaced_preserve_sketch:
|
160 |
+
preserved_sketch_indices = torch.linspace(0, video_shape[2] - 1, num_preserved_sketch_frames).long().tolist()
|
161 |
+
if no_first_sketch:
|
162 |
+
preserved_sketch_indices = preserved_sketch_indices[1:]
|
163 |
+
else:
|
164 |
+
if no_first_sketch:
|
165 |
+
preserved_sketch_indices = [video_shape[2] - 1]
|
166 |
+
else:
|
167 |
+
preserved_sketch_indices = [0, video_shape[2] - 1]
|
168 |
+
# If we need more frames than just first and last
|
169 |
+
if num_preserved_sketch_frames > 2 and video_shape[2] > 4:
|
170 |
+
# Create set of all valid candidates (excluding first, last and their adjacent frames)
|
171 |
+
# Exclude indices adjacent to first and last
|
172 |
+
candidates = set(range(2, video_shape[2] - 2))
|
173 |
+
|
174 |
+
# Determine how many additional frames to select
|
175 |
+
additional_frames_needed = min(num_preserved_sketch_frames - 2, len(candidates))
|
176 |
+
|
177 |
+
# Keep selecting frames until we have enough or run out of candidates
|
178 |
+
additional_indices = []
|
179 |
+
while len(additional_indices) < additional_frames_needed and candidates:
|
180 |
+
# Convert set to list for random selection
|
181 |
+
candidate_list = list(candidates)
|
182 |
+
# Select a random candidate
|
183 |
+
idx = random.choice(candidate_list)
|
184 |
+
additional_indices.append(idx)
|
185 |
+
|
186 |
+
# Remove selected index and adjacent indices from candidates
|
187 |
+
candidates.remove(idx)
|
188 |
+
if idx - 1 in candidates:
|
189 |
+
candidates.remove(idx - 1)
|
190 |
+
if idx + 1 in candidates:
|
191 |
+
candidates.remove(idx + 1)
|
192 |
+
|
193 |
+
preserved_sketch_indices.extend(additional_indices)
|
194 |
+
preserved_sketch_indices.sort()
|
195 |
+
|
196 |
+
# Indices to preserve has been determined.
|
197 |
+
# Later code will not care the number of preserved frames but rely on the indices only.
|
198 |
+
preserved_image_indices = [0]
|
199 |
+
if max_num_preserved_image_frames is not None and max_num_preserved_image_frames > 1:
|
200 |
+
max_num_preserved_image_frames -= 1
|
201 |
+
if min_num_preserved_image_frames is None or min_num_preserved_image_frames < 1:
|
202 |
+
min_num_preserved_image_frames = 1
|
203 |
+
min_num_preserved_image_frames -= 1
|
204 |
+
other_indices = torch.tensor([i for i in range(video_shape[2]) if i not in preserved_sketch_indices])
|
205 |
+
max_num_preserved_image_frames = min(max_num_preserved_image_frames, len(other_indices))
|
206 |
+
min_num_preserved_image_frames = min(min_num_preserved_image_frames, max_num_preserved_image_frames)
|
207 |
+
num_preserved_image_frames = random.randint(min_num_preserved_image_frames, max_num_preserved_image_frames)
|
208 |
+
other_indices = other_indices[torch.randperm(len(other_indices))]
|
209 |
+
if num_preserved_image_frames > 0:
|
210 |
+
preserved_image_indices.extend(other_indices[:num_preserved_image_frames])
|
211 |
+
|
212 |
+
preserved_condition_mask = torch.zeros(size=(video_shape[0], video_shape[2]), dtype=video_dtype, device=video_device) # [b, t]
|
213 |
+
masked_condition_video = torch.zeros_like(original_video) # [b, c, t, h, w]
|
214 |
+
full_sketch_frames = torch.zeros_like(original_video) # [b, c, t, h, w]
|
215 |
+
|
216 |
+
if detach_image_and_sketch:
|
217 |
+
preserved_condition_mask_sketch = torch.zeros_like(preserved_condition_mask)
|
218 |
+
masked_condition_video_sketch = torch.zeros_like(masked_condition_video)
|
219 |
+
if 0 not in preserved_sketch_indices and not no_first_sketch:
|
220 |
+
preserved_sketch_indices.append(0)
|
221 |
+
else:
|
222 |
+
preserved_condition_mask_sketch = None
|
223 |
+
masked_condition_video_sketch = None
|
224 |
+
|
225 |
+
for _idx in preserved_image_indices:
|
226 |
+
preserved_condition_mask[:, _idx] = 1.0
|
227 |
+
masked_condition_video[:, :, _idx, :, :] = original_video[:, :, _idx, :, :]
|
228 |
+
|
229 |
+
# Set up sketch_local_mask if masking is applied
|
230 |
+
sketch_local_mask = None
|
231 |
+
|
232 |
+
if apply_sketch_mask:
|
233 |
+
# Create a full-sized mask initialized to all ones (unmasked)
|
234 |
+
sketch_local_mask = torch.ones(
|
235 |
+
video_shape[0], video_shape[2], video_shape[3], video_shape[4],
|
236 |
+
device=video_device,
|
237 |
+
dtype=video_dtype
|
238 |
+
).unsqueeze(1) # Add channel dimension to get [b, 1, t, h, w]
|
239 |
+
|
240 |
+
if not is_evaluation and random.random() < sketch_mask_ratio:
|
241 |
+
# For preserved frames, apply random masking
|
242 |
+
for i, frame_idx in enumerate(preserved_sketch_indices):
|
243 |
+
if i == 0:
|
244 |
+
# First frame is not masked
|
245 |
+
continue
|
246 |
+
# Create masks only for preserved frames
|
247 |
+
frame_masks = create_random_mask(
|
248 |
+
batch_size=video_shape[0],
|
249 |
+
num_frames=1, # Just one frame at a time
|
250 |
+
height=video_shape[3],
|
251 |
+
width=video_shape[4],
|
252 |
+
device=video_device,
|
253 |
+
dtype=video_dtype,
|
254 |
+
# mask_area_ratio=0.4 * random.random() + 0.1,
|
255 |
+
shape_type=sketch_mask_shape
|
256 |
+
)
|
257 |
+
|
258 |
+
# Set the mask for this preserved frame
|
259 |
+
sketch_local_mask[:, :, frame_idx:frame_idx+1, :, :] = frame_masks
|
260 |
+
|
261 |
+
# Produce sketches for preserved frames
|
262 |
+
# Sketches can either be 1) calculated from sketch pool or 2) loaded from the flux sketch directory
|
263 |
+
if is_flux_sketch_available is not None and is_flux_sketch_available[0]:
|
264 |
+
should_use_flux_sketch = random.random() < 0.75 if not is_evaluation else True
|
265 |
+
else:
|
266 |
+
should_use_flux_sketch = False
|
267 |
+
|
268 |
+
cur_model_name = "flux" if should_use_flux_sketch else random.choice(["lineart", "lineart_anime", "anime2sketch"]) if model_name is None else model_name # "anime2sketch"
|
269 |
+
# cur_model_name = "anyline"
|
270 |
+
for _idx in preserved_sketch_indices:
|
271 |
+
sketch_frame = None
|
272 |
+
if should_use_flux_sketch:
|
273 |
+
# Load flux sketch
|
274 |
+
sketech_path = f"/group/40005/gzhiwang/iclora/linearts/{video_clip_names[0]}/{_idx}.lineart.png"
|
275 |
+
print(f"Loading flux sketch from {sketech_path}...")
|
276 |
+
if os.path.exists(sketech_path):
|
277 |
+
sketch_frame = cv2.imread(sketech_path)
|
278 |
+
sketch_frame = cv2.cvtColor(sketch_frame, cv2.COLOR_BGR2RGB)
|
279 |
+
# resize to 480p
|
280 |
+
sketch_frame = cv2.resize(sketch_frame, (video_shape[4], video_shape[3]))
|
281 |
+
sketch_frame = torch.from_numpy(sketch_frame).to(video_device, dtype=video_dtype)
|
282 |
+
# Normalize to [-1, 1]
|
283 |
+
sketch_frame = sketch_frame / 255.0 * 2.0 - 1.0
|
284 |
+
sketch_frame = sketch_frame.permute(2, 0, 1)
|
285 |
+
sketch_frame = sketch_frame.unsqueeze(0)
|
286 |
+
else:
|
287 |
+
print(f"FLUX Sketch path {sketech_path} does not exist. Falling back to sketch pool.")
|
288 |
+
# raise ValueError(f"FLUX Sketch path {sketech_path} does not exist.")
|
289 |
+
if sketch_frame is None:
|
290 |
+
# Calculate sketch from sketch pool
|
291 |
+
sketch_frame = extract_img_to_sketch(
|
292 |
+
sketch_model, original_video[:, :, _idx, :, :].float(),
|
293 |
+
model_name=cur_model_name).to(video_device, dtype=video_dtype)
|
294 |
+
# Convert white BG (from sketch pool or loaded from flux sketch files) to black BG (for training)
|
295 |
+
sketch_frame = -torch.clip(sketch_frame, -1, 1)
|
296 |
+
full_sketch_frames[:, :, _idx, :, :] = sketch_frame
|
297 |
+
|
298 |
+
if len(preserved_sketch_indices) > 0:
|
299 |
+
_mask_to_add = preserved_condition_mask_sketch if detach_image_and_sketch else preserved_condition_mask
|
300 |
+
_video_to_add = masked_condition_video_sketch if detach_image_and_sketch else masked_condition_video
|
301 |
+
if not detach_image_and_sketch:
|
302 |
+
preserved_sketch_indices = preserved_sketch_indices[1:]
|
303 |
+
|
304 |
+
# Apply masking to sketch frames if required
|
305 |
+
if apply_sketch_mask and sketch_local_mask is not None:
|
306 |
+
# sketch_local_mask: [b, 1, t, h, w]
|
307 |
+
for _idx in preserved_sketch_indices:
|
308 |
+
_mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
|
309 |
+
_video_to_add[:, :, _idx, :, :] = torch.where(sketch_local_mask[:, 0:1, _idx, :, :] == 0, -1.0, full_sketch_frames[:, :, _idx, :, :])
|
310 |
+
else:
|
311 |
+
for _idx in preserved_sketch_indices:
|
312 |
+
_mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
|
313 |
+
_video_to_add[:, :, _idx, :, :] = full_sketch_frames[:, :, _idx, :, :]
|
314 |
+
|
315 |
+
return masked_condition_video, preserved_condition_mask, masked_condition_video_sketch, preserved_condition_mask_sketch, full_sketch_frames, sketch_local_mask, cur_model_name
|
316 |
+
|
317 |
+
|