l-li commited on
Commit
00274d1
·
1 Parent(s): dcac773

update requirements.

Browse files
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samples/1_out.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ samples/2_out.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ samples/3_out.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ samples/1_image1.png filter=lfs diff=lfs merge=lfs -text
40
+ samples/3_image1.png filter=lfs diff=lfs merge=lfs -text
41
+ samples/ToonComposer-Icon.png filter=lfs diff=lfs merge=lfs -text
42
+ samples/1_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
43
+ samples/1_sketch3.jpg filter=lfs diff=lfs merge=lfs -text
44
+ samples/2_image1.jpg filter=lfs diff=lfs merge=lfs -text
45
+ samples/1_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
46
+ samples/2_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
47
+ samples/2_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
48
+ samples/3_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
49
+ samples/ToonComposer-Method.jpg filter=lfs diff=lfs merge=lfs -text
50
+ samples/ToonComposer-TLDR.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making ToonComposer available.
2
+
3
+ Copyright (C) 2025 Tencent. All rights reserved.
4
+
5
+ ToonComposer is licensed under the MIT License except for the third-party components listed below, which is licensed under different terms. ToonComposer does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
+
7
+ For avoidance of doubts, ToonComposer refers to the inference code, parameters and weights made publicly available by Tencent in accordance with the MIT License in this repository.
8
+
9
+ Terms of the MIT License:
10
+ --------------------------------------------------------------------
11
+ Copyright (C) 2025 Tencent. All rights reserved.
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the " Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice (including the next paragraph) shall be included in all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18
+
19
+
20
+ The ToonComposer model was developed by Tencent based on the following Open Models.
21
+ The ToonComposer inference code was developed by Tencent based on the code of the following Open Models.The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
22
+
23
+ Open Models Licensed under the Apache-2.0 License:
24
+
25
+ --------------------------------------------------------------------
26
+ 1.Wan2.1
27
+ Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
28
+ The code of this model was modified by Tencent.
29
+
30
+ --------------------------------------------------------------------
31
+ Terms of the Apache-2.0 License:
32
+ --------------------------------------------------------------------
33
+ Apache License
34
+ Version 2.0, January 2004
35
+ http://www.apache.org/licenses/
36
+
37
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
38
+
39
+ 1. Definitions.
40
+
41
+ "License" shall mean the terms and conditions for use, reproduction,
42
+ and distribution as defined by Sections 1 through 9 of this document.
43
+
44
+ "Licensor" shall mean the copyright owner or entity authorized by
45
+ the copyright owner that is granting the License.
46
+
47
+ "Legal Entity" shall mean the union of the acting entity and all
48
+ other entities that control, are controlled by, or are under common
49
+ control with that entity. For the purposes of this definition,
50
+ "control" means (i) the power, direct or indirect, to cause the
51
+ direction or management of such entity, whether by contract or
52
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
53
+ outstanding shares, or (iii) beneficial ownership of such entity.
54
+
55
+ "You" (or "Your") shall mean an individual or Legal Entity
56
+ exercising permissions granted by this License.
57
+
58
+ "Source" form shall mean the preferred form for making modifications,
59
+ including but not limited to software source code, documentation
60
+ source, and configuration files.
61
+
62
+ "Object" form shall mean any form resulting from mechanical
63
+ transformation or translation of a Source form, including but
64
+ not limited to compiled object code, generated documentation,
65
+ and conversions to other media types.
66
+
67
+ "Work" shall mean the work of authorship, whether in Source or
68
+ Object form, made available under the License, as indicated by a
69
+ copyright notice that is included in or attached to the work
70
+ (an example is provided in the Appendix below).
71
+
72
+ "Derivative Works" shall mean any work, whether in Source or Object
73
+ form, that is based on (or derived from) the Work and for which the
74
+ editorial revisions, annotations, elaborations, or other modifications
75
+ represent, as a whole, an original work of authorship. For the purposes
76
+ of this License, Derivative Works shall not include works that remain
77
+ separable from, or merely link (or bind by name) to the interfaces of,
78
+ the Work and Derivative Works thereof.
79
+
80
+ "Contribution" shall mean any work of authorship, including
81
+ the original version of the Work and any modifications or additions
82
+ to that Work or Derivative Works thereof, that is intentionally
83
+ submitted to Licensor for inclusion in the Work by the copyright owner
84
+ or by an individual or Legal Entity authorized to submit on behalf of
85
+ the copyright owner. For the purposes of this definition, "submitted"
86
+ means any form of electronic, verbal, or written communication sent
87
+ to the Licensor or its representatives, including but not limited to
88
+ communication on electronic mailing lists, source code control systems,
89
+ and issue tracking systems that are managed by, or on behalf of, the
90
+ Licensor for the purpose of discussing and improving the Work, but
91
+ excluding communication that is conspicuously marked or otherwise
92
+ designated in writing by the copyright owner as "Not a Contribution."
93
+
94
+ "Contributor" shall mean Licensor and any individual or Legal Entity
95
+ on behalf of whom a Contribution has been received by Licensor and
96
+ subsequently incorporated within the Work.
97
+
98
+ 2. Grant of Copyright License. Subject to the terms and conditions of
99
+ this License, each Contributor hereby grants to You a perpetual,
100
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
101
+ copyright license to reproduce, prepare Derivative Works of,
102
+ publicly display, publicly perform, sublicense, and distribute the
103
+ Work and such Derivative Works in Source or Object form.
104
+
105
+ 3. Grant of Patent License. Subject to the terms and conditions of
106
+ this License, each Contributor hereby grants to You a perpetual,
107
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
108
+ (except as stated in this section) patent license to make, have made,
109
+ use, offer to sell, sell, import, and otherwise transfer the Work,
110
+ where such license applies only to those patent claims licensable
111
+ by such Contributor that are necessarily infringed by their
112
+ Contribution(s) alone or by combination of their Contribution(s)
113
+ with the Work to which such Contribution(s) was submitted. If You
114
+ institute patent litigation against any entity (including a
115
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
116
+ or a Contribution incorporated within the Work constitutes direct
117
+ or contributory patent infringement, then any patent licenses
118
+ granted to You under this License for that Work shall terminate
119
+ as of the date such litigation is filed.
120
+
121
+ 4. Redistribution. You may reproduce and distribute copies of the
122
+ Work or Derivative Works thereof in any medium, with or without
123
+ modifications, and in Source or Object form, provided that You
124
+ meet the following conditions:
125
+
126
+ (a) You must give any other recipients of the Work or
127
+ Derivative Works a copy of this License; and
128
+
129
+ (b) You must cause any modified files to carry prominent notices
130
+ stating that You changed the files; and
131
+
132
+ (c) You must retain, in the Source form of any Derivative Works
133
+ that You distribute, all copyright, patent, trademark, and
134
+ attribution notices from the Source form of the Work,
135
+ excluding those notices that do not pertain to any part of
136
+ the Derivative Works; and
137
+
138
+ (d) If the Work includes a "NOTICE" text file as part of its
139
+ distribution, then any Derivative Works that You distribute must
140
+ include a readable copy of the attribution notices contained
141
+ within such NOTICE file, excluding those notices that do not
142
+ pertain to any part of the Derivative Works, in at least one
143
+ of the following places: within a NOTICE text file distributed
144
+ as part of the Derivative Works; within the Source form or
145
+ documentation, if provided along with the Derivative Works; or,
146
+ within a display generated by the Derivative Works, if and
147
+ wherever such third-party notices normally appear. The contents
148
+ of the NOTICE file are for informational purposes only and
149
+ do not modify the License. You may add Your own attribution
150
+ notices within Derivative Works that You distribute, alongside
151
+ or as an addendum to the NOTICE text from the Work, provided
152
+ that such additional attribution notices cannot be construed
153
+ as modifying the License.
154
+
155
+ You may add Your own copyright statement to Your modifications and
156
+ may provide additional or different license terms and conditions
157
+ for use, reproduction, or distribution of Your modifications, or
158
+ for any such Derivative Works as a whole, provided Your use,
159
+ reproduction, and distribution of the Work otherwise complies with
160
+ the conditions stated in this License.
161
+
162
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
163
+ any Contribution intentionally submitted for inclusion in the Work
164
+ by You to the Licensor shall be under the terms and conditions of
165
+ this License, without any additional terms or conditions.
166
+ Notwithstanding the above, nothing herein shall supersede or modify
167
+ the terms of any separate license agreement you may have executed
168
+ with Licensor regarding such Contributions.
169
+
170
+ 6. Trademarks. This License does not grant permission to use the trade
171
+ names, trademarks, service marks, or product names of the Licensor,
172
+ except as required for reasonable and customary use in describing the
173
+ origin of the Work and reproducing the content of the NOTICE file.
174
+
175
+ 7. Disclaimer of Warranty. Unless required by applicable law or
176
+ agreed to in writing, Licensor provides the Work (and each
177
+ Contributor provides its Contributions) on an "AS IS" BASIS,
178
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
179
+ implied, including, without limitation, any warranties or conditions
180
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
181
+ PARTICULAR PURPOSE. You are solely responsible for determining the
182
+ appropriateness of using or redistributing the Work and assume any
183
+ risks associated with Your exercise of permissions under this License.
184
+
185
+ 8. Limitation of Liability. In no event and under no legal theory,
186
+ whether in tort (including negligence), contract, or otherwise,
187
+ unless required by applicable law (such as deliberate and grossly
188
+ negligent acts) or agreed to in writing, shall any Contributor be
189
+ liable to You for damages, including any direct, indirect, special,
190
+ incidental, or consequential damages of any character arising as a
191
+ result of this License or out of the use or inability to use the
192
+ Work (including but not limited to damages for loss of goodwill,
193
+ work stoppage, computer failure or malfunction, or any and all
194
+ other commercial damages or losses), even if such Contributor
195
+ has been advised of the possibility of such damages.
196
+
197
+ 9. Accepting Warranty or Additional Liability. While redistributing
198
+ the Work or Derivative Works thereof, You may choose to offer,
199
+ and charge a fee for, acceptance of support, warranty, indemnity,
200
+ or other liability obligations and/or rights consistent with this
201
+ License. However, in accepting such obligations, You may act only
202
+ on Your own behalf and on Your sole responsibility, not on behalf
203
+ of any other Contributor, and only if You agree to indemnify,
204
+ defend, and hold each Contributor harmless for any liability
205
+ incurred by, or claims asserted against, such Contributor by reason
206
+ of your accepting any such warranty or additional liability.
207
+
208
+ END OF TERMS AND CONDITIONS
209
+
210
+ APPENDIX: How to apply the Apache License to your work.
211
+
212
+ To apply the Apache License to your work, attach the following
213
+ boilerplate notice, with the fields enclosed by brackets "[]"
214
+ replaced with your own identifying information. (Don't include
215
+ the brackets!) The text should be enclosed in the appropriate
216
+ comment syntax for the file format. We also recommend that a
217
+ file or class name and description of purpose be included on the
218
+ same "printed page" as the copyright notice for easier
219
+ identification within third-party archives.
220
+
221
+ Copyright [yyyy] [name of copyright owner]
222
+
223
+ Licensed under the Apache License, Version 2.0 (the "License");
224
+ you may not use this file except in compliance with the License.
225
+ You may obtain a copy of the License at
226
+
227
+ http://www.apache.org/licenses/LICENSE-2.0
228
+
229
+ Unless required by applicable law or agreed to in writing, software
230
+ distributed under the License is distributed on an "AS IS" BASIS,
231
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
232
+ See the License for the specific language governing permissions and
233
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: ToonComposer
3
- emoji: 🚀
4
  colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ToonComposer
3
+ emoji: 🎨
4
  colorFrom: gray
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tooncomposer import ToonComposer, get_base_model_paths
5
+ import argparse
6
+ import json
7
+ from util.training_util import extract_img_to_sketch
8
+ import os
9
+ import tempfile
10
+ import cv2
11
+ import gradio as gr
12
+ from einops import rearrange
13
+ from datetime import datetime
14
+ from typing import Optional, List, Dict
15
+ from huggingface_hub import snapshot_download
16
+
17
+ os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Weights resolution and download helpers
21
+ # -----------------------------------------------------------------------------
22
+
23
+ WAN_REPO_ID = "Wan-AI/Wan2.1-I2V-14B-480P"
24
+ TOONCOMPOSER_REPO_ID = "TencentARC/ToonComposer"
25
+
26
+ def _path_is_dir_with_files(dir_path: str, required_files: List[str]) -> bool:
27
+ if not dir_path or not os.path.isdir(dir_path):
28
+ return False
29
+ for f in required_files:
30
+ if not os.path.exists(os.path.join(dir_path, f)):
31
+ return False
32
+ return True
33
+
34
+ def resolve_wan_model_root(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
35
+ """Return a directory containing Wan2.1-I2V-14B-480P weights.
36
+
37
+ Resolution order:
38
+ 1) preferred_dir arg (if valid)
39
+ 2) WAN21_I2V_DIR env var (if valid)
40
+ 3) HF local cache (no download) via snapshot_download(local_files_only=True)
41
+ 4) HF download to cache via snapshot_download()
42
+ """
43
+ # Required filenames relative to the model root
44
+ expected = get_base_model_paths("Wan2.1-I2V-14B-480P", format='dict', model_root=".")
45
+ required_files = []
46
+ required_files.extend([os.path.basename(p) for p in expected["dit"]])
47
+ required_files.append(os.path.basename(expected["image_encoder"]))
48
+ required_files.append(os.path.basename(expected["text_encoder"]))
49
+ required_files.append(os.path.basename(expected["vae"]))
50
+
51
+ # 1) preferred_dir arg
52
+ if _path_is_dir_with_files(preferred_dir or "", required_files):
53
+ return os.path.abspath(preferred_dir)
54
+
55
+ # 2) environment variable
56
+ env_dir = os.environ.get("WAN21_I2V_DIR")
57
+ if _path_is_dir_with_files(env_dir or "", required_files):
58
+ return os.path.abspath(env_dir)
59
+
60
+ # 3) try local cache without network
61
+ try:
62
+ cached_dir = snapshot_download(repo_id=WAN_REPO_ID, local_files_only=True)
63
+ return cached_dir
64
+ except Exception:
65
+ pass
66
+
67
+ # 4) download (may be large)
68
+ cached_dir = snapshot_download(repo_id=WAN_REPO_ID, token=hf_token)
69
+ return cached_dir
70
+
71
+ def resolve_tooncomposer_repo_dir(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
72
+ """Return a directory containing ToonComposer repo with 480p/608p subdirs."""
73
+ # Quick validity check: ensure either a subdir 480p or 608p exists with required files
74
+ def has_resolution_dirs(base_dir: str) -> bool:
75
+ if not base_dir or not os.path.isdir(base_dir):
76
+ return False
77
+ ok = False
78
+ for res in ["480p", "608p"]:
79
+ d = os.path.join(base_dir, res)
80
+ if os.path.isdir(d):
81
+ ckpt = os.path.join(d, "tooncomposer.ckpt")
82
+ cfg = os.path.join(d, "config.json")
83
+ if os.path.exists(ckpt) and os.path.exists(cfg):
84
+ ok = True
85
+ return ok
86
+
87
+ # 1) preferred_dir arg
88
+ if has_resolution_dirs(preferred_dir or ""):
89
+ return os.path.abspath(preferred_dir)
90
+
91
+ # 2) environment variable
92
+ env_dir = os.environ.get("TOONCOMPOSER_DIR")
93
+ if has_resolution_dirs(env_dir or ""):
94
+ return os.path.abspath(env_dir)
95
+
96
+ # 3) try local cache first
97
+ try:
98
+ cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, local_files_only=True)
99
+ return cached_dir
100
+ except Exception:
101
+ pass
102
+
103
+ # 4) download repo to cache
104
+ cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, token=hf_token)
105
+ return cached_dir
106
+
107
+ def build_checkpoints_by_resolution(tooncomposer_base_dir: str) -> Dict[str, Dict[str, object]]:
108
+ """Construct resolution mapping from a base repo dir that contains 480p/608p.
109
+
110
+ The ToonComposer HF repo stores, inside each resolution dir:
111
+ - tooncomposer.ckpt
112
+ - config.json (model configuration)
113
+ """
114
+ mapping = {}
115
+ # Known target sizes
116
+ res_to_hw = {
117
+ "480p": (480, 832),
118
+ "608p": (608, 1088),
119
+ }
120
+ for res, (h, w) in res_to_hw.items():
121
+ res_dir = os.path.join(tooncomposer_base_dir, res)
122
+ mapping[res] = {
123
+ "target_height": h,
124
+ "target_width": w,
125
+ "snapshot_args_path": os.path.join(res_dir, "config.json"),
126
+ "checkpoint_path": os.path.join(res_dir, "tooncomposer.ckpt"),
127
+ }
128
+ return mapping
129
+
130
+ # Will be populated in main() after resolving ToonComposer repo directory
131
+ checkpoints_by_resolution = {}
132
+
133
+ def tensor2video(frames):
134
+ frames = rearrange(frames, "C T H W -> T H W C")
135
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
136
+ frames = [Image.fromarray(frame) for frame in frames]
137
+ return frames
138
+
139
+ def _load_model_config(config_path: str) -> Dict[str, object]:
140
+ with open(config_path, "r") as f:
141
+ data = json.load(f)
142
+ return data
143
+
144
+ def _merge_with_defaults(cfg: Dict[str, object]) -> Dict[str, object]:
145
+ # Provide safe defaults for optional fields used at inference-time
146
+ defaults = {
147
+ "base_model_name": "Wan2.1-I2V-14B-480P",
148
+ "learning_rate": 1e-5,
149
+ "train_architecture": "lora",
150
+ "lora_rank": 4,
151
+ "lora_alpha": 4,
152
+ "lora_target_modules": "q,k,v,o,ffn.0,ffn.2",
153
+ "init_lora_weights": "kaiming",
154
+ "use_gradient_checkpointing": True,
155
+ "tiled": False,
156
+ "tile_size_height": 34,
157
+ "tile_size_width": 34,
158
+ "tile_stride_height": 18,
159
+ "tile_stride_width": 16,
160
+ "output_path": "./",
161
+ "use_local_lora": False,
162
+ "use_dera": False,
163
+ "dera_rank": None,
164
+ "use_dera_spatial": True,
165
+ "use_dera_temporal": True,
166
+ "use_sequence_cond": True,
167
+ "sequence_cond_mode": "sparse",
168
+ "use_channel_cond": False,
169
+ "use_sequence_cond_position_aware_residual": True,
170
+ "use_sequence_cond_loss": False,
171
+ "fast_dev": False,
172
+ "max_num_cond_images": 1,
173
+ "max_num_cond_sketches": 2,
174
+ "visualize_attention": False,
175
+ "random_spaced_cond_frames": False,
176
+ "use_sketch_mask": True,
177
+ "sketch_mask_ratio": 0.2,
178
+ "no_first_sketch": False,
179
+ }
180
+ merged = defaults.copy()
181
+ merged.update(cfg)
182
+ return merged
183
+
184
+ def initialize_model(resolution="480p", fast_dev=False, device="cuda:0", dtype=torch.bfloat16,
185
+ wan_model_dir: Optional[str] = None, tooncomposer_dir: Optional[str] = None,
186
+ hf_token: Optional[str] = None):
187
+ # Initialize model components
188
+ if resolution not in checkpoints_by_resolution:
189
+ raise ValueError(f"Resolution '{resolution}' is not available. Found: {list(checkpoints_by_resolution.keys())}")
190
+
191
+ # 1) resolve config and checkpoint from ToonComposer repo (local or HF)
192
+ snapshot_args_path = checkpoints_by_resolution[resolution]["snapshot_args_path"]
193
+ checkpoint_path = checkpoints_by_resolution[resolution]["checkpoint_path"]
194
+
195
+ # 2) load model config
196
+ snapshot_args_raw = _load_model_config(snapshot_args_path)
197
+ snapshot_args = _merge_with_defaults(snapshot_args_raw)
198
+ snapshot_args["checkpoint_path"] = checkpoint_path
199
+
200
+ # 3) resolve Wan2.1 model root
201
+ snapshot_args["model_root"] = resolve_wan_model_root(preferred_dir=wan_model_dir, hf_token=hf_token)
202
+
203
+ # Backward-compat fields
204
+ if "training_max_frame_stride" not in snapshot_args:
205
+ snapshot_args["training_max_frame_stride"] = 4
206
+ snapshot_args["random_spaced_cond_frames"] = False
207
+ args = argparse.Namespace(**snapshot_args)
208
+ if not fast_dev:
209
+ model = ToonComposer(
210
+ base_model_name=args.base_model_name,
211
+ model_root=args.model_root,
212
+ learning_rate=args.learning_rate,
213
+ train_architecture=args.train_architecture,
214
+ lora_rank=args.lora_rank,
215
+ lora_alpha=args.lora_alpha,
216
+ lora_target_modules=args.lora_target_modules,
217
+ init_lora_weights=args.init_lora_weights,
218
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
219
+ checkpoint_path=args.checkpoint_path,
220
+ tiled=args.tiled,
221
+ tile_size=(args.tile_size_height, args.tile_size_width),
222
+ tile_stride=(args.tile_stride_height, args.tile_stride_width),
223
+ output_path=args.output_path,
224
+ use_local_lora=args.use_local_lora,
225
+ use_dera=args.use_dera,
226
+ dera_rank=args.dera_rank,
227
+ use_dera_spatial=args.use_dera_spatial,
228
+ use_dera_temporal=args.use_dera_temporal,
229
+ use_sequence_cond=args.use_sequence_cond,
230
+ sequence_cond_mode=args.sequence_cond_mode,
231
+ use_channel_cond=args.use_channel_cond,
232
+ use_sequence_cond_position_aware_residual=args.use_sequence_cond_position_aware_residual,
233
+ use_sequence_cond_loss=args.use_sequence_cond_loss,
234
+ fast_dev=args.fast_dev,
235
+ max_num_cond_images=args.max_num_cond_images,
236
+ max_num_cond_sketches=args.max_num_cond_sketches,
237
+ visualize_attention=args.visualize_attention,
238
+ random_spaced_cond_frames=args.random_spaced_cond_frames,
239
+ use_sketch_mask=args.use_sketch_mask,
240
+ sketch_mask_ratio=args.sketch_mask_ratio,
241
+ no_first_sketch=args.no_first_sketch,
242
+ )
243
+ model = model.to(device, dtype=dtype).eval()
244
+ else:
245
+ print("Fast dev mode. Models will not be loaded.")
246
+ model = None
247
+ print("Models initialized.")
248
+ return model, device, dtype
249
+
250
+ # -----------------------------------------------------------------------------
251
+ # CLI args and global initialization
252
+ # -----------------------------------------------------------------------------
253
+
254
+ def _parse_args():
255
+ parser = argparse.ArgumentParser()
256
+ parser.add_argument("--resolution", type=str, default=os.environ.get("TOONCOMPOSER_RESOLUTION", "480p"), choices=["480p", "608p"], help="Target resolution to load by default.")
257
+ parser.add_argument("--device", type=str, default=os.environ.get("DEVICE", "cuda"))
258
+ parser.add_argument("--dtype", type=str, default=os.environ.get("DTYPE", "bfloat16"), choices=["bfloat16", "float32"])
259
+ parser.add_argument("--wan_model_dir", type=str, default=os.environ.get("WAN21_I2V_DIR"), help="Local directory containing Wan2.1 model files. If not provided, will try HF cache and download if needed.")
260
+ parser.add_argument("--tooncomposer_dir", type=str, default=os.environ.get("TOONCOMPOSER_DIR"), help="Local directory containing ToonComposer weights with 480p/608p subdirectories. If not provided, will try HF cache and download if needed.")
261
+ parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="Hugging Face token (if needed for gated models).")
262
+ parser.add_argument("--fast_dev", action="store_true", help="Run in fast dev mode without loading heavy models.")
263
+ return parser.parse_args()
264
+
265
+ _cli_args = _parse_args()
266
+
267
+ # Resolve ToonComposer repo dir and build resolution mapping
268
+ _toon_dir = resolve_tooncomposer_repo_dir(preferred_dir=_cli_args.tooncomposer_dir, hf_token=_cli_args.hf_token)
269
+ checkpoints_by_resolution = build_checkpoints_by_resolution(_toon_dir)
270
+
271
+ _dtype_map = {
272
+ "bfloat16": torch.bfloat16,
273
+ "float32": torch.float32,
274
+ }
275
+ fast_dev = bool(_cli_args.fast_dev)
276
+ model, device, dtype = initialize_model(
277
+ resolution=_cli_args.resolution,
278
+ fast_dev=fast_dev,
279
+ device=_cli_args.device,
280
+ dtype=_dtype_map[_cli_args.dtype],
281
+ wan_model_dir=_cli_args.wan_model_dir,
282
+ tooncomposer_dir=_cli_args.tooncomposer_dir,
283
+ hf_token=_cli_args.hf_token,
284
+ )
285
+
286
+ def process_conditions(num_items, item_inputs, num_frames, is_sketch=False, target_height=480, target_width=832):
287
+ """Process condition images/sketches into masked video tensor and mask"""
288
+ # Create empty tensors filled with -1
289
+ video = torch.zeros((1, 3, num_frames, target_height, target_width), device=device)
290
+ mask = torch.zeros((1, num_frames), device=device)
291
+
292
+ for i in range(num_items):
293
+ img, frame_idx = item_inputs[i]
294
+ if img is None or frame_idx is None:
295
+ continue
296
+
297
+ # Convert PIL image to tensor
298
+ img_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0
299
+ if is_sketch:
300
+ img_tensor = -img_tensor
301
+ img_tensor = img_tensor.unsqueeze(0).to(device)
302
+
303
+ # Resize to model's expected resolution while preserving aspect ratio
304
+ # Get original dimensions
305
+ _, _, h, w = img_tensor.shape
306
+
307
+ # Resize based on short edge while maintaining aspect ratio
308
+ if h/w < target_height/target_width:
309
+ new_h = target_height
310
+ new_w = int(w * (new_h / h))
311
+ else: # Width is the short edge
312
+ new_w = target_width
313
+ new_h = int(h * (new_w / w))
314
+
315
+ # Resize with the calculated dimensions
316
+ img_tensor = torch.nn.functional.interpolate(img_tensor, size=(new_h, new_w), mode="bilinear")
317
+
318
+ # Center crop to target resolution if needed
319
+ if new_h > target_height or new_w > target_width:
320
+ # Calculate starting positions for crop
321
+ start_h = max(0, (new_h - target_height) // 2)
322
+ start_w = max(0, (new_w - target_width) // 2)
323
+ # Crop
324
+ img_tensor = img_tensor[:, :, start_h:start_h+target_height, start_w:start_w+target_width]
325
+
326
+ # Place in video tensor
327
+ frame_idx = min(max(int(frame_idx), 0), num_frames-1)
328
+ if is_sketch:
329
+ video[:, :, frame_idx] = img_tensor[:, :3] # Handle RGBA sketches
330
+ else:
331
+ video[:, :, frame_idx] = img_tensor
332
+ mask[:, frame_idx] = 1.0
333
+ return video, mask
334
+
335
+ def process_sketch_masks(num_sketch_masks, sketch_mask_inputs, num_frames, target_height=480, target_width=832):
336
+ """Process sketch masks into a single tensor"""
337
+ # Create empty tensor filled with 1s (1 means no mask, keep original)
338
+ sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
339
+
340
+ for i in range(num_sketch_masks):
341
+ editor_value, frame_idx = sketch_mask_inputs[i]
342
+ if editor_value is None or frame_idx is None:
343
+ continue
344
+
345
+ # For ImageMask, we need to extract the mask from the editor_value dictionary
346
+ # editor_value is a dict with 'background', 'layers', and 'composite' keys from ImageEditor
347
+ if isinstance(editor_value, dict):
348
+ if "composite" in editor_value and editor_value["composite"] is not None:
349
+ # The 'composite' is the image with mask drawn on it
350
+ # Since we're using ImageMask with fixed black brush, the black areas are the mask
351
+ # Convert the composite to a binary mask (0=masked, 1=not masked)
352
+ # sketch = editor_value["background"] # This is the sketch
353
+ mask = editor_value["layers"][0] if editor_value["layers"] else None # This is the mask layer
354
+ if mask is not None:
355
+ # Convert mask to tensor and normalize
356
+ mask_array = np.array(mask)
357
+ mask_array = np.max(mask_array, axis=2)
358
+
359
+ # Convert to tensor, normalize to [0, 1]
360
+ mask_tensor = torch.from_numpy(mask_array).float()
361
+ if mask_tensor.max() > 1.0:
362
+ mask_tensor = mask_tensor / 255.0
363
+
364
+ # Resize to model's expected resolution
365
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, h, w]
366
+ mask_tensor = torch.nn.functional.interpolate(mask_tensor, size=(target_height, target_width), mode="nearest")
367
+
368
+ # Invert the mask: black (0) = masked area, white (1) = keep original
369
+ # We need to invert because in the UI black means "masked"
370
+ mask_tensor = 1.0 - mask_tensor
371
+
372
+ # Place in sketch_local_mask tensor
373
+ frame_idx = min(max(int(frame_idx), 0), num_frames-1)
374
+ sketch_local_mask[:, :, frame_idx] = mask_tensor
375
+
376
+ sketch_mask_vis = torch.ones((1, 3, num_frames, target_height, target_width), device=device)
377
+ for t in range(sketch_local_mask.shape[2]):
378
+ for c in range(3):
379
+ sketch_mask_vis[0, c, t, :, :] = torch.where(
380
+ sketch_local_mask[0, 0, t] > 0.5,
381
+ 1.0, # White for unmasked areas
382
+ -1.0 # Black for masked areas
383
+ )
384
+ return sketch_local_mask
385
+
386
+
387
+ def invert_sketch(image):
388
+ """Invert the colors of an image (black to white, white to black)"""
389
+ if image is None:
390
+ return None
391
+
392
+ # Handle input from ImageMask component (EditorValue dictionary)
393
+ if isinstance(image, dict) and "background" in image:
394
+ # Extract the background image
395
+ bg_image = image["background"]
396
+
397
+ # Invert the background
398
+ inverted_bg = invert_sketch_internal(bg_image)
399
+
400
+ # Return updated editor value
401
+ return gr.update(value=inverted_bg)
402
+
403
+ # Original function for regular images
404
+ return invert_sketch_internal(image)
405
+
406
+ def invert_sketch_internal(image):
407
+ """Internal function to invert an image"""
408
+ if image is None:
409
+ return None
410
+
411
+ # Convert to PIL image if needed
412
+ if isinstance(image, str): # If it's a filepath
413
+ image = Image.open(image)
414
+ elif isinstance(image, np.ndarray):
415
+ image = Image.fromarray(image)
416
+
417
+ # Ensure it's a PIL image now
418
+ if not isinstance(image, Image.Image):
419
+ try:
420
+ image = Image.fromarray(np.array(image))
421
+ except:
422
+ print(f"Warning: Could not convert image of type {type(image)} to PIL Image")
423
+ return image
424
+
425
+ # Invert the image
426
+ inverted = Image.fromarray(255 - np.array(image))
427
+ return inverted
428
+
429
+ def create_blank_mask(canvas_width=832, canvas_height=480):
430
+ """Create a blank white mask image"""
431
+ return Image.new('RGB', (canvas_width, canvas_height), color='white')
432
+
433
+ def create_mask_with_sketch(sketch, canvas_width=832, canvas_height=480):
434
+ """Create a mask image with sketch as background"""
435
+ if sketch is None:
436
+ return create_blank_mask(canvas_width, canvas_height)
437
+
438
+ # Convert sketch to PIL if needed
439
+ if not isinstance(sketch, Image.Image):
440
+ sketch = Image.fromarray(np.array(sketch))
441
+
442
+ # Resize sketch to fit the canvas
443
+ sketch = sketch.resize((canvas_width, canvas_height))
444
+
445
+ # Create a semi-transparent white layer over the sketch
446
+ overlay = Image.new('RGBA', (canvas_width, canvas_height), (255, 255, 255, 128))
447
+
448
+ # Ensure sketch has alpha channel
449
+ if sketch.mode != 'RGBA':
450
+ sketch = sketch.convert('RGBA')
451
+
452
+ # Overlay the semi-transparent white layer on the sketch
453
+ result = Image.alpha_composite(sketch, overlay)
454
+
455
+ # Convert back to RGB for Gradio
456
+ return result.convert('RGB')
457
+
458
+ def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args):
459
+ """Validate user inputs and return error messages if any"""
460
+ errors = []
461
+
462
+ # Check text prompt
463
+ if not text_prompt or text_prompt.strip() == "":
464
+ errors.append("❌ Text prompt is required. Please enter a description for your video.")
465
+
466
+ # Check condition images
467
+ cond_images_count = 0
468
+ for i in range(int(num_cond_images)):
469
+ img = args[i*2]
470
+ frame_idx = args[i*2+1]
471
+
472
+ if img is None:
473
+ errors.append(f"❌ Image #{i+1} is missing. Please upload an image or reduce the number of keyframe images.")
474
+ else:
475
+ cond_images_count += 1
476
+
477
+ if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
478
+ errors.append(f"❌ Frame index for Image #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
479
+
480
+ # Check condition sketches
481
+ num_cond_sketches_index = 8 # Starting index for sketch inputs
482
+ cond_sketches_count = 0
483
+ sketch_frame_indices = []
484
+
485
+ for i in range(int(num_cond_sketches)):
486
+ sketch_idx = num_cond_sketches_index + i*2
487
+ frame_idx_idx = num_cond_sketches_index + 1 + i*2
488
+
489
+ if sketch_idx < len(args) and frame_idx_idx < len(args):
490
+ sketch = args[sketch_idx]
491
+ frame_idx = args[frame_idx_idx]
492
+
493
+ # Check if sketch is provided
494
+ if sketch is None:
495
+ errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch or reduce the number of keyframe sketches.")
496
+ else:
497
+ # For ImageMask components, check if background is provided
498
+ if isinstance(sketch, dict):
499
+ if "background" not in sketch or sketch["background"] is None:
500
+ errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch image.")
501
+ else:
502
+ cond_sketches_count += 1
503
+ else:
504
+ cond_sketches_count += 1
505
+
506
+ # Check frame index
507
+ if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
508
+ errors.append(f"❌ Frame index for Sketch #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
509
+ elif frame_idx is not None:
510
+ sketch_frame_indices.append(frame_idx)
511
+
512
+ # Check for duplicate frame indices
513
+ image_frame_indices = []
514
+ for i in range(int(num_cond_images)):
515
+ frame_idx = args[i*2+1]
516
+ if frame_idx is not None:
517
+ image_frame_indices.append(frame_idx)
518
+
519
+ all_frame_indices = image_frame_indices + sketch_frame_indices
520
+ if len(all_frame_indices) != len(set(all_frame_indices)):
521
+ errors.append("❌ Duplicate frame indices detected. Each image and sketch must be placed at a different frame.")
522
+
523
+ # Check minimum requirements
524
+ if cond_images_count == 0:
525
+ errors.append("❌ At least one input image is required.")
526
+
527
+ return errors
528
+
529
+ def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
530
+ # Validate inputs first
531
+ validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
532
+
533
+ if validation_errors:
534
+ error_message = "\n".join(validation_errors)
535
+ return gr.update(value=None), error_message
536
+
537
+ try:
538
+ # Parse inputs
539
+ # Get the condition images
540
+ cond_images = []
541
+ for i in range(int(num_cond_images)):
542
+ img = args[i*2]
543
+ frame_idx = args[i*2+1]
544
+ if img is not None and frame_idx is not None:
545
+ cond_images.append((img, frame_idx))
546
+
547
+ # Get num_cond_sketches
548
+ if num_cond_sketches is None:
549
+ num_cond_sketches = 0
550
+ else:
551
+ num_cond_sketches = int(num_cond_sketches)
552
+
553
+ # Get condition sketches and masks
554
+ cond_sketches = []
555
+ sketch_masks = []
556
+ num_cond_sketches_index = 8 # Starting index for sketch inputs
557
+
558
+ for i in range(num_cond_sketches):
559
+ sketch_idx = num_cond_sketches_index + i*2
560
+ frame_idx_idx = num_cond_sketches_index + 1 + i*2
561
+
562
+ if sketch_idx < len(args) and frame_idx_idx < len(args):
563
+ editor_value = args[sketch_idx]
564
+ frame_idx = args[frame_idx_idx]
565
+
566
+ if editor_value is not None and frame_idx is not None:
567
+ # Extract the sketch from the background of the editor value
568
+ if isinstance(editor_value, dict) and "background" in editor_value:
569
+ sketch = editor_value["background"]
570
+ if sketch is not None:
571
+ cond_sketches.append((sketch, frame_idx))
572
+ # Also add to sketch_masks for mask processing
573
+ sketch_masks.append((editor_value, frame_idx))
574
+ else:
575
+ # For regular image inputs (first sketch)
576
+ if editor_value is not None:
577
+ cond_sketches.append((editor_value, frame_idx))
578
+
579
+ # Set target resolution based on selection
580
+ target_height, target_width = checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"]
581
+
582
+ # Update model resolution
583
+ if not fast_dev:
584
+ model.update_height_width(target_height, target_width)
585
+
586
+ # Process conditions
587
+ with torch.no_grad():
588
+ # Process image conditions
589
+ masked_cond_video, preserved_cond_mask = process_conditions(
590
+ num_cond_images, cond_images, num_frames, target_height=target_height, target_width=target_width
591
+ )
592
+
593
+ # Process sketch conditions
594
+ masked_cond_sketch, preserved_sketch_mask = process_conditions(
595
+ len(cond_sketches), cond_sketches, num_frames, is_sketch=True, target_height=target_height, target_width=target_width
596
+ )
597
+
598
+ # Process sketch masks (if any)
599
+ sketch_local_mask = None
600
+ if len(sketch_masks) > 0:
601
+ sketch_local_mask = process_sketch_masks(
602
+ len(sketch_masks), sketch_masks, num_frames, target_height=target_height, target_width=target_width
603
+ )
604
+ else:
605
+ sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
606
+
607
+ if fast_dev:
608
+ print("Fast dev mode, returning dummy video")
609
+ # Create a simple dummy video for testing
610
+ temp_dir = tempfile.mkdtemp()
611
+ video_path = os.path.join(temp_dir, "dummy_video.mp4")
612
+
613
+ # Create a simple test video
614
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
615
+ video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (target_width, target_height))
616
+
617
+ for i in range(30): # 30 frames
618
+ # Create a simple colored frame
619
+ frame = np.full((target_height, target_width, 3), (i * 8) % 255, dtype=np.uint8)
620
+ video_writer.write(frame)
621
+
622
+ video_writer.release()
623
+ return video_path, "✅ Dummy video generated successfully in fast dev mode!"
624
+
625
+ masked_cond_video = masked_cond_video.to(device=device, dtype=dtype)
626
+ preserved_cond_mask = preserved_cond_mask.to(device=device, dtype=dtype)
627
+ masked_cond_sketch = masked_cond_sketch.to(device=device, dtype=dtype)
628
+ preserved_sketch_mask = preserved_sketch_mask.to(device=device, dtype=dtype)
629
+
630
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(device).type):
631
+ # Generate video
632
+ model.pipe.device = device
633
+ generated_video = model.pipe(
634
+ prompt=[text_prompt],
635
+ negative_prompt=[model.negative_prompt],
636
+ input_image=None,
637
+ num_inference_steps=15,
638
+ num_frames=num_frames,
639
+ seed=42, tiled=True,
640
+ input_condition_video=masked_cond_video,
641
+ input_condition_preserved_mask=preserved_cond_mask,
642
+ input_condition_video_sketch=masked_cond_sketch,
643
+ input_condition_preserved_mask_sketch=preserved_sketch_mask,
644
+ sketch_local_mask=sketch_local_mask,
645
+ cfg_scale=cfg_scale,
646
+ sequence_cond_residual_scale=sequence_cond_residual_scale,
647
+ height=target_height,
648
+ width=target_width,
649
+ )
650
+
651
+ # Convert to PIL images
652
+ video_frames = model.pipe.tensor2video(generated_video[0].cpu())
653
+
654
+ # Convert PIL images to an MP4 video
655
+ temp_dir = tempfile.mkdtemp()
656
+ video_path = os.path.join(temp_dir, "generated_video.mp4")
657
+
658
+ width, height = video_frames[0].size
659
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 video
660
+ video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (width, height)) # 20 fps
661
+
662
+ for frame in video_frames:
663
+ # Convert PIL image to OpenCV BGR format
664
+ frame_bgr = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
665
+ video_writer.write(frame_bgr)
666
+
667
+ video_writer.release()
668
+ print(f"Generated video saved to {video_path}. Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
669
+
670
+ return video_path, f"✅ Video generated successfully! (with {len(cond_images)} keyframe images, {len(cond_sketches)} keyframe sketches)"
671
+
672
+ except Exception as e:
673
+ error_msg = f"❌ Error during generation: {str(e)}"
674
+ print(error_msg)
675
+ return gr.update(value=None), error_msg
676
+
677
+ def create_sample_gallery():
678
+ """Create gallery items for samples"""
679
+ import os
680
+
681
+ gallery_items = []
682
+ sample_info = [
683
+ {
684
+ "id": 1,
685
+ "title": "Sample 1",
686
+ "description": "Man playing with blue fish underwater (3 sketches)",
687
+ "preview": "samples/1_image1.png"
688
+ },
689
+ {
690
+ "id": 2,
691
+ "title": "Sample 2",
692
+ "description": "Girl and boy planting a growing flower (2 sketches)",
693
+ "preview": "samples/2_image1.jpg"
694
+ },
695
+ {
696
+ "id": 3,
697
+ "title": "Sample 3",
698
+ "description": "Ancient Chinese boy giving apple to elder (1 sketch)",
699
+ "preview": "samples/3_image1.png"
700
+ }
701
+ ]
702
+
703
+ for sample in sample_info:
704
+ if os.path.exists(sample["preview"]):
705
+ gallery_items.append((sample["preview"], f"{sample['title']}: {sample['description']}"))
706
+
707
+ return gallery_items
708
+
709
+ def handle_gallery_select(evt: gr.SelectData):
710
+ """Handle gallery selection and load the corresponding sample"""
711
+ sample_id = evt.index + 1 # Gallery index starts from 0, sample IDs start from 1
712
+ return apply_sample_to_ui(sample_id)
713
+
714
+ def load_sample_data(sample_id):
715
+ """Load sample data based on the selected sample"""
716
+ import os
717
+
718
+ samples_dir = "samples"
719
+
720
+ # Sample configurations
721
+ sample_configs = {
722
+ 1: {
723
+ "prompt": "Underwater scene: A shirtless man plays with a spiraling blue fish. A whale follows a bag in the man's hand, swimming in circles as the man uses the bag to lure the blue fish forward. Anime. High quality.",
724
+ "num_sketches": 3,
725
+ "image_frame": 0,
726
+ "sketch_frames": [20, 40, 60],
727
+ "num_frames": 61
728
+ },
729
+ 2: {
730
+ "prompt": "A girl and a silver-haired boy plant a huge flower. As the camera slowly moves up, the huge flower continues to grow and bloom. Anime. High quality.",
731
+ "num_sketches": 2,
732
+ "image_frame": 0,
733
+ "sketch_frames": [30, 60],
734
+ "num_frames": 61
735
+ },
736
+ 3: {
737
+ "prompt": "An ancient Chinese boy holds an apple and smiles as he gives it to an elderly man nearby. Anime. High quality.",
738
+ "num_sketches": 1,
739
+ "image_frame": 0,
740
+ "sketch_frames": [30],
741
+ "num_frames": 33
742
+ }
743
+ }
744
+
745
+ if sample_id not in sample_configs:
746
+ return None
747
+
748
+ config = sample_configs[sample_id]
749
+
750
+ # Load image
751
+ image_path = os.path.join(samples_dir, f"{sample_id}_image1.png")
752
+ if not os.path.exists(image_path):
753
+ image_path = os.path.join(samples_dir, f"{sample_id}_image1.jpg")
754
+
755
+ # Load sketches
756
+ sketches = []
757
+ for i in range(config["num_sketches"]):
758
+ sketch_path = os.path.join(samples_dir, f"{sample_id}_sketch{i+1}.jpg")
759
+ if os.path.exists(sketch_path):
760
+ sketches.append(sketch_path)
761
+
762
+ # Load output video
763
+ output_path = os.path.join(samples_dir, f"{sample_id}_out.mp4")
764
+
765
+ return {
766
+ "prompt": config["prompt"],
767
+ "image": image_path if os.path.exists(image_path) else None,
768
+ "sketches": sketches,
769
+ "image_frame": config["image_frame"],
770
+ "sketch_frames": config["sketch_frames"][:len(sketches)],
771
+ "output_video": output_path if os.path.exists(output_path) else None,
772
+ "num_sketches": len(sketches),
773
+ "num_frames": config["num_frames"]
774
+ }
775
+
776
+ def apply_sample_to_ui(sample_id):
777
+ """Apply sample data to UI components"""
778
+ sample_data = load_sample_data(sample_id)
779
+
780
+ if not sample_data:
781
+ return [gr.update() for _ in range(20)] # Return no updates if sample not found
782
+
783
+ updates = [gr.update(value=sample_data["num_frames"])]
784
+
785
+ # Update prompt
786
+ updates.append(gr.update(value=sample_data["prompt"]))
787
+
788
+ # Update number of sketches
789
+ updates.append(gr.update(value=sample_data["num_sketches"]))
790
+
791
+ # Update condition image
792
+ updates.append(gr.update(value=sample_data["image"]))
793
+ updates.append(gr.update(value=sample_data["image_frame"]))
794
+
795
+ # Update sketches (up to 4)
796
+ for i in range(4):
797
+ if i < len(sample_data["sketches"]):
798
+ # Load sketch image
799
+ sketch_img = Image.open(sample_data["sketches"][i])
800
+ # Create ImageMask format
801
+ sketch_dict = {
802
+ "background": sketch_img,
803
+ "layers": [],
804
+ "composite": sketch_img
805
+ }
806
+ updates.append(gr.update(value=sketch_dict))
807
+ updates.append(gr.update(value=sample_data["sketch_frames"][i]))
808
+ else:
809
+ updates.append(gr.update(value=None))
810
+ updates.append(gr.update(value=30))
811
+
812
+ # Update output video
813
+ updates.append(gr.update(value=sample_data["output_video"]))
814
+
815
+ # Update status
816
+ updates.append(gr.update(value=f"✅ Loaded Sample {sample_id}: {sample_data['prompt'][:50]}..."))
817
+
818
+ return updates
819
+
820
+ if __name__ == "__main__":
821
+ from util.stylesheets import css, pre_js, banner_image
822
+ with gr.Blocks(title="🎨 ToonComposer Demo", css=css, js=pre_js) as iface:
823
+ with gr.Row():
824
+ with gr.Column(scale=1):
825
+ gr.HTML(banner_image)
826
+ with gr.Column(scale=1):
827
+ gr.Markdown("""
828
+ 💡 **Quick Guide**
829
+ 1. Set the promopt and number of target frames, input keyframe images/sketches, etc.
830
+ 2. Upload keyframe image as the first frame (with index set to 0).
831
+ 3. Upload sketches with optional motion masks for controlled generation at specified frame indices.
832
+ 4. Click the *Generate* button to create your cartoon video.
833
+ """)
834
+
835
+ max_num_frames = 61
836
+ cond_images_inputs = []
837
+ cond_sketches_inputs = []
838
+ with gr.Row():
839
+ with gr.Column(scale=1):
840
+ with gr.Accordion("Video Settings", open=True):
841
+ num_frames = gr.Slider(
842
+ minimum=17, maximum=max_num_frames, value=max_num_frames, step=1, label="🎥 Number of Frames",
843
+ info="Select the total number of frames for the generated video. Should be 4N+"
844
+ )
845
+
846
+ resolution = gr.Radio(
847
+ choices=["480p", "608p"],
848
+ value="480p",
849
+ label="🎥 Resolution",
850
+ info="Select the resolution for the generated video."
851
+ )
852
+
853
+ text_prompt = gr.Textbox(
854
+ label="📝 Text Prompt",
855
+ placeholder="Enter a description for the video.",
856
+ info="Describe what you want to generate in the video.",
857
+ lines=5
858
+ )
859
+ cfg_scale = gr.Slider(
860
+ minimum=1.0, maximum=15.0, value=7.5, label="⚙️ CFG Scale",
861
+ info="Adjust the classifier-free guidance scale for generation."
862
+ )
863
+ sequence_cond_residual_scale = gr.Slider(
864
+ minimum=0.0, maximum=1.2, value=1.0, label="⚙️ Pos-aware Residual Scale",
865
+ info="Adjust the residual scale for the position-aware sequence condition."
866
+ )
867
+
868
+ with gr.Column(scale=3):
869
+ with gr.Accordion("Keyframe Image(s)", open=True):
870
+ num_cond_images = gr.Slider(
871
+ minimum=1, maximum=4, value=1, step=1, label="🖼️ Number of Keyframe Images",
872
+ info="Specify how many keyframe color images to use (max 4 images)."
873
+ )
874
+ for i in range(4): # Max 4 condition images
875
+ with gr.Tab(label=f"Image {i+1}", interactive=i==0) as tab:
876
+ gr.Markdown("At least one image is required. \n Each image or sketch will be used to control the cartoon geneartion at the given frame index.")
877
+ image_input = gr.Image(
878
+ label=f"Image {i+1}", type="pil",
879
+ placeholder=f"Upload a keyframe image {i+1}..."
880
+ )
881
+ frame_index_input = gr.Slider(
882
+ label=f"Frame Index for Image #{i+1}", minimum=0, maximum=max_num_frames - 1,
883
+ value=i * (max_num_frames-1) // 3, step=1,
884
+ info=f"Frame position for Image {i+1} (0 to {max_num_frames-1})"
885
+ )
886
+ cond_images_inputs.append((image_input, frame_index_input, tab))
887
+
888
+
889
+ with gr.Column(scale=3):
890
+ with gr.Accordion("Keyframe Sketch(es)", open=True):
891
+ num_cond_sketches = gr.Slider(
892
+ minimum=1, maximum=4, value=1, step=1, label="✏️ Number of Keyframe Sketch(es)",
893
+ info="Specify how many keyframe sketches to use (max 4 sketches)."
894
+ )
895
+ for i in range(4): # Max 4 condition sketches
896
+ with gr.Tab(label=f"Sketch {i + 1}", interactive=i==0) as tab:
897
+
898
+ gr.Markdown("At least one sketch is required. \n You can optionally draw black areas using the brush tool to mark regions where motion can be generated freely.")
899
+
900
+ # Use ImageMask which allows uploading an image and drawing a mask
901
+ sketch_input = gr.ImageMask(
902
+ label=f"Sketch {i + 1} with Motion Mask",
903
+ type="pil",
904
+ elem_id=f"sketch_mask_{i + 1}"
905
+ )
906
+
907
+ # All sketches have a frame index input
908
+ _frame_index_input = gr.Slider(
909
+ label=f"Frame Index for Sketch #{i + 1}", minimum=0, maximum=max_num_frames - 1,
910
+ value=max_num_frames-1, step=1,
911
+ info=f"Frame position for Sketch {i + 1} (0 to {max_num_frames-1})"
912
+ )
913
+
914
+ cond_sketches_inputs.append((sketch_input, _frame_index_input, tab))
915
+
916
+ with gr.Row():
917
+ with gr.Column(scale=1):
918
+ # Sample Gallery Section
919
+ with gr.Accordion("🔍 Sample Gallery", open=True):
920
+ gr.Markdown("Click on any sample image below to load the sample inputs.")
921
+ sample_gallery = gr.Gallery(
922
+ value=create_sample_gallery(),
923
+ label="Sample Examples",
924
+ show_label=False,
925
+ elem_id="sample-gallery",
926
+ columns=3,
927
+ rows=1,
928
+ height=200,
929
+ allow_preview=True,
930
+ object_fit="contain")
931
+
932
+ with gr.Accordion("🛠️ Tools", open=False):
933
+ tool_input = gr.Image(
934
+ label=f"Input Image", type="pil",
935
+ placeholder=f"Upload an image."
936
+ )
937
+ invert_btn = gr.Button(f"Invert Colors")
938
+ invert_btn.click(
939
+ fn=invert_sketch,
940
+ inputs=[tool_input],
941
+ outputs=[tool_input]
942
+ )
943
+
944
+ with gr.Column(scale=1):
945
+ status_text = gr.Textbox(
946
+ label="📊 Status",
947
+ value="Ready to generate. Please check your inputs and click Run.",
948
+ interactive=False,
949
+ lines=5
950
+ )
951
+
952
+ with gr.Accordion("🎬 Generated Video", open=True):
953
+ output_video = gr.Video(
954
+ label="Video Output",
955
+ show_label=True
956
+ )
957
+ run_button = gr.Button("🚀 Generate Video", variant="primary", size="lg")
958
+
959
+ def update_visibility(num_items, num_frames):
960
+ # Update visibility for columns
961
+ updates_images = []
962
+ updates_indices = []
963
+ for i in range(4):
964
+ is_visible = i < num_items
965
+ # is_visible = True
966
+ updates_images.append(gr.update(interactive=is_visible))
967
+ updates_indices.append(gr.update(
968
+ value=((num_frames - 1) // max(num_items, 1)) * (i + 1),
969
+ minimum=0, maximum=num_frames-1,
970
+ ))
971
+ return updates_images + updates_indices
972
+
973
+ def update_visibility_images(num_items, num_frames):
974
+ # Update visibility for columns
975
+ updates_images = []
976
+ updates_indices = []
977
+ for i in range(4):
978
+ is_visible = i < num_items
979
+ updates_images.append(gr.update(interactive=is_visible))
980
+ updates_indices.append(gr.update(
981
+ value=((num_frames - 1) // max(num_items, 1)) * i,
982
+ minimum=0, maximum=num_frames-1,
983
+ ))
984
+ return updates_images + updates_indices
985
+
986
+ def update_frame_ranges(num_items_images, num_items_sketches, num_frames):
987
+ """Update the maximum values for all frame index sliders"""
988
+ updates = []
989
+ for i in range(4): # Images
990
+ updates.append(gr.update(
991
+ value=((num_frames - 1) // max(num_items_images, 1)) * i,
992
+ maximum=num_frames-1
993
+ ))
994
+ for i in range(4): # Sketches
995
+ updates.append(gr.update(
996
+ value=((num_frames - 1) // max(num_items_sketches, 1)) * (i + 1),
997
+ maximum=num_frames-1))
998
+ return updates
999
+
1000
+ num_cond_images.change(
1001
+ fn=update_visibility_images,
1002
+ inputs=[num_cond_images, num_frames],
1003
+ outputs=[tab for _, _, tab in cond_images_inputs] \
1004
+ + [frame_index_input for _, frame_index_input, _ in cond_images_inputs],
1005
+ )
1006
+
1007
+ num_cond_sketches.change(
1008
+ fn=update_visibility,
1009
+ inputs=[num_cond_sketches, num_frames],
1010
+ outputs=[tab for _, _, tab in cond_sketches_inputs] \
1011
+ + [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs],
1012
+ )
1013
+
1014
+ num_frames.change(
1015
+ fn=update_frame_ranges,
1016
+ inputs=[num_cond_images, num_cond_sketches, num_frames],
1017
+ outputs=[frame_index_input for _, frame_index_input, _ in cond_images_inputs] + \
1018
+ [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs]
1019
+ )
1020
+
1021
+ def update_resolution(resolution):
1022
+ model.update_height_width(checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"])
1023
+ model.load_tooncomposer_checkpoint(checkpoints_by_resolution[resolution]["checkpoint_path"])
1024
+ return gr.update(), gr.update()
1025
+
1026
+ resolution.change(
1027
+ fn=update_resolution,
1028
+ inputs=[resolution],
1029
+ outputs=[output_video, run_button]
1030
+ )
1031
+
1032
+ sample_outputs = [
1033
+ num_frames, text_prompt, num_cond_sketches,
1034
+ cond_images_inputs[0][0], cond_images_inputs[0][1], # Image 1
1035
+ cond_sketches_inputs[0][0], cond_sketches_inputs[0][1], # Sketch 1
1036
+ cond_sketches_inputs[1][0], cond_sketches_inputs[1][1], # Sketch 2
1037
+ cond_sketches_inputs[2][0], cond_sketches_inputs[2][1], # Sketch 3
1038
+ cond_sketches_inputs[3][0], cond_sketches_inputs[3][1], # Sketch 4
1039
+ output_video, status_text
1040
+ ]
1041
+
1042
+ sample_gallery.select(
1043
+ fn=handle_gallery_select,
1044
+ outputs=sample_outputs
1045
+ )
1046
+
1047
+ inputs = [num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution]
1048
+ run_button.click(
1049
+ fn=tooncomposer_inference,
1050
+ inputs=inputs,
1051
+ outputs=[output_video, status_text]
1052
+ )
1053
+
1054
+ # Add condition image inputs
1055
+ for image_input, frame_index_input, _ in cond_images_inputs:
1056
+ inputs.append(image_input)
1057
+ inputs.append(frame_index_input)
1058
+
1059
+ # Add sketch inputs (both regular and ImageMask)
1060
+ for sketch_input, frame_index_input, _ in cond_sketches_inputs:
1061
+ inputs.append(sketch_input)
1062
+ inputs.append(frame_index_input)
1063
+
1064
+ iface.launch(server_port=7860, server_name="0.0.0.0",
1065
+ allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")),
1066
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "samples"))])
model/__init__.py ADDED
File without changes
model/dera.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from .dit import flash_attention
5
+ import torch.amp as amp
6
+
7
+
8
+ class DeRAAttention(nn.Module):
9
+
10
+ def __init__(self,
11
+ dim,
12
+ num_heads,
13
+ window_size=(-1, -1),
14
+ mode="spatial"):
15
+ assert dim % num_heads == 0
16
+ super().__init__()
17
+ self.dim = dim
18
+ self.num_heads = num_heads
19
+ self.head_dim = dim // num_heads
20
+ self.window_size = window_size
21
+
22
+ self.q = nn.Linear(dim, dim)
23
+ self.k = nn.Linear(dim, dim)
24
+ self.v = nn.Linear(dim, dim)
25
+ self.o = nn.Linear(dim, dim)
26
+ self.visualize_attention = False
27
+
28
+ if mode == 'spatial':
29
+ self.rope_apply = self.rope_apply_spatial
30
+ elif mode == 'temporal':
31
+ self.rope_apply = self.rope_apply_temporal
32
+ elif mode == 'spatial_temporal':
33
+ self.rope_apply = self.rope_apply_spatial_temporal
34
+ else:
35
+ raise ValueError("Invalid mode: {}".format(mode))
36
+
37
+ @staticmethod
38
+ @amp.autocast(enabled=False, device_type="cuda")
39
+ def rope_apply_spatial(x, grid_size, freqs, sequence_cond_compressed_indices=None):
40
+ batch, _, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
41
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
42
+ assert len(grid_size) == 2, "grid_size mustbe [h, w]"
43
+ h, w = grid_size[0], grid_size[1]
44
+ seq_len = h * w
45
+ x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
46
+ batch, seq_len, n, -1, 2))
47
+ freqs_i = torch.cat([
48
+ freqs[1][:h].view(1, h, 1, -1).expand(1, h, w, -1),
49
+ freqs[2][:w].view(1, 1, w, -1).expand(1, h, w, -1)
50
+ ], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
51
+ freqs_i = torch.concat([freqs_i.new_ones(batch, seq_len, 1, c//3), freqs_i], dim=3)
52
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
53
+ return x_i.float()
54
+
55
+ @staticmethod
56
+ @amp.autocast(enabled=False, device_type="cuda")
57
+ def rope_apply_temporal(x, grid_size, freqs, sequence_cond_compressed_indices=None):
58
+ batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
59
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
60
+ assert len(grid_size) == 1, "grid_size must be [t]"
61
+ seq_len = grid_size[0]
62
+ x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(batch, seq_len, n, -1, 2))
63
+ freqs_i = torch.cat([
64
+ freqs[0][:seq_len].view(seq_len, 1, 1, -1)
65
+ ], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
66
+ freqs_i = torch.concat([freqs_i, freqs_i.new_ones(batch, seq_len, 1, 2 * c//3)], dim=3)
67
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
68
+ if seq_len_actual > seq_len:
69
+ sequence_cond_seq_length = seq_len_actual - seq_len
70
+ if sequence_cond_seq_length == seq_len:
71
+ x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
72
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
73
+ else:
74
+ sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
75
+ sequence_cond_t_length = len(sequence_cond_compressed_index)
76
+ assert sequence_cond_t_length == sequence_cond_seq_length, "`sequence_cond_t_length` must be equal to `sequence_cond_seq_length`"
77
+ x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
78
+ freqs_i_sequence_cond = torch.cat([
79
+ freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1),
80
+ ], dim=-1).reshape(sequence_cond_seq_length, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
81
+ freqs_i_sequence_cond = torch.concat([freqs_i_sequence_cond, freqs_i_sequence_cond.new_ones(batch, sequence_cond_t_length, 1, 2 * c//3)], dim=3)
82
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
83
+ x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
84
+
85
+ return x_i.float()
86
+
87
+ @staticmethod
88
+ @amp.autocast(enabled=False, device_type="cuda")
89
+ def rope_apply_spatial_temporal(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
90
+ batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
91
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
92
+ assert len(grid_sizes) == 3, "grid_sizes must be ([f, h, w])"
93
+ f, h, w = grid_sizes[0], grid_sizes[1], grid_sizes[2]
94
+ seq_len = f * h * w
95
+ x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
96
+ batch, seq_len, n, -1, 2))
97
+ freqs_i = torch.cat([
98
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
99
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
100
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
101
+ ], dim=-1).reshape(seq_len, 1, -1)
102
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
103
+ if seq_len_actual > seq_len:
104
+ sequence_cond_seq_length = seq_len_actual - seq_len
105
+ if sequence_cond_seq_length == seq_len:
106
+ x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
107
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
108
+ else:
109
+ sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
110
+ sequence_cond_t_length = len(sequence_cond_compressed_index)
111
+ assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
112
+ x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
113
+ freqs_i_sequence_cond = torch.cat([
114
+ freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
115
+ freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
116
+ freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
117
+ ], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
118
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
119
+ x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
120
+ return x_i.float()
121
+
122
+
123
+ def forward(self, x, seq_lens, grid_size, freqs, sequence_cond_compressed_indices):
124
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
125
+ def qkv_fn(x):
126
+ q = self.q(x).view(b, s, n, d)
127
+ k = self.k(x).view(b, s, n, d)
128
+ v = self.v(x).view(b, s, n, d)
129
+ return q, k, v
130
+
131
+ q, k, v = qkv_fn(x)
132
+ q_rope = self.rope_apply(q, grid_size, freqs, sequence_cond_compressed_indices)
133
+ k_rope = self.rope_apply(k, grid_size, freqs, sequence_cond_compressed_indices)
134
+ if self.visualize_attention:
135
+ with torch.no_grad():
136
+ self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
137
+ self._last_grid_sizes = grid_size
138
+ self._last_seq_lens = seq_lens
139
+ x = flash_attention(
140
+ q=q_rope,
141
+ k=k_rope,
142
+ v=v,
143
+ k_lens=None,
144
+ window_size=self.window_size)
145
+ x = x.flatten(2)
146
+ x = self.o(x)
147
+ return x
148
+
149
+
150
+ class DeRA(nn.Module):
151
+ def __init__(self, dim, rank, use_spatial=True, use_temporal=True):
152
+ super(DeRA, self).__init__()
153
+ self.dim = dim
154
+ self.rank = rank
155
+ self.use_spatial = use_spatial
156
+ self.use_temporal = use_temporal
157
+
158
+ if not use_spatial and not use_temporal:
159
+ self.attention_mode = "none"
160
+ else:
161
+ self.attention_mode = "spatial_temporal" if use_spatial and use_temporal else "spatial" if use_spatial else "temporal"
162
+
163
+ self.spatial_down_proj = nn.Linear(self.dim, rank, bias=False)
164
+ self.spatial_up_proj = nn.Linear(rank, self.dim, bias=False)
165
+ self.spatial_up_proj.weight.data.zero_()
166
+ if self.attention_mode != "none":
167
+ self.spatial_attn = DeRAAttention(dim=rank, num_heads=4, window_size=(-1, -1),
168
+ mode=self.attention_mode)
169
+ else:
170
+ self.spatial_attn = None
171
+
172
+ def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
173
+ _, actual_seq, _ = x.shape
174
+ if isinstance(grid_sizes, torch.Tensor):
175
+ grid_sizes = tuple(grid_sizes[0].tolist())
176
+
177
+ if len(grid_sizes) != 3:
178
+ raise ValueError("`grid_sizes` should contain time, spatial height, and width dimensions")
179
+ _, orig_h, orig_w = grid_sizes
180
+ actual_t = actual_seq // (orig_h * orig_w)
181
+
182
+ x_low = self.spatial_down_proj(x)
183
+ if self.attention_mode == "spatial":
184
+ x_low_spatial = rearrange(x_low, 'b (t h w) r -> (b t) (h w) r', t=actual_t, h=orig_h, w=orig_w)
185
+ x_low_spatial = self.spatial_attn(x_low_spatial, seq_lens, grid_sizes[1:], freqs, sequence_cond_compressed_indices)
186
+ x_low = rearrange(x_low_spatial, '(b t) (h w) r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
187
+ elif self.attention_mode == "temporal":
188
+ x_low_temporal = rearrange(x_low, 'b (t h w) r -> (b h w) t r', t=actual_t, h=orig_h, w=orig_w)
189
+ x_low_temporal = self.spatial_attn(x_low_temporal, seq_lens, grid_sizes[:1], freqs, sequence_cond_compressed_indices)
190
+ x_low = rearrange(x_low_temporal, '(b h w) t r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
191
+ elif self.attention_mode == "spatial_temporal":
192
+ x_low = self.spatial_attn(x_low, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
193
+ x_out = self.spatial_up_proj(x_low)
194
+ return x_out
195
+
model/dit.py ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.amp as amp
5
+ import torch.nn as nn
6
+ from util.model_util import hash_state_dict_keys
7
+ from einops import rearrange
8
+
9
+ try:
10
+ import flash_attn_interface
11
+ FLASH_ATTN_3_AVAILABLE = True
12
+ except ModuleNotFoundError:
13
+ FLASH_ATTN_3_AVAILABLE = False
14
+
15
+ try:
16
+ import flash_attn
17
+ FLASH_ATTN_2_AVAILABLE = True
18
+ except ModuleNotFoundError:
19
+ FLASH_ATTN_2_AVAILABLE = False
20
+
21
+ try:
22
+ from sageattention import sageattn
23
+ SAGE_ATTN_AVAILABLE = True
24
+ except ModuleNotFoundError:
25
+ SAGE_ATTN_AVAILABLE = False
26
+
27
+ import warnings
28
+
29
+
30
+ __all__ = ['WanModel']
31
+
32
+
33
+ def flash_attention(
34
+ q,
35
+ k,
36
+ v,
37
+ q_lens=None,
38
+ k_lens=None,
39
+ dropout_p=0.,
40
+ softmax_scale=None,
41
+ q_scale=None,
42
+ causal=False,
43
+ window_size=(-1, -1),
44
+ deterministic=False,
45
+ dtype=torch.bfloat16,
46
+ version=None,
47
+ ):
48
+ """
49
+ q: [B, Lq, Nq, C1].
50
+ k: [B, Lk, Nk, C1].
51
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
52
+ q_lens: [B].
53
+ k_lens: [B].
54
+ dropout_p: float. Dropout probability.
55
+ softmax_scale: float. The scaling of QK^T before applying softmax.
56
+ causal: bool. Whether to apply causal attention mask.
57
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
58
+ deterministic: bool. If True, slightly slower and uses more memory.
59
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
60
+ """
61
+ half_dtypes = (torch.float16, torch.bfloat16)
62
+ assert dtype in half_dtypes
63
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
64
+
65
+ # params
66
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
67
+
68
+ def half(x):
69
+ return x if x.dtype in half_dtypes else x.to(dtype)
70
+
71
+ # preprocess query
72
+ if q_lens is None:
73
+ q = half(q.flatten(0, 1))
74
+ q_lens = torch.tensor(
75
+ [lq] * b, dtype=torch.int32).to(
76
+ device=q.device, non_blocking=True)
77
+ else:
78
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
79
+
80
+ # preprocess key, value
81
+ if k_lens is None:
82
+ k = half(k.flatten(0, 1))
83
+ v = half(v.flatten(0, 1))
84
+ k_lens = torch.tensor(
85
+ [lk] * b, dtype=torch.int32).to(
86
+ device=k.device, non_blocking=True)
87
+ else:
88
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
89
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
90
+
91
+ q = q.to(v.dtype)
92
+ k = k.to(v.dtype)
93
+
94
+ if q_scale is not None:
95
+ q = q * q_scale
96
+
97
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
98
+ warnings.warn(
99
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
100
+ )
101
+
102
+ # apply attention
103
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
104
+ # Note: dropout_p, window_size are not supported in FA3 now.
105
+ x = flash_attn_interface.flash_attn_varlen_func(
106
+ q=q,
107
+ k=k,
108
+ v=v,
109
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
110
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
111
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
112
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
113
+ seqused_q=None,
114
+ seqused_k=None,
115
+ max_seqlen_q=lq,
116
+ max_seqlen_k=lk,
117
+ softmax_scale=softmax_scale,
118
+ causal=causal,
119
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
120
+ elif FLASH_ATTN_2_AVAILABLE:
121
+ x = flash_attn.flash_attn_varlen_func(
122
+ q=q,
123
+ k=k,
124
+ v=v,
125
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
126
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
127
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
128
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
129
+ max_seqlen_q=lq,
130
+ max_seqlen_k=lk,
131
+ dropout_p=dropout_p,
132
+ softmax_scale=softmax_scale,
133
+ causal=causal,
134
+ window_size=window_size,
135
+ deterministic=deterministic).unflatten(0, (b, lq))
136
+ elif SAGE_ATTN_AVAILABLE:
137
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
138
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
139
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
140
+ x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
141
+ x = x.transpose(1, 2).contiguous()
142
+ else:
143
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
144
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
145
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
146
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
147
+ x = x.transpose(1, 2).contiguous()
148
+
149
+ # output
150
+ return x.type(out_dtype)
151
+
152
+
153
+ def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
154
+ b, lq, lk = q.size(0), q.size(1), k.size(1)
155
+ if q_lens is None:
156
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
157
+ if k_lens is None:
158
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
159
+ attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
160
+ for i in range(b):
161
+ q_len, k_len = q_lens[i], k_lens[i]
162
+ attn_mask[i, q_len:, :] = True
163
+ attn_mask[i, :, k_len:] = True
164
+
165
+ if causal:
166
+ causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
167
+ attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
168
+
169
+ attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
170
+ return attn_mask
171
+
172
+
173
+ def attention(
174
+ q,
175
+ k,
176
+ v,
177
+ q_lens=None,
178
+ k_lens=None,
179
+ dropout_p=0.,
180
+ softmax_scale=None,
181
+ q_scale=None,
182
+ causal=False,
183
+ window_size=(-1, -1),
184
+ deterministic=False,
185
+ dtype=torch.bfloat16,
186
+ fa_version=None,
187
+ ):
188
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
189
+ return flash_attention(
190
+ q=q,
191
+ k=k,
192
+ v=v,
193
+ q_lens=q_lens,
194
+ k_lens=k_lens,
195
+ dropout_p=dropout_p,
196
+ softmax_scale=softmax_scale,
197
+ q_scale=q_scale,
198
+ causal=causal,
199
+ window_size=window_size,
200
+ deterministic=deterministic,
201
+ dtype=dtype,
202
+ version=fa_version,
203
+ )
204
+ else:
205
+ if q_lens is not None or k_lens is not None:
206
+ warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
207
+ attn_mask = None
208
+
209
+ q = q.transpose(1, 2).to(dtype)
210
+ k = k.transpose(1, 2).to(dtype)
211
+ v = v.transpose(1, 2).to(dtype)
212
+
213
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
214
+
215
+ out = out.transpose(1, 2).contiguous()
216
+ return out
217
+
218
+
219
+
220
+ def sinusoidal_embedding_1d(dim, position):
221
+ # preprocess
222
+ assert dim % 2 == 0
223
+ half = dim // 2
224
+ position = position.type(torch.float64)
225
+
226
+ # calculation
227
+ sinusoid = torch.outer(
228
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
229
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
230
+ return x
231
+
232
+
233
+ @amp.autocast(enabled=False, device_type="cuda")
234
+ def rope_params(max_seq_len, dim, theta=10000):
235
+ assert dim % 2 == 0
236
+ freqs = torch.outer(
237
+ torch.arange(max_seq_len),
238
+ 1.0 / torch.pow(theta,
239
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
240
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
241
+ return freqs
242
+
243
+
244
+ @amp.autocast(enabled=False, device_type="cuda")
245
+ def rope_apply(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
246
+ batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
247
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
248
+ output = []
249
+ assert len(grid_sizes) == batch, "grid_sizes must have the same length as the batch size ([b, 3=[f, h, w])"
250
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
251
+ seq_len = f * h * w
252
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
253
+ seq_len, n, -1, 2))
254
+ freqs_i = torch.cat([
255
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
256
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
257
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
258
+ ], dim=-1).reshape(seq_len, 1, -1)
259
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
260
+
261
+ if seq_len_actual > seq_len:
262
+ sequence_cond_seq_length = seq_len_actual - seq_len
263
+ if sequence_cond_seq_length == seq_len:
264
+ x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(seq_len_actual - seq_len, n, -1, 2))
265
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(2)
266
+ else:
267
+ sequence_cond_compressed_index = sequence_cond_compressed_indices[i]
268
+ sequence_cond_t_length = len(sequence_cond_compressed_index)
269
+ assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
270
+ x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(sequence_cond_seq_length, n, -1, 2))
271
+ freqs_i_sequence_cond = torch.cat([
272
+ freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
273
+ freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
274
+ freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
275
+ ], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
276
+ x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(2)
277
+ x_i = torch.cat([x_i, x_i_sequence_cond])
278
+
279
+ output.append(x_i)
280
+ return torch.stack(output).float()
281
+
282
+
283
+ class WanRMSNorm(nn.Module):
284
+
285
+ def __init__(self, dim, eps=1e-5):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.eps = eps
289
+ self.weight = nn.Parameter(torch.ones(dim))
290
+
291
+ def forward(self, x):
292
+ return self._norm(x.float()).type_as(x) * self.weight
293
+
294
+ def _norm(self, x):
295
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
296
+
297
+
298
+ class WanLayerNorm(nn.LayerNorm):
299
+
300
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
301
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
302
+
303
+ def forward(self, x):
304
+ return super().forward(x.float()).type_as(x)
305
+
306
+
307
+ class WanSelfAttention(nn.Module):
308
+
309
+ def __init__(self,
310
+ dim,
311
+ num_heads,
312
+ window_size=(-1, -1),
313
+ qk_norm=True,
314
+ eps=1e-6):
315
+ assert dim % num_heads == 0
316
+ super().__init__()
317
+ self.dim = dim
318
+ self.num_heads = num_heads
319
+ self.head_dim = dim // num_heads
320
+ self.window_size = window_size
321
+ self.qk_norm = qk_norm
322
+ self.eps = eps
323
+
324
+ self.q = nn.Linear(dim, dim)
325
+ self.k = nn.Linear(dim, dim)
326
+ self.v = nn.Linear(dim, dim)
327
+ self.o = nn.Linear(dim, dim)
328
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
329
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
330
+ self.visualize_attention = False
331
+
332
+ def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
333
+ """
334
+ Args:
335
+ x: [B, L, C].
336
+ seq_lens: [B].
337
+ grid_sizes: [B, 3=[f, h, w]].
338
+ freqs: [L, 2].
339
+ sequence_cond_compressed_indices: [B, T_sequence_condITION].
340
+
341
+ `f` in `grid_sizes` can less than the actual seq_lens (L),
342
+ which indicates full in-context condition (when L=2*f) or
343
+ sparse in-context condition (when `f` < L < 2*f and `sequence_cond_compressed_indices` is not None) is used.
344
+ """
345
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
346
+
347
+ def qkv_fn(x):
348
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
349
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
350
+ v = self.v(x).view(b, s, n, d)
351
+ return q, k, v
352
+
353
+ q, k, v = qkv_fn(x)
354
+
355
+ q_rope = rope_apply(q, grid_sizes, freqs, sequence_cond_compressed_indices)
356
+ k_rope = rope_apply(k, grid_sizes, freqs, sequence_cond_compressed_indices)
357
+
358
+ if self.visualize_attention:
359
+ with torch.no_grad():
360
+ self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
361
+ self._last_grid_sizes = grid_sizes
362
+ self._last_seq_lens = seq_lens
363
+
364
+ x = flash_attention(
365
+ q=q_rope,
366
+ k=k_rope,
367
+ v=v,
368
+ k_lens=None,
369
+ window_size=self.window_size)
370
+
371
+ # output
372
+ x = x.flatten(2)
373
+ x = self.o(x)
374
+ return x
375
+
376
+ def _compute_attention_for_visualization(self, q, k):
377
+ """Compute attention maps for visualization purposes"""
378
+ # b, _, n, d = q.shape
379
+ print("Computing attention maps for visualization")
380
+ # Reshape for attention computation
381
+ q = q.permute(0, 2, 1, 3) # [b, n, s, d]
382
+ k = k.permute(0, 2, 1, 3) # [b, n, s, d]
383
+ # query: b, n, s, d
384
+ print("q.shape=", q.shape)
385
+ print("k.shape=", k.shape)
386
+ attention_probs_list = []
387
+ for i in range(0, q.shape[1], 20):
388
+ print(f"Computing attention for head {i} to {i+20}")
389
+ query_attention = q[-1][i : i + 20]
390
+ key_attention = k[-1][i : i + 20]
391
+ identity_matrix = torch.eye(
392
+ query_attention.shape[-2],
393
+ device=query_attention.device,
394
+ dtype=query_attention.dtype,
395
+ ) # shape=[s]
396
+ attention_probs_temp = torch.nn.functional.scaled_dot_product_attention(
397
+ query_attention,
398
+ key_attention,
399
+ identity_matrix,
400
+ attn_mask=None,
401
+ dropout_p=0.0,
402
+ is_causal=False,
403
+ )
404
+ attention_probs_list.append(attention_probs_temp.detach().cpu())
405
+ del (
406
+ query_attention,
407
+ key_attention,
408
+ identity_matrix,
409
+ attention_probs_temp,
410
+ )
411
+ attention_probs = torch.mean(torch.cat(attention_probs_list), dim=0).float().numpy()
412
+ print("Attention maps computed. Shape=", attention_probs.shape)
413
+ # Only keep attention maps, don't compute the output
414
+ return attention_probs # [s, s]
415
+
416
+
417
+ class WanT2VCrossAttention(WanSelfAttention):
418
+
419
+ def forward(self, x, context, context_lens):
420
+ """
421
+ x: [B, L1, C].
422
+ context: [B, L2, C].
423
+ context_lens: [B].
424
+ """
425
+ b, n, d = x.size(0), self.num_heads, self.head_dim
426
+
427
+ # compute query, key, value
428
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
429
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
430
+ v = self.v(context).view(b, -1, n, d)
431
+
432
+ # compute attention
433
+ x = flash_attention(q, k, v, k_lens=context_lens)
434
+
435
+ # output
436
+ x = x.flatten(2)
437
+ x = self.o(x)
438
+ return x
439
+
440
+
441
+ class WanI2VCrossAttention(WanSelfAttention):
442
+
443
+ def __init__(self,
444
+ dim,
445
+ num_heads,
446
+ window_size=(-1, -1),
447
+ qk_norm=True,
448
+ eps=1e-6):
449
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
450
+
451
+ self.k_img = nn.Linear(dim, dim)
452
+ self.v_img = nn.Linear(dim, dim)
453
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
454
+ self.norm_k_img = WanRMSNorm(
455
+ dim, eps=eps) if qk_norm else nn.Identity()
456
+
457
+ def forward(self, x, context, context_lens):
458
+ """
459
+ x: [B, L1, C].
460
+ context: [B, L2, C].
461
+ context_lens: [B].
462
+ """
463
+ context_img = context[:, :257]
464
+ context = context[:, 257:]
465
+ b, n, d = x.size(0), self.num_heads, self.head_dim
466
+
467
+ # compute query, key, value
468
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
469
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
470
+ v = self.v(context).view(b, -1, n, d)
471
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
472
+ v_img = self.v_img(context_img).view(b, -1, n, d)
473
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
474
+ # compute attention
475
+ x = flash_attention(q, k, v, k_lens=context_lens)
476
+
477
+ # output
478
+ x = x.flatten(2)
479
+ img_x = img_x.flatten(2)
480
+ x = x + img_x
481
+ x = self.o(x)
482
+ return x
483
+
484
+
485
+ WANX_CROSSATTENTION_CLASSES = {
486
+ 't2v_cross_attn': WanT2VCrossAttention,
487
+ 'i2v_cross_attn': WanI2VCrossAttention,
488
+ }
489
+
490
+
491
+ class WanAttentionBlock(nn.Module):
492
+
493
+ def __init__(self,
494
+ cross_attn_type,
495
+ dim,
496
+ ffn_dim,
497
+ num_heads,
498
+ window_size=(-1, -1),
499
+ qk_norm=True,
500
+ cross_attn_norm=False,
501
+ eps=1e-6,
502
+ use_local_lora=False,
503
+ use_dera=False,
504
+ dera_rank=None,
505
+ use_dera_spatial=True,
506
+ use_dera_temporal=True):
507
+ super().__init__()
508
+ self.dim = dim
509
+ self.ffn_dim = ffn_dim
510
+ self.num_heads = num_heads
511
+ self.window_size = window_size
512
+ self.qk_norm = qk_norm
513
+ self.cross_attn_norm = cross_attn_norm
514
+ self.eps = eps
515
+
516
+ # layers
517
+ self.norm1 = WanLayerNorm(dim, eps)
518
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
519
+ self.norm3 = WanLayerNorm(
520
+ dim, eps,
521
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
522
+ self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
523
+ dim, num_heads, (-1, -1), qk_norm, eps)
524
+ self.norm2 = WanLayerNorm(dim, eps)
525
+ self.ffn = nn.Sequential(
526
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
527
+ nn.Linear(ffn_dim, dim))
528
+
529
+ # modulation
530
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
531
+
532
+ self.use_local_lora = use_local_lora
533
+ if use_local_lora:
534
+ from .local_lora import LocalLoRA
535
+ self.local_lora = LocalLoRA(dim=dim, rank=64, kernel_size=(3, 3), stride=(1, 1))
536
+
537
+ self.use_dera = use_dera
538
+ if use_dera:
539
+ from .dera import DeRA
540
+ self.dera = DeRA(dim, rank=dera_rank, use_spatial=use_dera_spatial, use_temporal=use_dera_temporal)
541
+
542
+ def forward(
543
+ self,
544
+ x,
545
+ e,
546
+ seq_lens,
547
+ grid_sizes,
548
+ freqs,
549
+ context,
550
+ context_lens,
551
+ sequence_cond_compressed_indices,
552
+ dera_freqs=None
553
+ ):
554
+ assert e.dtype == torch.float32
555
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
556
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
557
+ assert e[0].dtype == torch.float32
558
+
559
+ # self-attention
560
+ x_self_attn_input = self.norm1(x).float() * (1 + e[1]) + e[0]
561
+ y = self.self_attn(x_self_attn_input, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
562
+ if self.use_local_lora:
563
+ y = y + self.local_lora(x_self_attn_input, grid_sizes)
564
+
565
+ if self.use_dera:
566
+ y = y + self.dera(x_self_attn_input, seq_lens, grid_sizes, dera_freqs, sequence_cond_compressed_indices)
567
+
568
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
569
+ x = x + y * e[2]
570
+
571
+ def cross_attn_ffn(x, context, context_lens, e):
572
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
573
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
574
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
575
+ x = x + y * e[5]
576
+ return x
577
+
578
+ x = cross_attn_ffn(x, context, context_lens, e)
579
+ return x
580
+
581
+
582
+ class Head(nn.Module):
583
+
584
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
585
+ super().__init__()
586
+ self.dim = dim
587
+ self.out_dim = out_dim
588
+ self.patch_size = patch_size
589
+ self.eps = eps
590
+
591
+ # layers
592
+ out_dim = math.prod(patch_size) * out_dim
593
+ self.norm = WanLayerNorm(dim, eps)
594
+ self.head = nn.Linear(dim, out_dim)
595
+
596
+ # modulation
597
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
598
+
599
+ def forward(self, x, e):
600
+ assert e.dtype == torch.float32
601
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
602
+ e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
603
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
604
+ return x
605
+
606
+
607
+ class MLPProj(torch.nn.Module):
608
+
609
+ def __init__(self, in_dim, out_dim):
610
+ super().__init__()
611
+
612
+ self.proj = torch.nn.Sequential(
613
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
614
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
615
+ torch.nn.LayerNorm(out_dim))
616
+
617
+ def forward(self, image_embeds):
618
+ clip_extra_context_tokens = self.proj(image_embeds)
619
+ return clip_extra_context_tokens
620
+
621
+
622
+ class WanModel(nn.Module):
623
+
624
+ def __init__(self,
625
+ model_type='t2v',
626
+ patch_size=(1, 2, 2),
627
+ text_len=512,
628
+ in_dim=16,
629
+ dim=2048,
630
+ ffn_dim=8192,
631
+ freq_dim=256,
632
+ text_dim=4096,
633
+ out_dim=16,
634
+ num_heads=16,
635
+ num_layers=32,
636
+ window_size=(-1, -1),
637
+ qk_norm=True,
638
+ cross_attn_norm=False,
639
+ eps=1e-6,
640
+ use_local_lora=False,
641
+ use_dera=False,
642
+ dera_rank=None,
643
+ use_dera_spatial=True,
644
+ use_dera_temporal=True,
645
+ use_sequence_cond=False,
646
+ sequence_cond_in_dim=None,
647
+ sequence_cond_mode=None,
648
+ use_channel_cond=False,
649
+ channel_cond_in_dim=None,
650
+ use_sequence_cond_position_aware_residual=False,
651
+ use_sequence_cond_loss=False
652
+ ):
653
+ super().__init__()
654
+
655
+ assert model_type in ['t2v', 'i2v']
656
+ self.model_type = model_type
657
+
658
+ self.patch_size = patch_size
659
+ self.text_len = text_len
660
+ self.in_dim = in_dim
661
+ self.dim = dim
662
+ self.ffn_dim = ffn_dim
663
+ self.freq_dim = freq_dim
664
+ self.text_dim = text_dim
665
+ self.out_dim = out_dim
666
+ self.num_heads = num_heads
667
+ self.num_layers = num_layers
668
+ self.window_size = window_size
669
+ self.qk_norm = qk_norm
670
+ self.cross_attn_norm = cross_attn_norm
671
+ self.eps = eps
672
+
673
+ self.use_local_lora = use_local_lora
674
+ self.use_dera = use_dera
675
+
676
+ # embeddings
677
+ self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
678
+ self.text_embedding = nn.Sequential(
679
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
680
+ nn.Linear(dim, dim))
681
+
682
+ self.time_embedding = nn.Sequential(
683
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
684
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
685
+
686
+ # blocks
687
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
688
+ self.blocks = nn.ModuleList([
689
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
690
+ window_size, qk_norm, cross_attn_norm, eps, use_local_lora=use_local_lora,
691
+ use_dera=use_dera, dera_rank=dera_rank, use_dera_spatial=use_dera_spatial, use_dera_temporal=use_dera_temporal)
692
+ for _ in range(num_layers)
693
+ ])
694
+
695
+ # head
696
+ self.head = Head(dim, out_dim, patch_size, eps)
697
+
698
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
699
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
700
+ d = dim // num_heads
701
+ self.freqs = torch.cat([
702
+ rope_params(1024, d - 4 * (d // 6)),
703
+ rope_params(1024, 2 * (d // 6)),
704
+ rope_params(1024, 2 * (d // 6))
705
+ ], dim=1)
706
+
707
+ if self.use_dera:
708
+ dera_d = dera_rank // 4 # (18)
709
+ self.dera_freqs = torch.cat([
710
+ rope_params(1024, dera_d - 4 * (dera_d // 6)),
711
+ rope_params(1024, 2 * (dera_d // 6)),
712
+ rope_params(1024, 2 * (dera_d // 6))
713
+ ], dim=1)
714
+ else:
715
+ self.dera_freqs = None
716
+
717
+ if model_type == 'i2v':
718
+ self.img_emb = MLPProj(1280, dim)
719
+
720
+ self.init_weights()
721
+
722
+ self.use_sequence_cond = use_sequence_cond
723
+ self.sequence_cond_in_dim = sequence_cond_in_dim
724
+ self.sequence_cond_mode = sequence_cond_mode
725
+ if use_sequence_cond:
726
+ assert sequence_cond_in_dim is not None, "`sequence_cond_in_dim` must be provided when `use_sequence_cond` is True"
727
+ self.sequence_cond_patch_embedding = nn.Conv3d(sequence_cond_in_dim, dim, kernel_size=patch_size, stride=patch_size)
728
+ self.sequence_cond_identifier = nn.Parameter(torch.randn(1, 1, dim) / dim**0.5)
729
+
730
+ self.use_channel_cond = use_channel_cond
731
+ self.channel_cond_in_dim = channel_cond_in_dim
732
+ if use_channel_cond:
733
+ assert channel_cond_in_dim is not None, "`channel_cond_in_dim` must be provided when `use_channel_cond` is True"
734
+ self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
735
+ if use_sequence_cond_position_aware_residual:
736
+ self.sequence_cond_residual_proj = nn.Linear(dim, dim, bias=False)
737
+ self.sequence_cond_residual_proj.weight.data.zero_()
738
+
739
+ self.use_sequence_cond_loss = use_sequence_cond_loss
740
+ if self.use_sequence_cond_loss:
741
+ self.sequence_latent_to_cond_proj = nn.Linear(dim, dim, bias=False)
742
+ self.sequence_latent_to_cond_proj.weight.data.zero_()
743
+ self.head_sequence_cond_out = nn.Linear(dim, math.prod(patch_size) * out_dim)
744
+
745
+ def copy_sequence_cond_patch_embedding_weights(self):
746
+ size_patch_embedding = self.patch_embedding.weight.size(1)
747
+ size_sequence_cond_patch_embedding = self.sequence_cond_patch_embedding.weight.size(1)
748
+ self.sequence_cond_patch_embedding.weight.data = self.patch_embedding.weight.data[:, size_patch_embedding - size_sequence_cond_patch_embedding:, :, :, :].clone()
749
+ if self.patch_embedding.bias is not None:
750
+ self.sequence_cond_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
751
+
752
+ def copy_patch_embedding_weights_for_channel_cond(self):
753
+ original_patch_in_channels = self.patch_embedding.in_channels
754
+ new_patch_embedding = nn.Conv3d(in_channels=original_patch_in_channels + self.channel_cond_in_dim,
755
+ out_channels=self.dim, kernel_size=self.patch_size, stride=self.patch_size)
756
+ new_patch_embedding.weight.data[:, :original_patch_in_channels, :, :, :] = self.patch_embedding.weight.data.clone()
757
+ if self.patch_embedding.bias is not None:
758
+ new_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
759
+ del self.patch_embedding
760
+ self.patch_embedding = new_patch_embedding
761
+
762
+ def forward(
763
+ self,
764
+ x,
765
+ timestep,
766
+ context,
767
+ seq_len,
768
+ clip_fea=None,
769
+ y=None,
770
+ use_gradient_checkpointing=False,
771
+ sequence_cond=None,
772
+ sequence_cond_compressed_indices=None,
773
+ channel_cond=None,
774
+ sequence_cond_residual_scale=1.0,
775
+ **kwargs,
776
+ ):
777
+ """
778
+ x: A list of videos each with shape [C, T, H, W].
779
+ t: [B].
780
+ context: A list of text embeddings each with shape [L, C].
781
+ sequence_cond: A list of conditional frames each with shape [C, T_sequence_cond, H, W].
782
+ sequence_cond_compressed_indices: [B, T_sequence_cond] Indices for any additional conditional information, where T_sequence_cond < T. For sparse mode only.
783
+
784
+
785
+ Note:
786
+ sequence_cond will be injected into the model as an additional input sequence, i.e., sequence dimension.
787
+ channel_cond will be injected into the model in the input' channel dimension.
788
+
789
+ Examples:
790
+ 1) for extra cond case:
791
+ # given x: [B, C, T, H, W] ----> [B, L=T*H*W, C] --patch_embedding--> [B, L, D]
792
+ # sequence_cond: [B, C_sequence_cond, T_sequence_cond, H, W] ----> [B, L_sequence_cond=T_sequence_cond*H*W, C_sequence_cond] --sequence_cond_embedding--> [B, L_sequence_cond, D]
793
+ x = torch.concat([x, sequence_cond], dim=2) # Concat on sequence dimension after patch/extra cond embedding
794
+ # after concat, x: [B, L+L_sequence_cond, D]
795
+ 2) for channel cond case:
796
+ given x: [B, C, T, H, W]
797
+ channel_cond: [B, C_CHANNEL_COND, T, H, W]
798
+ x = torch.concat([x, channel_cond], dim=1) # Concat on channel dimension before patch/extra cond embedding
799
+ # x: [B, C + C_CHANNEL_COND, T, H, W] --patch_embedding(requires param copy and tuning)--> [B, L=T*H*W, D]
800
+ """
801
+ if self.model_type == 'i2v':
802
+ assert clip_fea is not None and y is not None
803
+ # params
804
+ device = x[0].device
805
+ if self.freqs.device != device:
806
+ self.freqs = self.freqs.to(device)
807
+ if self.dera_freqs is not None and self.dera_freqs.device != device:
808
+ self.dera_freqs = self.dera_freqs.to(device)
809
+
810
+ if y is not None:
811
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
812
+
813
+ if channel_cond is not None:
814
+ assert self.use_channel_cond, "forward argument `channel_cond` is provided but model property `self.use_channel_cond` is False"
815
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, channel_cond)]
816
+
817
+ # embeddings
818
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
819
+ grid_sizes = torch.stack(
820
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
821
+ x = [u.flatten(2).transpose(1, 2) for u in x]
822
+ x = torch.cat(x, dim=0)
823
+
824
+ if sequence_cond is not None:
825
+ assert self.use_sequence_cond, "forward argument `sequence_cond` is provided but model property `self.use_sequence_cond` is False"
826
+ sequence_cond = [self.sequence_cond_patch_embedding(u.unsqueeze(0)) for u in sequence_cond]
827
+ sequence_cond = [u.flatten(2).transpose(1, 2) + self.sequence_cond_identifier for u in sequence_cond]
828
+ sequence_cond = torch.concat(sequence_cond, dim=0)
829
+
830
+ x = torch.concat([x, sequence_cond], dim=1)
831
+
832
+ actual_seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
833
+
834
+ # time embeddings
835
+ with amp.autocast(dtype=torch.float32, device_type="cuda"):
836
+ e = self.time_embedding(
837
+ sinusoidal_embedding_1d(self.freq_dim, timestep).float())
838
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
839
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
840
+
841
+ # context
842
+ context_lens = None
843
+ context = self.text_embedding(
844
+ torch.stack([
845
+ torch.cat(
846
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
847
+ for u in context
848
+ ]))
849
+
850
+ if clip_fea is not None:
851
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
852
+ context = torch.concat([context_clip, context], dim=1)
853
+
854
+ # arguments
855
+ kwargs = dict(e=e0, seq_lens=actual_seq_lens, grid_sizes=grid_sizes,
856
+ freqs=self.freqs, context=context, context_lens=context_lens,
857
+ sequence_cond_compressed_indices=sequence_cond_compressed_indices, dera_freqs=self.dera_freqs)
858
+
859
+ def create_custom_forward(module):
860
+ def custom_forward(*inputs, **kwargs):
861
+ return module(*inputs, **kwargs)
862
+ return custom_forward
863
+
864
+ for block_idx, block in enumerate(self.blocks):
865
+ if self.training and use_gradient_checkpointing:
866
+ x = torch.utils.checkpoint.checkpoint(
867
+ create_custom_forward(block),
868
+ x, **kwargs,
869
+ use_reentrant=False,
870
+ )
871
+ else:
872
+ x = block(x, **kwargs)
873
+
874
+ if self.use_sequence_cond_loss and block_idx == len(self.blocks) - 3:
875
+ # This this function, the context length will be extended from (N+C) to 2N, where C is the length of the sparse sequence cond.
876
+ x_ori = x[:, :seq_len, :]
877
+ x_ori_projected = self.sequence_latent_to_cond_proj(x_ori)
878
+ x_seq_cond = x[:, seq_len:, :]
879
+ seq_cond_length = len(sequence_cond_compressed_indices[0])
880
+ x_ori_projected = rearrange(x_ori_projected, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
881
+ x_seq_cond = rearrange(x_seq_cond, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
882
+ x_ori_projected[:, :, sequence_cond_compressed_indices[0], :, :] += x_seq_cond
883
+ x_ori_projected = rearrange(x_ori_projected, 'b c t h w -> b (t h w) c')
884
+ x = torch.concat([x_ori, x_ori_projected], dim=1)
885
+ # Let the later blocks generate sketches at the full seqeuence length
886
+
887
+ if self.use_sequence_cond_position_aware_residual and block_idx < len(self.blocks) - 1:
888
+ # Apply the sequence condition position-aware residual for all blocks except the last one
889
+ x_ori = x[:, :seq_len, :]
890
+ x_seq_cond = x[:, seq_len:, :]
891
+ x_seq_cond_porjected = self.sequence_cond_residual_proj(x_seq_cond)
892
+ assert x_ori.shape[0] == 1, "Only support batch size 1 for `sequence_cond_position_aware_residual`."
893
+ seq_cond_length = len(sequence_cond_compressed_indices[0])
894
+ x_ori = rearrange(x_ori, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
895
+ x_seq_cond_porjected = rearrange(x_seq_cond_porjected, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
896
+
897
+ x_ori[:, :, sequence_cond_compressed_indices[0], :, :] = x_ori[:, :, sequence_cond_compressed_indices[0], :, :] + x_seq_cond_porjected * sequence_cond_residual_scale
898
+ x_ori = rearrange(x_ori, 'b c t h w -> b (t h w) c')
899
+ x = torch.concat([x_ori, x_seq_cond], dim=1)
900
+
901
+ if sequence_cond is not None:
902
+ if self.use_sequence_cond_loss:
903
+ sequence_cond_out = x[:, seq_len:, :]
904
+ sequence_cond_out = self.unpatchify(sequence_cond_out, grid_sizes) # sequence_cond_grid_sizes
905
+ sequence_cond_out = torch.stack(sequence_cond_out).float() # b, c, t, h, w
906
+ else:
907
+ sequence_cond_out = None
908
+ x = x[:, :seq_len, :]
909
+ # head
910
+ x = self.head(x, e)
911
+
912
+ # unpatchify
913
+ x = self.unpatchify(x, grid_sizes)
914
+ x = torch.stack(x).float()
915
+ if sequence_cond is not None and self.use_sequence_cond_loss:
916
+ return x, sequence_cond_out
917
+ return x
918
+
919
+ def unpatchify(self, x, grid_sizes):
920
+ c = self.out_dim
921
+ out = []
922
+ for u, v in zip(x, grid_sizes.tolist()):
923
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
924
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
925
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
926
+ out.append(u)
927
+ return out
928
+
929
+ def init_weights(self):
930
+ for m in self.modules():
931
+ if isinstance(m, nn.Linear):
932
+ nn.init.xavier_uniform_(m.weight)
933
+ if m.bias is not None:
934
+ nn.init.zeros_(m.bias)
935
+
936
+ # init embeddings
937
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
938
+ for m in self.text_embedding.modules():
939
+ if isinstance(m, nn.Linear):
940
+ nn.init.normal_(m.weight, std=.02)
941
+ for m in self.time_embedding.modules():
942
+ if isinstance(m, nn.Linear):
943
+ nn.init.normal_(m.weight, std=.02)
944
+
945
+ # init output layer
946
+ nn.init.zeros_(self.head.head.weight)
947
+
948
+ @staticmethod
949
+ def state_dict_converter():
950
+ return WanModelStateDictConverter()
951
+
952
+
953
+ class WanModelStateDictConverter:
954
+ def __init__(self):
955
+ pass
956
+
957
+ def from_diffusers(self, state_dict):
958
+ rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
959
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
960
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
961
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
962
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
963
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
964
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
965
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
966
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
967
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
968
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
969
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
970
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
971
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
972
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
973
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
974
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
975
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
976
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
977
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
978
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
979
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
980
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
981
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
982
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
983
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
984
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
985
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
986
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
987
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
988
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
989
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
990
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
991
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
992
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
993
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
994
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
995
+ "patch_embedding.bias": "patch_embedding.bias",
996
+ "patch_embedding.weight": "patch_embedding.weight",
997
+ "scale_shift_table": "head.modulation",
998
+ "proj_out.bias": "head.head.bias",
999
+ "proj_out.weight": "head.head.weight",
1000
+ }
1001
+ state_dict_ = {}
1002
+ for name, param in state_dict.items():
1003
+ if name in rename_dict:
1004
+ state_dict_[rename_dict[name]] = param
1005
+ else:
1006
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
1007
+ if name_ in rename_dict:
1008
+ name_ = rename_dict[name_]
1009
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
1010
+ state_dict_[name_] = param
1011
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
1012
+ config = {
1013
+ "model_type": "t2v",
1014
+ "patch_size": (1, 2, 2),
1015
+ "text_len": 512,
1016
+ "in_dim": 16,
1017
+ "dim": 5120,
1018
+ "ffn_dim": 13824,
1019
+ "freq_dim": 256,
1020
+ "text_dim": 4096,
1021
+ "out_dim": 16,
1022
+ "num_heads": 40,
1023
+ "num_layers": 40,
1024
+ "window_size": (-1, -1),
1025
+ "qk_norm": True,
1026
+ "cross_attn_norm": True,
1027
+ "eps": 1e-6,
1028
+ }
1029
+ else:
1030
+ config = {}
1031
+ return state_dict_, config
1032
+
1033
+ def from_civitai(self, state_dict):
1034
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
1035
+ config = {
1036
+ "model_type": "t2v",
1037
+ "patch_size": (1, 2, 2),
1038
+ "text_len": 512,
1039
+ "in_dim": 16,
1040
+ "dim": 1536,
1041
+ "ffn_dim": 8960,
1042
+ "freq_dim": 256,
1043
+ "text_dim": 4096,
1044
+ "out_dim": 16,
1045
+ "num_heads": 12,
1046
+ "num_layers": 30,
1047
+ "window_size": (-1, -1),
1048
+ "qk_norm": True,
1049
+ "cross_attn_norm": True,
1050
+ "eps": 1e-6,
1051
+ }
1052
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
1053
+ config = {
1054
+ "model_type": "t2v",
1055
+ "patch_size": (1, 2, 2),
1056
+ "text_len": 512,
1057
+ "in_dim": 16,
1058
+ "dim": 5120,
1059
+ "ffn_dim": 13824,
1060
+ "freq_dim": 256,
1061
+ "text_dim": 4096,
1062
+ "out_dim": 16,
1063
+ "num_heads": 40,
1064
+ "num_layers": 40,
1065
+ "window_size": (-1, -1),
1066
+ "qk_norm": True,
1067
+ "cross_attn_norm": True,
1068
+ "eps": 1e-6,
1069
+ }
1070
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
1071
+ config = {
1072
+ "model_type": "i2v",
1073
+ "patch_size": (1, 2, 2),
1074
+ "text_len": 512,
1075
+ "in_dim": 36,
1076
+ "dim": 5120,
1077
+ "ffn_dim": 13824,
1078
+ "freq_dim": 256,
1079
+ "text_dim": 4096,
1080
+ "out_dim": 16,
1081
+ "num_heads": 40,
1082
+ "num_layers": 40,
1083
+ "window_size": (-1, -1),
1084
+ "qk_norm": True,
1085
+ "cross_attn_norm": True,
1086
+ "eps": 1e-6,
1087
+ }
1088
+ else:
1089
+ config = {}
1090
+ return state_dict, config
model/image_encoder.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from .dit import flash_attention
12
+
13
+
14
+ class SelfAttention(nn.Module):
15
+
16
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
17
+ assert dim % num_heads == 0
18
+ super().__init__()
19
+ self.dim = dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+ self.eps = eps
23
+
24
+ # layers
25
+ self.q = nn.Linear(dim, dim)
26
+ self.k = nn.Linear(dim, dim)
27
+ self.v = nn.Linear(dim, dim)
28
+ self.o = nn.Linear(dim, dim)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x, mask):
32
+ """
33
+ x: [B, L, C].
34
+ """
35
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
36
+
37
+ # compute query, key, value
38
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
39
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+
42
+ # compute attention
43
+ p = self.dropout.p if self.training else 0.0
44
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
45
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
46
+
47
+ # output
48
+ x = self.o(x)
49
+ x = self.dropout(x)
50
+ return x
51
+
52
+
53
+ class AttentionBlock(nn.Module):
54
+
55
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.num_heads = num_heads
59
+ self.post_norm = post_norm
60
+ self.eps = eps
61
+
62
+ # layers
63
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
64
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
67
+ nn.Dropout(dropout))
68
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
69
+
70
+ def forward(self, x, mask):
71
+ if self.post_norm:
72
+ x = self.norm1(x + self.attn(x, mask))
73
+ x = self.norm2(x + self.ffn(x))
74
+ else:
75
+ x = x + self.attn(self.norm1(x), mask)
76
+ x = x + self.ffn(self.norm2(x))
77
+ return x
78
+
79
+
80
+ class XLMRoberta(nn.Module):
81
+ """
82
+ XLMRobertaModel with no pooler and no LM head.
83
+ """
84
+
85
+ def __init__(self,
86
+ vocab_size=250002,
87
+ max_seq_len=514,
88
+ type_size=1,
89
+ pad_id=1,
90
+ dim=1024,
91
+ num_heads=16,
92
+ num_layers=24,
93
+ post_norm=True,
94
+ dropout=0.1,
95
+ eps=1e-5):
96
+ super().__init__()
97
+ self.vocab_size = vocab_size
98
+ self.max_seq_len = max_seq_len
99
+ self.type_size = type_size
100
+ self.pad_id = pad_id
101
+ self.dim = dim
102
+ self.num_heads = num_heads
103
+ self.num_layers = num_layers
104
+ self.post_norm = post_norm
105
+ self.eps = eps
106
+
107
+ # embeddings
108
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
109
+ self.type_embedding = nn.Embedding(type_size, dim)
110
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
111
+ self.dropout = nn.Dropout(dropout)
112
+
113
+ # blocks
114
+ self.blocks = nn.ModuleList([
115
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
116
+ for _ in range(num_layers)
117
+ ])
118
+
119
+ # norm layer
120
+ self.norm = nn.LayerNorm(dim, eps=eps)
121
+
122
+ def forward(self, ids):
123
+ """
124
+ ids: [B, L] of torch.LongTensor.
125
+ """
126
+ b, s = ids.shape
127
+ mask = ids.ne(self.pad_id).long()
128
+
129
+ # embeddings
130
+ x = self.token_embedding(ids) + \
131
+ self.type_embedding(torch.zeros_like(ids)) + \
132
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
133
+ if self.post_norm:
134
+ x = self.norm(x)
135
+ x = self.dropout(x)
136
+
137
+ # blocks
138
+ mask = torch.where(
139
+ mask.view(b, 1, 1, s).gt(0), 0.0,
140
+ torch.finfo(x.dtype).min)
141
+ for block in self.blocks:
142
+ x = block(x, mask)
143
+
144
+ # output
145
+ if not self.post_norm:
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ def xlm_roberta_large(pretrained=False,
151
+ return_tokenizer=False,
152
+ device='cpu',
153
+ **kwargs):
154
+ """
155
+ XLMRobertaLarge adapted from Huggingface.
156
+ """
157
+ # params
158
+ cfg = dict(
159
+ vocab_size=250002,
160
+ max_seq_len=514,
161
+ type_size=1,
162
+ pad_id=1,
163
+ dim=1024,
164
+ num_heads=16,
165
+ num_layers=24,
166
+ post_norm=True,
167
+ dropout=0.1,
168
+ eps=1e-5)
169
+ cfg.update(**kwargs)
170
+
171
+ # init model
172
+ if pretrained:
173
+ from sora import DOWNLOAD_TO_CACHE
174
+
175
+ # init a meta model
176
+ with torch.device('meta'):
177
+ model = XLMRoberta(**cfg)
178
+
179
+ # load checkpoint
180
+ model.load_state_dict(
181
+ torch.load(
182
+ DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
183
+ map_location=device),
184
+ assign=True)
185
+ else:
186
+ # init a model on device
187
+ with torch.device(device):
188
+ model = XLMRoberta(**cfg)
189
+
190
+ # init tokenizer
191
+ if return_tokenizer:
192
+ from sora.data import HuggingfaceTokenizer
193
+ tokenizer = HuggingfaceTokenizer(
194
+ name='xlm-roberta-large',
195
+ seq_len=model.text_len,
196
+ clean='whitespace')
197
+ return model, tokenizer
198
+ else:
199
+ return model
200
+
201
+
202
+
203
+ def pos_interpolate(pos, seq_len):
204
+ if pos.size(1) == seq_len:
205
+ return pos
206
+ else:
207
+ src_grid = int(math.sqrt(pos.size(1)))
208
+ tar_grid = int(math.sqrt(seq_len))
209
+ n = pos.size(1) - src_grid * src_grid
210
+ return torch.cat([
211
+ pos[:, :n],
212
+ F.interpolate(
213
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
214
+ 0, 3, 1, 2),
215
+ size=(tar_grid, tar_grid),
216
+ mode='bicubic',
217
+ align_corners=False).flatten(2).transpose(1, 2)
218
+ ],
219
+ dim=1)
220
+
221
+
222
+ class QuickGELU(nn.Module):
223
+
224
+ def forward(self, x):
225
+ return x * torch.sigmoid(1.702 * x)
226
+
227
+
228
+ class LayerNorm(nn.LayerNorm):
229
+
230
+ def forward(self, x):
231
+ return super().forward(x.float()).type_as(x)
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+
236
+ def __init__(self,
237
+ dim,
238
+ num_heads,
239
+ causal=False,
240
+ attn_dropout=0.0,
241
+ proj_dropout=0.0):
242
+ assert dim % num_heads == 0
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+ self.causal = causal
248
+ self.attn_dropout = attn_dropout
249
+ self.proj_dropout = proj_dropout
250
+
251
+ # layers
252
+ self.to_qkv = nn.Linear(dim, dim * 3)
253
+ self.proj = nn.Linear(dim, dim)
254
+
255
+ def forward(self, x):
256
+ """
257
+ x: [B, L, C].
258
+ """
259
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
260
+
261
+ # compute query, key, value
262
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
263
+
264
+ # compute attention
265
+ p = self.attn_dropout if self.training else 0.0
266
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
267
+ x = x.reshape(b, s, c)
268
+
269
+ # output
270
+ x = self.proj(x)
271
+ x = F.dropout(x, self.proj_dropout, self.training)
272
+ return x
273
+
274
+
275
+ class SwiGLU(nn.Module):
276
+
277
+ def __init__(self, dim, mid_dim):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.mid_dim = mid_dim
281
+
282
+ # layers
283
+ self.fc1 = nn.Linear(dim, mid_dim)
284
+ self.fc2 = nn.Linear(dim, mid_dim)
285
+ self.fc3 = nn.Linear(mid_dim, dim)
286
+
287
+ def forward(self, x):
288
+ x = F.silu(self.fc1(x)) * self.fc2(x)
289
+ x = self.fc3(x)
290
+ return x
291
+
292
+
293
+ class AttentionBlock(nn.Module):
294
+
295
+ def __init__(self,
296
+ dim,
297
+ mlp_ratio,
298
+ num_heads,
299
+ post_norm=False,
300
+ causal=False,
301
+ activation='quick_gelu',
302
+ attn_dropout=0.0,
303
+ proj_dropout=0.0,
304
+ norm_eps=1e-5):
305
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.mlp_ratio = mlp_ratio
309
+ self.num_heads = num_heads
310
+ self.post_norm = post_norm
311
+ self.causal = causal
312
+ self.norm_eps = norm_eps
313
+
314
+ # layers
315
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
316
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
317
+ proj_dropout)
318
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
319
+ if activation == 'swi_glu':
320
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
321
+ else:
322
+ self.mlp = nn.Sequential(
323
+ nn.Linear(dim, int(dim * mlp_ratio)),
324
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
325
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
326
+
327
+ def forward(self, x):
328
+ if self.post_norm:
329
+ x = x + self.norm1(self.attn(x))
330
+ x = x + self.norm2(self.mlp(x))
331
+ else:
332
+ x = x + self.attn(self.norm1(x))
333
+ x = x + self.mlp(self.norm2(x))
334
+ return x
335
+
336
+
337
+ class AttentionPool(nn.Module):
338
+
339
+ def __init__(self,
340
+ dim,
341
+ mlp_ratio,
342
+ num_heads,
343
+ activation='gelu',
344
+ proj_dropout=0.0,
345
+ norm_eps=1e-5):
346
+ assert dim % num_heads == 0
347
+ super().__init__()
348
+ self.dim = dim
349
+ self.mlp_ratio = mlp_ratio
350
+ self.num_heads = num_heads
351
+ self.head_dim = dim // num_heads
352
+ self.proj_dropout = proj_dropout
353
+ self.norm_eps = norm_eps
354
+
355
+ # layers
356
+ gain = 1.0 / math.sqrt(dim)
357
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
358
+ self.to_q = nn.Linear(dim, dim)
359
+ self.to_kv = nn.Linear(dim, dim * 2)
360
+ self.proj = nn.Linear(dim, dim)
361
+ self.norm = LayerNorm(dim, eps=norm_eps)
362
+ self.mlp = nn.Sequential(
363
+ nn.Linear(dim, int(dim * mlp_ratio)),
364
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
365
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
366
+
367
+ def forward(self, x):
368
+ """
369
+ x: [B, L, C].
370
+ """
371
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
372
+
373
+ # compute query, key, value
374
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
375
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
376
+
377
+ # compute attention
378
+ x = flash_attention(q, k, v, version=2)
379
+ x = x.reshape(b, 1, c)
380
+
381
+ # output
382
+ x = self.proj(x)
383
+ x = F.dropout(x, self.proj_dropout, self.training)
384
+
385
+ # mlp
386
+ x = x + self.mlp(self.norm(x))
387
+ return x[:, 0]
388
+
389
+
390
+ class VisionTransformer(nn.Module):
391
+
392
+ def __init__(self,
393
+ image_size=224,
394
+ patch_size=16,
395
+ dim=768,
396
+ mlp_ratio=4,
397
+ out_dim=512,
398
+ num_heads=12,
399
+ num_layers=12,
400
+ pool_type='token',
401
+ pre_norm=True,
402
+ post_norm=False,
403
+ activation='quick_gelu',
404
+ attn_dropout=0.0,
405
+ proj_dropout=0.0,
406
+ embedding_dropout=0.0,
407
+ norm_eps=1e-5):
408
+ if image_size % patch_size != 0:
409
+ print(
410
+ '[WARNING] image_size is not divisible by patch_size',
411
+ flush=True)
412
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
413
+ out_dim = out_dim or dim
414
+ super().__init__()
415
+ self.image_size = image_size
416
+ self.patch_size = patch_size
417
+ self.num_patches = (image_size // patch_size)**2
418
+ self.dim = dim
419
+ self.mlp_ratio = mlp_ratio
420
+ self.out_dim = out_dim
421
+ self.num_heads = num_heads
422
+ self.num_layers = num_layers
423
+ self.pool_type = pool_type
424
+ self.post_norm = post_norm
425
+ self.norm_eps = norm_eps
426
+
427
+ # embeddings
428
+ gain = 1.0 / math.sqrt(dim)
429
+ self.patch_embedding = nn.Conv2d(
430
+ 3,
431
+ dim,
432
+ kernel_size=patch_size,
433
+ stride=patch_size,
434
+ bias=not pre_norm)
435
+ if pool_type in ('token', 'token_fc'):
436
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
437
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
438
+ 1, self.num_patches +
439
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
440
+ self.dropout = nn.Dropout(embedding_dropout)
441
+
442
+ # transformer
443
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
444
+ self.transformer = nn.Sequential(*[
445
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
446
+ activation, attn_dropout, proj_dropout, norm_eps)
447
+ for _ in range(num_layers)
448
+ ])
449
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
450
+
451
+ # head
452
+ if pool_type == 'token':
453
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
454
+ elif pool_type == 'token_fc':
455
+ self.head = nn.Linear(dim, out_dim)
456
+ elif pool_type == 'attn_pool':
457
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
458
+ proj_dropout, norm_eps)
459
+
460
+ def forward(self, x, interpolation=False, use_31_block=False):
461
+ b = x.size(0)
462
+
463
+ # embeddings
464
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
465
+ if self.pool_type in ('token', 'token_fc'):
466
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
467
+ if interpolation:
468
+ e = pos_interpolate(self.pos_embedding, x.size(1))
469
+ else:
470
+ e = self.pos_embedding
471
+ e = e.to(dtype=x.dtype, device=x.device)
472
+ x = self.dropout(x + e)
473
+ if self.pre_norm is not None:
474
+ x = self.pre_norm(x)
475
+
476
+ # transformer
477
+ if use_31_block:
478
+ x = self.transformer[:-1](x)
479
+ return x
480
+ else:
481
+ x = self.transformer(x)
482
+ return x
483
+
484
+
485
+ class CLIP(nn.Module):
486
+
487
+ def __init__(self,
488
+ embed_dim=512,
489
+ image_size=224,
490
+ patch_size=16,
491
+ vision_dim=768,
492
+ vision_mlp_ratio=4,
493
+ vision_heads=12,
494
+ vision_layers=12,
495
+ vision_pool='token',
496
+ vision_pre_norm=True,
497
+ vision_post_norm=False,
498
+ vocab_size=49408,
499
+ text_len=77,
500
+ text_dim=512,
501
+ text_mlp_ratio=4,
502
+ text_heads=8,
503
+ text_layers=12,
504
+ text_causal=True,
505
+ text_pool='argmax',
506
+ text_head_bias=False,
507
+ logit_bias=None,
508
+ activation='quick_gelu',
509
+ attn_dropout=0.0,
510
+ proj_dropout=0.0,
511
+ embedding_dropout=0.0,
512
+ norm_eps=1e-5):
513
+ super().__init__()
514
+ self.embed_dim = embed_dim
515
+ self.image_size = image_size
516
+ self.patch_size = patch_size
517
+ self.vision_dim = vision_dim
518
+ self.vision_mlp_ratio = vision_mlp_ratio
519
+ self.vision_heads = vision_heads
520
+ self.vision_layers = vision_layers
521
+ self.vision_pool = vision_pool
522
+ self.vision_pre_norm = vision_pre_norm
523
+ self.vision_post_norm = vision_post_norm
524
+ self.vocab_size = vocab_size
525
+ self.text_len = text_len
526
+ self.text_dim = text_dim
527
+ self.text_mlp_ratio = text_mlp_ratio
528
+ self.text_heads = text_heads
529
+ self.text_layers = text_layers
530
+ self.text_causal = text_causal
531
+ self.text_pool = text_pool
532
+ self.text_head_bias = text_head_bias
533
+ self.norm_eps = norm_eps
534
+
535
+ # models
536
+ self.visual = VisionTransformer(
537
+ image_size=image_size,
538
+ patch_size=patch_size,
539
+ dim=vision_dim,
540
+ mlp_ratio=vision_mlp_ratio,
541
+ out_dim=embed_dim,
542
+ num_heads=vision_heads,
543
+ num_layers=vision_layers,
544
+ pool_type=vision_pool,
545
+ pre_norm=vision_pre_norm,
546
+ post_norm=vision_post_norm,
547
+ activation=activation,
548
+ attn_dropout=attn_dropout,
549
+ proj_dropout=proj_dropout,
550
+ embedding_dropout=embedding_dropout,
551
+ norm_eps=norm_eps)
552
+ self.textual = TextTransformer(
553
+ vocab_size=vocab_size,
554
+ text_len=text_len,
555
+ dim=text_dim,
556
+ mlp_ratio=text_mlp_ratio,
557
+ out_dim=embed_dim,
558
+ num_heads=text_heads,
559
+ num_layers=text_layers,
560
+ causal=text_causal,
561
+ pool_type=text_pool,
562
+ head_bias=text_head_bias,
563
+ activation=activation,
564
+ attn_dropout=attn_dropout,
565
+ proj_dropout=proj_dropout,
566
+ embedding_dropout=embedding_dropout,
567
+ norm_eps=norm_eps)
568
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
569
+ if logit_bias is not None:
570
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
571
+
572
+ # initialize weights
573
+ self.init_weights()
574
+
575
+ def forward(self, imgs, txt_ids):
576
+ """
577
+ imgs: [B, 3, H, W] of torch.float32.
578
+ - mean: [0.48145466, 0.4578275, 0.40821073]
579
+ - std: [0.26862954, 0.26130258, 0.27577711]
580
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
581
+ """
582
+ xi = self.visual(imgs)
583
+ xt = self.textual(txt_ids)
584
+ return xi, xt
585
+
586
+ def init_weights(self):
587
+ # embeddings
588
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
589
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
590
+
591
+ # attentions
592
+ for modality in ['visual', 'textual']:
593
+ dim = self.vision_dim if modality == 'visual' else self.text_dim
594
+ transformer = getattr(self, modality).transformer
595
+ proj_gain = (1.0 / math.sqrt(dim)) * (
596
+ 1.0 / math.sqrt(2 * len(transformer)))
597
+ attn_gain = 1.0 / math.sqrt(dim)
598
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
599
+ for block in transformer:
600
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
601
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
602
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
603
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
604
+
605
+ def param_groups(self):
606
+ groups = [{
607
+ 'params': [
608
+ p for n, p in self.named_parameters()
609
+ if 'norm' in n or n.endswith('bias')
610
+ ],
611
+ 'weight_decay': 0.0
612
+ }, {
613
+ 'params': [
614
+ p for n, p in self.named_parameters()
615
+ if not ('norm' in n or n.endswith('bias'))
616
+ ]
617
+ }]
618
+ return groups
619
+
620
+
621
+ class XLMRobertaWithHead(XLMRoberta):
622
+
623
+ def __init__(self, **kwargs):
624
+ self.out_dim = kwargs.pop('out_dim')
625
+ super().__init__(**kwargs)
626
+
627
+ # head
628
+ mid_dim = (self.dim + self.out_dim) // 2
629
+ self.head = nn.Sequential(
630
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
631
+ nn.Linear(mid_dim, self.out_dim, bias=False))
632
+
633
+ def forward(self, ids):
634
+ # xlm-roberta
635
+ x = super().forward(ids)
636
+
637
+ # average pooling
638
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
639
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
640
+
641
+ # head
642
+ x = self.head(x)
643
+ return x
644
+
645
+
646
+ class XLMRobertaCLIP(nn.Module):
647
+
648
+ def __init__(self,
649
+ embed_dim=1024,
650
+ image_size=224,
651
+ patch_size=14,
652
+ vision_dim=1280,
653
+ vision_mlp_ratio=4,
654
+ vision_heads=16,
655
+ vision_layers=32,
656
+ vision_pool='token',
657
+ vision_pre_norm=True,
658
+ vision_post_norm=False,
659
+ activation='gelu',
660
+ vocab_size=250002,
661
+ max_text_len=514,
662
+ type_size=1,
663
+ pad_id=1,
664
+ text_dim=1024,
665
+ text_heads=16,
666
+ text_layers=24,
667
+ text_post_norm=True,
668
+ text_dropout=0.1,
669
+ attn_dropout=0.0,
670
+ proj_dropout=0.0,
671
+ embedding_dropout=0.0,
672
+ norm_eps=1e-5):
673
+ super().__init__()
674
+ self.embed_dim = embed_dim
675
+ self.image_size = image_size
676
+ self.patch_size = patch_size
677
+ self.vision_dim = vision_dim
678
+ self.vision_mlp_ratio = vision_mlp_ratio
679
+ self.vision_heads = vision_heads
680
+ self.vision_layers = vision_layers
681
+ self.vision_pre_norm = vision_pre_norm
682
+ self.vision_post_norm = vision_post_norm
683
+ self.activation = activation
684
+ self.vocab_size = vocab_size
685
+ self.max_text_len = max_text_len
686
+ self.type_size = type_size
687
+ self.pad_id = pad_id
688
+ self.text_dim = text_dim
689
+ self.text_heads = text_heads
690
+ self.text_layers = text_layers
691
+ self.text_post_norm = text_post_norm
692
+ self.norm_eps = norm_eps
693
+
694
+ # models
695
+ self.visual = VisionTransformer(
696
+ image_size=image_size,
697
+ patch_size=patch_size,
698
+ dim=vision_dim,
699
+ mlp_ratio=vision_mlp_ratio,
700
+ out_dim=embed_dim,
701
+ num_heads=vision_heads,
702
+ num_layers=vision_layers,
703
+ pool_type=vision_pool,
704
+ pre_norm=vision_pre_norm,
705
+ post_norm=vision_post_norm,
706
+ activation=activation,
707
+ attn_dropout=attn_dropout,
708
+ proj_dropout=proj_dropout,
709
+ embedding_dropout=embedding_dropout,
710
+ norm_eps=norm_eps)
711
+ self.textual = None
712
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
713
+
714
+ def forward(self, imgs, txt_ids):
715
+ """
716
+ imgs: [B, 3, H, W] of torch.float32.
717
+ - mean: [0.48145466, 0.4578275, 0.40821073]
718
+ - std: [0.26862954, 0.26130258, 0.27577711]
719
+ txt_ids: [B, L] of torch.long.
720
+ Encoded by data.CLIPTokenizer.
721
+ """
722
+ xi = self.visual(imgs)
723
+ xt = self.textual(txt_ids)
724
+ return xi, xt
725
+
726
+ def param_groups(self):
727
+ groups = [{
728
+ 'params': [
729
+ p for n, p in self.named_parameters()
730
+ if 'norm' in n or n.endswith('bias')
731
+ ],
732
+ 'weight_decay': 0.0
733
+ }, {
734
+ 'params': [
735
+ p for n, p in self.named_parameters()
736
+ if not ('norm' in n or n.endswith('bias'))
737
+ ]
738
+ }]
739
+ return groups
740
+
741
+
742
+ def _clip(pretrained=False,
743
+ pretrained_name=None,
744
+ model_cls=CLIP,
745
+ return_transforms=False,
746
+ return_tokenizer=False,
747
+ tokenizer_padding='eos',
748
+ dtype=torch.float32,
749
+ device='cpu',
750
+ **kwargs):
751
+ # init model
752
+ if pretrained and pretrained_name:
753
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
754
+
755
+ # init a meta model
756
+ with torch.device('meta'):
757
+ model = model_cls(**kwargs)
758
+
759
+ # checkpoint path
760
+ checkpoint = f'models/clip/{pretrained_name}'
761
+ if dtype in (torch.float16, torch.bfloat16):
762
+ suffix = '-' + {
763
+ torch.float16: 'fp16',
764
+ torch.bfloat16: 'bf16'
765
+ }[dtype]
766
+ if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
767
+ checkpoint = f'{checkpoint}{suffix}'
768
+ checkpoint += '.pth'
769
+
770
+ # load
771
+ model.load_state_dict(
772
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
773
+ assign=True,
774
+ strict=False)
775
+ else:
776
+ # init a model on device
777
+ with torch.device(device):
778
+ model = model_cls(**kwargs)
779
+
780
+ # set device
781
+ output = (model,)
782
+
783
+ # init transforms
784
+ if return_transforms:
785
+ # mean and std
786
+ if 'siglip' in pretrained_name.lower():
787
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
788
+ else:
789
+ mean = [0.48145466, 0.4578275, 0.40821073]
790
+ std = [0.26862954, 0.26130258, 0.27577711]
791
+
792
+ # transforms
793
+ transforms = T.Compose([
794
+ T.Resize((model.image_size, model.image_size),
795
+ interpolation=T.InterpolationMode.BICUBIC),
796
+ T.ToTensor(),
797
+ T.Normalize(mean=mean, std=std)
798
+ ])
799
+ output += (transforms,)
800
+
801
+ # init tokenizer
802
+ if return_tokenizer:
803
+ from sora import data
804
+ if 'siglip' in pretrained_name.lower():
805
+ tokenizer = data.HuggingfaceTokenizer(
806
+ name=f'timm/{pretrained_name}',
807
+ seq_len=model.text_len,
808
+ clean='canonicalize')
809
+ elif 'xlm' in pretrained_name.lower():
810
+ tokenizer = data.HuggingfaceTokenizer(
811
+ name='xlm-roberta-large',
812
+ seq_len=model.max_text_len - 2,
813
+ clean='whitespace')
814
+ elif 'mba' in pretrained_name.lower():
815
+ tokenizer = data.HuggingfaceTokenizer(
816
+ name='facebook/xlm-roberta-xl',
817
+ seq_len=model.max_text_len - 2,
818
+ clean='whitespace')
819
+ else:
820
+ tokenizer = data.CLIPTokenizer(
821
+ seq_len=model.text_len, padding=tokenizer_padding)
822
+ output += (tokenizer,)
823
+ return output[0] if len(output) == 1 else output
824
+
825
+
826
+ def clip_xlm_roberta_vit_h_14(
827
+ pretrained=False,
828
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
829
+ **kwargs):
830
+ cfg = dict(
831
+ embed_dim=1024,
832
+ image_size=224,
833
+ patch_size=14,
834
+ vision_dim=1280,
835
+ vision_mlp_ratio=4,
836
+ vision_heads=16,
837
+ vision_layers=32,
838
+ vision_pool='token',
839
+ activation='gelu',
840
+ vocab_size=250002,
841
+ max_text_len=514,
842
+ type_size=1,
843
+ pad_id=1,
844
+ text_dim=1024,
845
+ text_heads=16,
846
+ text_layers=24,
847
+ text_post_norm=True,
848
+ text_dropout=0.1,
849
+ attn_dropout=0.0,
850
+ proj_dropout=0.0,
851
+ embedding_dropout=0.0)
852
+ cfg.update(**kwargs)
853
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
854
+
855
+
856
+ class WanImageEncoder(torch.nn.Module):
857
+
858
+ def __init__(self):
859
+ super().__init__()
860
+ # init model
861
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
862
+ pretrained=False,
863
+ return_transforms=True,
864
+ return_tokenizer=False,
865
+ dtype=torch.float32,
866
+ device="cpu")
867
+
868
+ def encode_image(self, videos):
869
+ # preprocess
870
+ size = (self.model.image_size,) * 2
871
+ videos = torch.cat([
872
+ F.interpolate(
873
+ u,
874
+ size=size,
875
+ mode='bicubic',
876
+ align_corners=False) for u in videos
877
+ ])
878
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
879
+
880
+ # forward
881
+ out = self.model.visual(videos, use_31_block=True)
882
+ return out
883
+
884
+ @staticmethod
885
+ def state_dict_converter():
886
+ return WanImageEncoderStateDictConverter()
887
+
888
+
889
+ class WanImageEncoderStateDictConverter:
890
+ def __init__(self):
891
+ pass
892
+
893
+ def from_diffusers(self, state_dict):
894
+ return state_dict
895
+
896
+ def from_civitai(self, state_dict):
897
+ state_dict_ = {}
898
+ for name, param in state_dict.items():
899
+ if name.startswith("textual."):
900
+ continue
901
+ name = "model." + name
902
+ state_dict_[name] = param
903
+ return state_dict_
model/prompter.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffsynth.prompters.base_prompter import BasePrompter
2
+ from model.text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import ftfy
5
+ import html
6
+ import string
7
+ import regex as re
8
+
9
+
10
+ def basic_clean(text):
11
+ text = ftfy.fix_text(text)
12
+ text = html.unescape(html.unescape(text))
13
+ return text.strip()
14
+
15
+
16
+ def whitespace_clean(text):
17
+ text = re.sub(r'\s+', ' ', text)
18
+ text = text.strip()
19
+ return text
20
+
21
+
22
+ def canonicalize(text, keep_punctuation_exact_string=None):
23
+ text = text.replace('_', ' ')
24
+ if keep_punctuation_exact_string:
25
+ text = keep_punctuation_exact_string.join(
26
+ part.translate(str.maketrans('', '', string.punctuation))
27
+ for part in text.split(keep_punctuation_exact_string))
28
+ else:
29
+ text = text.translate(str.maketrans('', '', string.punctuation))
30
+ text = text.lower()
31
+ text = re.sub(r'\s+', ' ', text)
32
+ return text.strip()
33
+
34
+
35
+ class HuggingfaceTokenizer:
36
+
37
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
38
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
39
+ self.name = name
40
+ self.seq_len = seq_len
41
+ self.clean = clean
42
+
43
+ # init tokenizer
44
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
45
+ self.vocab_size = self.tokenizer.vocab_size
46
+
47
+ def __call__(self, sequence, **kwargs):
48
+ return_mask = kwargs.pop('return_mask', False)
49
+
50
+ # arguments
51
+ _kwargs = {'return_tensors': 'pt'}
52
+ if self.seq_len is not None:
53
+ _kwargs.update({
54
+ 'padding': 'max_length',
55
+ 'truncation': True,
56
+ 'max_length': self.seq_len
57
+ })
58
+ _kwargs.update(**kwargs)
59
+
60
+ # tokenization
61
+ if isinstance(sequence, str):
62
+ sequence = [sequence]
63
+ if self.clean:
64
+ sequence = [self._clean(u) for u in sequence]
65
+ ids = self.tokenizer(sequence, **_kwargs)
66
+
67
+ # output
68
+ if return_mask:
69
+ return ids.input_ids, ids.attention_mask
70
+ else:
71
+ return ids.input_ids
72
+
73
+ def _clean(self, text):
74
+ if self.clean == 'whitespace':
75
+ text = whitespace_clean(basic_clean(text))
76
+ elif self.clean == 'lower':
77
+ text = whitespace_clean(basic_clean(text)).lower()
78
+ elif self.clean == 'canonicalize':
79
+ text = canonicalize(basic_clean(text))
80
+ return text
81
+
82
+
83
+ class WanPrompter(BasePrompter):
84
+
85
+ def __init__(self, tokenizer_path=None, text_len=512):
86
+ super().__init__()
87
+ self.text_len = text_len
88
+ self.text_encoder = None
89
+ self.fetch_tokenizer(tokenizer_path)
90
+
91
+ def fetch_tokenizer(self, tokenizer_path=None):
92
+ if tokenizer_path is not None:
93
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
94
+
95
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
96
+ self.text_encoder = text_encoder
97
+
98
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
99
+ prompt = self.process_prompt(prompt, positive=positive)
100
+
101
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
102
+ ids = ids.to(device)
103
+ mask = mask.to(device)
104
+ seq_lens = mask.gt(0).sum(dim=1).long()
105
+ prompt_emb = self.text_encoder(ids, mask)
106
+ prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
107
+ return prompt_emb
model/text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
model/vae.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ x = torch.cat([cache_x, x], dim=2)
49
+ padding[4] -= cache_x.shape[2]
50
+ x = F.pad(x, padding)
51
+
52
+ return super().forward(x)
53
+
54
+
55
+ class RMS_norm(nn.Module):
56
+
57
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
58
+ super().__init__()
59
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
60
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
61
+
62
+ self.channel_first = channel_first
63
+ self.scale = dim**0.5
64
+ self.gamma = nn.Parameter(torch.ones(shape))
65
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
66
+
67
+ def forward(self, x):
68
+ return F.normalize(
69
+ x, dim=(1 if self.channel_first else
70
+ -1)) * self.scale * self.gamma + self.bias
71
+
72
+
73
+ class Upsample(nn.Upsample):
74
+
75
+ def forward(self, x):
76
+ """
77
+ Fix bfloat16 support for nearest neighbor interpolation.
78
+ """
79
+ return super().forward(x.float()).type_as(x)
80
+
81
+
82
+ class Resample(nn.Module):
83
+
84
+ def __init__(self, dim, mode):
85
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
86
+ 'downsample3d')
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.mode = mode
90
+
91
+ # layers
92
+ if mode == 'upsample2d':
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
96
+ elif mode == 'upsample3d':
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
99
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
100
+ self.time_conv = CausalConv3d(dim,
101
+ dim * 2, (3, 1, 1),
102
+ padding=(1, 0, 0))
103
+
104
+ elif mode == 'downsample2d':
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == 'downsample3d':
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(dim,
113
+ dim, (3, 1, 1),
114
+ stride=(2, 1, 1),
115
+ padding=(0, 0, 0))
116
+
117
+ else:
118
+ self.resample = nn.Identity()
119
+
120
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
121
+ b, c, t, h, w = x.size()
122
+ if self.mode == 'upsample3d':
123
+ if feat_cache is not None:
124
+ idx = feat_idx[0]
125
+ if feat_cache[idx] is None:
126
+ feat_cache[idx] = 'Rep'
127
+ feat_idx[0] += 1
128
+ else:
129
+
130
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
131
+ if cache_x.shape[2] < 2 and feat_cache[
132
+ idx] is not None and feat_cache[idx] != 'Rep':
133
+ # cache last frame of last two chunk
134
+ cache_x = torch.cat([
135
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
136
+ cache_x.device), cache_x
137
+ ],
138
+ dim=2)
139
+ if cache_x.shape[2] < 2 and feat_cache[
140
+ idx] is not None and feat_cache[idx] == 'Rep':
141
+ cache_x = torch.cat([
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2)
146
+ if feat_cache[idx] == 'Rep':
147
+ x = self.time_conv(x)
148
+ else:
149
+ x = self.time_conv(x, feat_cache[idx])
150
+ feat_cache[idx] = cache_x
151
+ feat_idx[0] += 1
152
+
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
159
+ x = self.resample(x)
160
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
161
+
162
+ if self.mode == 'downsample3d':
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
184
+ conv.weight.data.copy_(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight.data.copy_(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False), nn.SiLU(),
208
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
209
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
210
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
211
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
212
+ if in_dim != out_dim else nn.Identity()
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat([
223
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
224
+ cache_x.device), cache_x
225
+ ],
226
+ dim=2)
227
+ x = layer(x, feat_cache[idx])
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ else:
231
+ x = layer(x)
232
+ return x + h
233
+
234
+
235
+ class AttentionBlock(nn.Module):
236
+ """
237
+ Causal self-attention with a single head.
238
+ """
239
+
240
+ def __init__(self, dim):
241
+ super().__init__()
242
+ self.dim = dim
243
+
244
+ # layers
245
+ self.norm = RMS_norm(dim)
246
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
247
+ self.proj = nn.Conv2d(dim, dim, 1)
248
+
249
+ # zero out the last layer params
250
+ nn.init.zeros_(self.proj.weight)
251
+
252
+ def forward(self, x):
253
+ identity = x
254
+ b, c, t, h, w = x.size()
255
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
256
+ x = self.norm(x)
257
+ # compute query, key, value
258
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
259
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
260
+
261
+ # apply attention
262
+ x = F.scaled_dot_product_attention(
263
+ q,
264
+ k,
265
+ v,
266
+ #attn_mask=block_causal_mask(q, block_size=h * w)
267
+ )
268
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
269
+
270
+ # output
271
+ x = self.proj(x)
272
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
273
+ return x + identity
274
+
275
+
276
+ class Encoder3d(nn.Module):
277
+
278
+ def __init__(self,
279
+ dim=128,
280
+ z_dim=4,
281
+ dim_mult=[1, 2, 4, 4],
282
+ num_res_blocks=2,
283
+ attn_scales=[],
284
+ temperal_downsample=[True, True, False],
285
+ dropout=0.0):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.z_dim = z_dim
289
+ self.dim_mult = dim_mult
290
+ self.num_res_blocks = num_res_blocks
291
+ self.attn_scales = attn_scales
292
+ self.temperal_downsample = temperal_downsample
293
+
294
+ # dimensions
295
+ dims = [dim * u for u in [1] + dim_mult]
296
+ scale = 1.0
297
+
298
+ # init block
299
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
300
+
301
+ # downsample blocks
302
+ downsamples = []
303
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
304
+ # residual (+attention) blocks
305
+ for _ in range(num_res_blocks):
306
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
307
+ if scale in attn_scales:
308
+ downsamples.append(AttentionBlock(out_dim))
309
+ in_dim = out_dim
310
+
311
+ # downsample block
312
+ if i != len(dim_mult) - 1:
313
+ mode = 'downsample3d' if temperal_downsample[
314
+ i] else 'downsample2d'
315
+ downsamples.append(Resample(out_dim, mode=mode))
316
+ scale /= 2.0
317
+ self.downsamples = nn.Sequential(*downsamples)
318
+
319
+ # middle blocks
320
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
321
+ AttentionBlock(out_dim),
322
+ ResidualBlock(out_dim, out_dim, dropout))
323
+
324
+ # output blocks
325
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
326
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
327
+
328
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
329
+ if feat_cache is not None:
330
+ idx = feat_idx[0]
331
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
332
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
333
+ # cache last frame of last two chunk
334
+ cache_x = torch.cat([
335
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
336
+ cache_x.device), cache_x
337
+ ],
338
+ dim=2)
339
+ x = self.conv1(x, feat_cache[idx])
340
+ feat_cache[idx] = cache_x
341
+ feat_idx[0] += 1
342
+ else:
343
+ x = self.conv1(x)
344
+
345
+ ## downsamples
346
+ for layer in self.downsamples:
347
+ if feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## middle
353
+ for layer in self.middle:
354
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
355
+ x = layer(x, feat_cache, feat_idx)
356
+ else:
357
+ x = layer(x)
358
+
359
+ ## head
360
+ for layer in self.head:
361
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
362
+ idx = feat_idx[0]
363
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
364
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
365
+ # cache last frame of last two chunk
366
+ cache_x = torch.cat([
367
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
368
+ cache_x.device), cache_x
369
+ ],
370
+ dim=2)
371
+ x = layer(x, feat_cache[idx])
372
+ feat_cache[idx] = cache_x
373
+ feat_idx[0] += 1
374
+ else:
375
+ x = layer(x)
376
+ return x
377
+
378
+
379
+ class Decoder3d(nn.Module):
380
+
381
+ def __init__(self,
382
+ dim=128,
383
+ z_dim=4,
384
+ dim_mult=[1, 2, 4, 4],
385
+ num_res_blocks=2,
386
+ attn_scales=[],
387
+ temperal_upsample=[False, True, True],
388
+ dropout=0.0):
389
+ super().__init__()
390
+ self.dim = dim
391
+ self.z_dim = z_dim
392
+ self.dim_mult = dim_mult
393
+ self.num_res_blocks = num_res_blocks
394
+ self.attn_scales = attn_scales
395
+ self.temperal_upsample = temperal_upsample
396
+
397
+ # dimensions
398
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
399
+ scale = 1.0 / 2**(len(dim_mult) - 2)
400
+
401
+ # init block
402
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
403
+
404
+ # middle blocks
405
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
406
+ AttentionBlock(dims[0]),
407
+ ResidualBlock(dims[0], dims[0], dropout))
408
+
409
+ # upsample blocks
410
+ upsamples = []
411
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
412
+ # residual (+attention) blocks
413
+ if i == 1 or i == 2 or i == 3:
414
+ in_dim = in_dim // 2
415
+ for _ in range(num_res_blocks + 1):
416
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
417
+ if scale in attn_scales:
418
+ upsamples.append(AttentionBlock(out_dim))
419
+ in_dim = out_dim
420
+
421
+ # upsample block
422
+ if i != len(dim_mult) - 1:
423
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
424
+ upsamples.append(Resample(out_dim, mode=mode))
425
+ scale *= 2.0
426
+ self.upsamples = nn.Sequential(*upsamples)
427
+
428
+ # output blocks
429
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
430
+ CausalConv3d(out_dim, 3, 3, padding=1))
431
+
432
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
433
+ ## conv1
434
+ if feat_cache is not None:
435
+ idx = feat_idx[0]
436
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
437
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
438
+ # cache last frame of last two chunk
439
+ cache_x = torch.cat([
440
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
441
+ cache_x.device), cache_x
442
+ ],
443
+ dim=2)
444
+ x = self.conv1(x, feat_cache[idx])
445
+ feat_cache[idx] = cache_x
446
+ feat_idx[0] += 1
447
+ else:
448
+ x = self.conv1(x)
449
+
450
+ ## middle
451
+ for layer in self.middle:
452
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
453
+ x = layer(x, feat_cache, feat_idx)
454
+ else:
455
+ x = layer(x)
456
+
457
+ ## upsamples
458
+ for layer in self.upsamples:
459
+ if feat_cache is not None:
460
+ x = layer(x, feat_cache, feat_idx)
461
+ else:
462
+ x = layer(x)
463
+
464
+ ## head
465
+ for layer in self.head:
466
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
467
+ idx = feat_idx[0]
468
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
469
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
470
+ # cache last frame of last two chunk
471
+ cache_x = torch.cat([
472
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
473
+ cache_x.device), cache_x
474
+ ],
475
+ dim=2)
476
+ x = layer(x, feat_cache[idx])
477
+ feat_cache[idx] = cache_x
478
+ feat_idx[0] += 1
479
+ else:
480
+ x = layer(x)
481
+ return x
482
+
483
+
484
+ def count_conv3d(model):
485
+ count = 0
486
+ for m in model.modules():
487
+ if check_is_instance(m, CausalConv3d):
488
+ count += 1
489
+ return count
490
+
491
+
492
+ class VideoVAE_(nn.Module):
493
+
494
+ def __init__(self,
495
+ dim=96,
496
+ z_dim=16,
497
+ dim_mult=[1, 2, 4, 4],
498
+ num_res_blocks=2,
499
+ attn_scales=[],
500
+ temperal_downsample=[False, True, True],
501
+ dropout=0.0):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.z_dim = z_dim
505
+ self.dim_mult = dim_mult
506
+ self.num_res_blocks = num_res_blocks
507
+ self.attn_scales = attn_scales
508
+ self.temperal_downsample = temperal_downsample
509
+ self.temperal_upsample = temperal_downsample[::-1]
510
+
511
+ # modules
512
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
513
+ attn_scales, self.temperal_downsample, dropout)
514
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
515
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
516
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
517
+ attn_scales, self.temperal_upsample, dropout)
518
+
519
+ def forward(self, x):
520
+ mu, log_var = self.encode(x)
521
+ z = self.reparameterize(mu, log_var)
522
+ x_recon = self.decode(z)
523
+ return x_recon, mu, log_var
524
+
525
+ def encode(self, x, scale):
526
+ self.clear_cache()
527
+ ## cache
528
+ t = x.shape[2]
529
+ iter_ = 1 + (t - 1) // 4
530
+
531
+ for i in range(iter_):
532
+ self._enc_conv_idx = [0]
533
+ if i == 0:
534
+ out = self.encoder(x[:, :, :1, :, :],
535
+ feat_cache=self._enc_feat_map,
536
+ feat_idx=self._enc_conv_idx)
537
+ else:
538
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
539
+ feat_cache=self._enc_feat_map,
540
+ feat_idx=self._enc_conv_idx)
541
+ out = torch.cat([out, out_], 2)
542
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
543
+ if isinstance(scale[0], torch.Tensor):
544
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
545
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
546
+ 1, self.z_dim, 1, 1, 1)
547
+ else:
548
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
549
+ mu = (mu - scale[0]) * scale[1]
550
+ return mu
551
+
552
+ def decode(self, z, scale):
553
+ self.clear_cache()
554
+ # z: [b,c,t,h,w]
555
+ if isinstance(scale[0], torch.Tensor):
556
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
557
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
558
+ 1, self.z_dim, 1, 1, 1)
559
+ else:
560
+ scale = scale.to(dtype=z.dtype, device=z.device)
561
+ z = z / scale[1] + scale[0]
562
+ iter_ = z.shape[2]
563
+ x = self.conv2(z)
564
+ for i in range(iter_):
565
+ self._conv_idx = [0]
566
+ if i == 0:
567
+ out = self.decoder(x[:, :, i:i + 1, :, :],
568
+ feat_cache=self._feat_map,
569
+ feat_idx=self._conv_idx)
570
+ else:
571
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
572
+ feat_cache=self._feat_map,
573
+ feat_idx=self._conv_idx)
574
+ out = torch.cat([out, out_], 2) # may add tensor offload
575
+ return out
576
+
577
+ def reparameterize(self, mu, log_var):
578
+ std = torch.exp(0.5 * log_var)
579
+ eps = torch.randn_like(std)
580
+ return eps * std + mu
581
+
582
+ def sample(self, imgs, deterministic=False):
583
+ mu, log_var = self.encode(imgs)
584
+ if deterministic:
585
+ return mu
586
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
587
+ return mu + std * torch.randn_like(std)
588
+
589
+ def clear_cache(self):
590
+ self._conv_num = count_conv3d(self.decoder)
591
+ self._conv_idx = [0]
592
+ self._feat_map = [None] * self._conv_num
593
+ # cache encode
594
+ self._enc_conv_num = count_conv3d(self.encoder)
595
+ self._enc_conv_idx = [0]
596
+ self._enc_feat_map = [None] * self._enc_conv_num
597
+
598
+
599
+ class WanVideoVAE(nn.Module):
600
+
601
+ def __init__(self, z_dim=16):
602
+ super().__init__()
603
+
604
+ mean = [
605
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
606
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
607
+ ]
608
+ std = [
609
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
610
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
611
+ ]
612
+ self.mean = torch.tensor(mean)
613
+ self.std = torch.tensor(std)
614
+ self.scale = [self.mean, 1.0 / self.std]
615
+
616
+ # init model
617
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
618
+ self.upsampling_factor = 8
619
+
620
+
621
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
622
+ x = torch.ones((length,))
623
+ if not left_bound:
624
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
625
+ if not right_bound:
626
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
627
+ return x
628
+
629
+
630
+ def build_mask(self, data, is_bound, border_width):
631
+ _, _, _, H, W = data.shape
632
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
633
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
634
+
635
+ h = repeat(h, "H -> H W", H=H, W=W)
636
+ w = repeat(w, "W -> H W", H=H, W=W)
637
+
638
+ mask = torch.stack([h, w]).min(dim=0).values
639
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
640
+ return mask
641
+
642
+
643
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
644
+ _, _, T, H, W = hidden_states.shape
645
+ size_h, size_w = tile_size
646
+ stride_h, stride_w = tile_stride
647
+
648
+ # Split tasks
649
+ tasks = []
650
+ for h in range(0, H, stride_h):
651
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
652
+ for w in range(0, W, stride_w):
653
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
654
+ h_, w_ = h + size_h, w + size_w
655
+ tasks.append((h, h_, w, w_))
656
+
657
+ data_device = "cpu"
658
+ computation_device = device
659
+
660
+ out_T = T * 4 - 3
661
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
662
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
663
+
664
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
665
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
666
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
667
+
668
+ mask = self.build_mask(
669
+ hidden_states_batch,
670
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
671
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
672
+ ).to(dtype=hidden_states.dtype, device=data_device)
673
+
674
+ target_h = h * self.upsampling_factor
675
+ target_w = w * self.upsampling_factor
676
+ values[
677
+ :,
678
+ :,
679
+ :,
680
+ target_h:target_h + hidden_states_batch.shape[3],
681
+ target_w:target_w + hidden_states_batch.shape[4],
682
+ ] += hidden_states_batch * mask
683
+ weight[
684
+ :,
685
+ :,
686
+ :,
687
+ target_h: target_h + hidden_states_batch.shape[3],
688
+ target_w: target_w + hidden_states_batch.shape[4],
689
+ ] += mask
690
+ values = values / weight
691
+ values = values.float().clamp_(-1, 1)
692
+ return values
693
+
694
+
695
+ def tiled_encode(self, video, device, tile_size, tile_stride):
696
+ _, _, T, H, W = video.shape
697
+ size_h, size_w = tile_size
698
+ stride_h, stride_w = tile_stride
699
+
700
+ # Split tasks
701
+ tasks = []
702
+ for h in range(0, H, stride_h):
703
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
704
+ for w in range(0, W, stride_w):
705
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
706
+ h_, w_ = h + size_h, w + size_w
707
+ tasks.append((h, h_, w, w_))
708
+
709
+ data_device = "cpu"
710
+ computation_device = device
711
+
712
+ out_T = (T + 3) // 4
713
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
714
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
715
+
716
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
717
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
718
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
719
+
720
+ mask = self.build_mask(
721
+ hidden_states_batch,
722
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
723
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
724
+ ).to(dtype=video.dtype, device=data_device)
725
+
726
+ target_h = h // self.upsampling_factor
727
+ target_w = w // self.upsampling_factor
728
+ values[
729
+ :,
730
+ :,
731
+ :,
732
+ target_h:target_h + hidden_states_batch.shape[3],
733
+ target_w:target_w + hidden_states_batch.shape[4],
734
+ ] += hidden_states_batch * mask
735
+ weight[
736
+ :,
737
+ :,
738
+ :,
739
+ target_h: target_h + hidden_states_batch.shape[3],
740
+ target_w: target_w + hidden_states_batch.shape[4],
741
+ ] += mask
742
+ values = values / weight
743
+ values = values.float()
744
+ return values
745
+
746
+
747
+ def single_encode(self, video, device):
748
+ video = video.to(device)
749
+ x = self.model.encode(video, self.scale)
750
+ return x.float()
751
+
752
+
753
+ def single_decode(self, hidden_state, device):
754
+ hidden_state = hidden_state.to(device)
755
+ video = self.model.decode(hidden_state, self.scale)
756
+ return video.float().clamp_(-1, 1)
757
+
758
+
759
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
760
+
761
+ videos = [video.to("cpu") for video in videos]
762
+ hidden_states = []
763
+ for video in videos:
764
+ video = video.unsqueeze(0)
765
+ if tiled:
766
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
767
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
768
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
769
+ else:
770
+ hidden_state = self.single_encode(video, device)
771
+ hidden_state = hidden_state.squeeze(0)
772
+ hidden_states.append(hidden_state)
773
+ hidden_states = torch.stack(hidden_states)
774
+ return hidden_states
775
+
776
+
777
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
778
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
779
+ videos = []
780
+ for hidden_state in hidden_states:
781
+ hidden_state = hidden_state.unsqueeze(0)
782
+ if tiled:
783
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
784
+ else:
785
+ video = self.single_decode(hidden_state, device)
786
+ video = video.squeeze(0)
787
+ videos.append(video)
788
+ videos = torch.stack(videos)
789
+ return videos
790
+
791
+
792
+ @staticmethod
793
+ def state_dict_converter():
794
+ return WanVideoVAEStateDictConverter()
795
+
796
+
797
+ class WanVideoVAEStateDictConverter:
798
+
799
+ def __init__(self):
800
+ pass
801
+
802
+ def from_civitai(self, state_dict):
803
+ state_dict_ = {}
804
+ if 'model_state' in state_dict:
805
+ state_dict = state_dict['model_state']
806
+ for name in state_dict:
807
+ state_dict_['model.' + name] = state_dict[name]
808
+ return state_dict_
809
+
pipeline/__init__.py ADDED
File without changes
pipeline/i2v_pipeline.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffsynth import ModelManager
2
+ from diffsynth.pipelines.base import BasePipeline
3
+ from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
4
+
5
+ from model.dit import WanModel
6
+ from model.text_encoder import WanTextEncoder
7
+ from model.vae import WanVideoVAE
8
+ from model.image_encoder import WanImageEncoder
9
+ from model.prompter import WanPrompter
10
+ from scheduler.flow_match import FlowMatchScheduler
11
+
12
+ import torch, os
13
+ from einops import rearrange, repeat
14
+ import numpy as np
15
+ import PIL.Image
16
+ from tqdm import tqdm
17
+ from safetensors import safe_open
18
+
19
+ from model.text_encoder import T5RelativeEmbedding, T5LayerNorm
20
+ from model.dit import WanLayerNorm, WanRMSNorm, WanSelfAttention
21
+ from model.vae import RMS_norm, CausalConv3d, Upsample
22
+
23
+
24
+ def binary_tensor_to_indices(tensor):
25
+ assert tensor.dim() == 2, "Input tensor must be in [b, t]"
26
+ indices = [(row == 1).nonzero(as_tuple=True)[0] for row in tensor]
27
+ return indices
28
+
29
+ def propagate_visualize_attention_arg(model, visualize_attention=False):
30
+ """
31
+ Recursively set the visualize_attention parameter to True for all WanSelfAttention modules
32
+ Only for inference/test mode
33
+ """
34
+ for name, module in model.named_modules():
35
+ if isinstance(module, WanSelfAttention):
36
+ if "blocks.0.self_attn" in name or "blocks.19.self_attn" in name or "blocks.39.self_attn" in name:
37
+ print(f"Set `visualize_attention` to {visualize_attention} for {name}")
38
+ module.visualize_attention = visualize_attention
39
+
40
+ class WanVideoPipeline(BasePipeline):
41
+
42
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
43
+ super().__init__(device=device, torch_dtype=torch_dtype)
44
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
45
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
46
+ self.text_encoder: WanTextEncoder = None
47
+ self.image_encoder: WanImageEncoder = None
48
+ self.dit: WanModel = None
49
+ self.vae: WanVideoVAE = None
50
+ self.model_names = ['text_encoder', 'dit', 'vae']
51
+ self.height_division_factor = 16
52
+ self.width_division_factor = 16
53
+
54
+
55
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
56
+ dtype = next(iter(self.text_encoder.parameters())).dtype
57
+ enable_vram_management(
58
+ self.text_encoder,
59
+ module_map = {
60
+ torch.nn.Linear: AutoWrappedLinear,
61
+ torch.nn.Embedding: AutoWrappedModule,
62
+ T5RelativeEmbedding: AutoWrappedModule,
63
+ T5LayerNorm: AutoWrappedModule,
64
+ },
65
+ module_config = dict(
66
+ offload_dtype=dtype,
67
+ offload_device="cpu",
68
+ onload_dtype=dtype,
69
+ onload_device="cpu",
70
+ computation_dtype=self.torch_dtype,
71
+ computation_device=self.device,
72
+ ),
73
+ )
74
+ dtype = next(iter(self.dit.parameters())).dtype
75
+ enable_vram_management(
76
+ self.dit,
77
+ module_map = {
78
+ torch.nn.Linear: AutoWrappedLinear,
79
+ torch.nn.Conv3d: AutoWrappedModule,
80
+ torch.nn.LayerNorm: AutoWrappedModule,
81
+ WanLayerNorm: AutoWrappedModule,
82
+ WanRMSNorm: AutoWrappedModule,
83
+ },
84
+ module_config = dict(
85
+ offload_dtype=dtype,
86
+ offload_device="cpu",
87
+ onload_dtype=dtype,
88
+ onload_device=self.device,
89
+ computation_dtype=self.torch_dtype,
90
+ computation_device=self.device,
91
+ ),
92
+ max_num_param=num_persistent_param_in_dit,
93
+ overflow_module_config = dict(
94
+ offload_dtype=dtype,
95
+ offload_device="cpu",
96
+ onload_dtype=dtype,
97
+ onload_device="cpu",
98
+ computation_dtype=self.torch_dtype,
99
+ computation_device=self.device,
100
+ ),
101
+ )
102
+ dtype = next(iter(self.vae.parameters())).dtype
103
+ enable_vram_management(
104
+ self.vae,
105
+ module_map = {
106
+ torch.nn.Linear: AutoWrappedLinear,
107
+ torch.nn.Conv2d: AutoWrappedModule,
108
+ RMS_norm: AutoWrappedModule,
109
+ CausalConv3d: AutoWrappedModule,
110
+ Upsample: AutoWrappedModule,
111
+ torch.nn.SiLU: AutoWrappedModule,
112
+ torch.nn.Dropout: AutoWrappedModule,
113
+ },
114
+ module_config = dict(
115
+ offload_dtype=dtype,
116
+ offload_device="cpu",
117
+ onload_dtype=dtype,
118
+ onload_device=self.device,
119
+ computation_dtype=self.torch_dtype,
120
+ computation_device=self.device,
121
+ ),
122
+ )
123
+ if self.image_encoder is not None:
124
+ dtype = next(iter(self.image_encoder.parameters())).dtype
125
+ enable_vram_management(
126
+ self.image_encoder,
127
+ module_map = {
128
+ torch.nn.Linear: AutoWrappedLinear,
129
+ torch.nn.Conv2d: AutoWrappedModule,
130
+ torch.nn.LayerNorm: AutoWrappedModule,
131
+ },
132
+ module_config = dict(
133
+ offload_dtype=dtype,
134
+ offload_device="cpu",
135
+ onload_dtype=dtype,
136
+ onload_device="cpu",
137
+ computation_dtype=self.torch_dtype,
138
+ computation_device=self.device,
139
+ ),
140
+ )
141
+ self.enable_cpu_offload()
142
+
143
+ def fetch_models_from_model_manager(self, model_manager: ModelManager):
144
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
145
+ if text_encoder_model_and_path is not None:
146
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
147
+ self.prompter.fetch_models(self.text_encoder)
148
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
149
+ self.dit = model_manager.fetch_model("wan_video_dit")
150
+ self.vae = model_manager.fetch_model("wan_video_vae")
151
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
152
+
153
+ def _init_component_from_checkpoint_path(self, model_cls, state_dict_path, strict=True, config_dict=None):
154
+ config = {}
155
+ state_dict = self._load_state_dict(state_dict_path)
156
+ if hasattr(model_cls, "state_dict_converter"):
157
+ state_dict_converter = model_cls.state_dict_converter()
158
+ state_dict = state_dict_converter.from_civitai(state_dict)
159
+ if isinstance(state_dict, tuple):
160
+ state_dict, config = state_dict
161
+ config.update(config_dict or {})
162
+ model = model_cls(**config)
163
+ if "use_local_lora" in config_dict or "use_dera" in config_dict:
164
+ strict = False
165
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
166
+ print(f"Missing keys: {missing_keys}")
167
+ print(f"Unexpected keys: {unexpected_keys}")
168
+ return model
169
+
170
+ def _load_state_dict(self, state_dict_paths):
171
+ if isinstance(state_dict_paths, str):
172
+ state_dict_paths = [state_dict_paths]
173
+ state_dict = {}
174
+ for state_dict_path in tqdm(state_dict_paths, desc="Reading file(s) from disk"):
175
+ state_dict.update(self._load_single_file(state_dict_path))
176
+ return state_dict
177
+
178
+ def _load_single_file(self, file_path):
179
+ if file_path.endswith(".safetensors"):
180
+ return self._load_state_dict_from_safetensors(file_path)
181
+ else:
182
+ return torch.load(file_path, map_location='cpu')
183
+
184
+ def _load_state_dict_from_safetensors(self, file_path, torch_dtype=None):
185
+ state_dict = {}
186
+ with safe_open(file_path, framework="pt", device="cpu") as f:
187
+ for k in f.keys():
188
+ state_dict[k] = f.get_tensor(k)
189
+ if torch_dtype is not None:
190
+ state_dict[k] = state_dict[k].to(torch_dtype)
191
+ return state_dict
192
+
193
+ def initialize_dummy_dit(self, config):
194
+ print("Initializing a dummy DIT model.")
195
+ self.dit = WanModel(**config)
196
+ print("Dummy DIT model is initialized.")
197
+
198
+ def fetch_models_from_checkpoints(self, path_dict, config_dict=None):
199
+ default_config = {"text_encoder": {}, "dit": {}, "vae": {}, "image_encoder": {}}
200
+ config_dict = {**default_config, **(config_dict or {})}
201
+ components = {
202
+ "text_encoder": WanTextEncoder,
203
+ "dit": WanModel,
204
+ "vae": WanVideoVAE,
205
+ "image_encoder": WanImageEncoder
206
+ }
207
+ for name, model_cls in components.items():
208
+ if name not in path_dict:
209
+ print(f"Component {name} is not found in the checkpoint path dict. Skipping.")
210
+ continue
211
+ path = path_dict[name]
212
+ config = config_dict.get(name, {})
213
+ print(f"Loading {name} from {path} with config {config}.")
214
+ setattr(self, name, self._init_component_from_checkpoint_path(model_cls, path, config_dict=config))
215
+ print(f"Initialized {name} from checkpoint.")
216
+ if "text_encoder" in path_dict:
217
+ self.prompter.fetch_models(self.text_encoder)
218
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(path_dict["text_encoder"]), "google/umt5-xxl"))
219
+ print("Initialized prompter from checkpoint.")
220
+ print("All components are initialized from checkpoints.")
221
+
222
+ @staticmethod
223
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
224
+ if device is None: device = model_manager.device
225
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
226
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
227
+ pipe.fetch_models_from_model_manager(model_manager)
228
+ return pipe
229
+
230
+ def denoising_model(self):
231
+ return self.dit
232
+
233
+ def encode_prompt(self, prompt, positive=True):
234
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
235
+ return {"context": prompt_emb}
236
+
237
+ def encode_image(self, image, num_frames, height, width):
238
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
239
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
240
+ clip_context = self.image_encoder.encode_image([image])
241
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
242
+ msk[:, 1:] = 0
243
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
244
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
245
+ msk = msk.transpose(1, 2)[0]
246
+ y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
247
+ y = torch.concat([msk, y])
248
+ return {"clip_fea": clip_context, "y": [y]}
249
+
250
+ def check_and_fix_image_or_video_tensor_input(self, _tensor):
251
+ assert isinstance(_tensor, torch.Tensor), "Input must be a tensor."
252
+ if _tensor.max() <= 255 and _tensor.max() > 1.0:
253
+ _tensor = _tensor.to(self.device) / 127.5 - 1
254
+ print("Input tensor is converted from [0, 255] to [-1, 1].")
255
+ elif _tensor.min() >= 0 and _tensor.max() <= 1:
256
+ _tensor = _tensor.to(self.device) * 2 - 1
257
+ print("Input tensor is converted from [0, 1] to [-1, 1].")
258
+ return _tensor
259
+
260
+ def encode_video_with_mask(self, video, num_frames, height, width, condition_preserved_mask):
261
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
262
+ video = video.to(self.device)
263
+ y = self.vae.encode(video, device=self.device)
264
+ msk = condition_preserved_mask
265
+ assert msk is not None, "The mask must be provided for the masked video input."
266
+ assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
267
+ assert msk.shape[0] == video.shape[0], "The batch size of the mask must be the same as the input video."
268
+ assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
269
+ msk = msk.to(self.device)
270
+ msk = msk.unsqueeze(-1).unsqueeze(-1)
271
+ msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
272
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
273
+ msk = msk.view(video.shape[0], msk.shape[1] // 4, 4, height//8, width//8) # b, t, c, h, w
274
+ msk = msk.transpose(1, 2) # b, c, t, h, w
275
+ y = torch.concat([msk, y], dim=1)
276
+ return y
277
+
278
+ def encode_video_with_mask_sparse(self, video, height, width, condition_preserved_mask, sketch_local_mask=None):
279
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
280
+ batch_size = video.shape[0]
281
+ cond_indices = binary_tensor_to_indices(condition_preserved_mask)
282
+ sequence_cond_compressed_indices = [(cond_index + 3) // 4 for cond_index in cond_indices]
283
+ video = video.to(self.device)
284
+ video_latent = self.vae.encode(video, device=self.device)
285
+ video_latent = video_latent[:, :, sequence_cond_compressed_indices[0], :, :]
286
+ msk = condition_preserved_mask.to(self.device)
287
+ msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1
288
+ msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
289
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
290
+ msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8
291
+ msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8
292
+ msk = msk[:, :, sequence_cond_compressed_indices[0], :, :]
293
+
294
+ if sketch_local_mask is not None:
295
+ sketch_local_mask = sketch_local_mask.to(self.device)
296
+ if sketch_local_mask.shape[-2:] != (height//8, width//8):
297
+ sk_batch_t = sketch_local_mask.shape[0] * sketch_local_mask.shape[2]
298
+ sketch_local_mask_reshaped = sketch_local_mask.reshape(sk_batch_t, 1, sketch_local_mask.shape[3], sketch_local_mask.shape[4])
299
+ sketch_local_mask_resized = torch.nn.functional.interpolate(
300
+ sketch_local_mask_reshaped,
301
+ size=(height//8, width//8),
302
+ mode='nearest'
303
+ )
304
+ sketch_local_mask_resized = sketch_local_mask_resized.reshape(
305
+ sketch_local_mask.shape[0],
306
+ sketch_local_mask.shape[1],
307
+ sketch_local_mask.shape[2],
308
+ height//8, width//8
309
+ )
310
+ else:
311
+ sketch_local_mask_resized = sketch_local_mask
312
+
313
+ sketch_mask = sketch_local_mask_resized
314
+ sketch_mask = torch.concat([torch.repeat_interleave(sketch_mask[:, :, 0:1], repeats=4, dim=2), sketch_mask[:, :, 1:]], dim=2)
315
+ sketch_mask = sketch_mask.view(batch_size, sketch_mask.shape[1], sketch_mask.shape[2] // 4, 4, height//8, width//8)
316
+ sketch_mask = sketch_mask.permute(0, 1, 3, 2, 4, 5) # [b, 1, 4, t//4, h//8, w//8]
317
+ sketch_mask = sketch_mask.view(batch_size, 4, sketch_mask.shape[3], height//8, width//8) # [b, 4, t//4, h//8, w//8]
318
+ sketch_mask = sketch_mask[:, :, sequence_cond_compressed_indices[0], :, :] # [b, 4, len(indices), h//8, w//8]
319
+
320
+ combined_latent = torch.cat([msk, video_latent, sketch_mask], dim=1)
321
+ else:
322
+ combined_latent = torch.concat([msk, video_latent], dim=1)
323
+
324
+ return combined_latent, sequence_cond_compressed_indices # b, c=(4+16+4=24), t, h, w when sketch_local_mask is provided
325
+
326
+ def encode_image_or_masked_video(self, image_or_masked_video, num_frames, height, width, condition_preserved_mask=None):
327
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
328
+ batch_size = image_or_masked_video.shape[0]
329
+ if isinstance(image_or_masked_video, PIL.Image.Image) or (isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() <= 4):
330
+ if isinstance(image_or_masked_video, PIL.Image.Image):
331
+ image_or_masked_video = self.preprocess_image(image_or_masked_video.resize((width, height))).to(self.device)
332
+ else:
333
+ if image_or_masked_video.dim() == 3:
334
+ image_or_masked_video = image_or_masked_video.unsqueeze(0) # b=1, c, h, w
335
+ image_or_masked_video = image_or_masked_video.to(self.device)
336
+ y = self.vae.encode([torch.concat([image_or_masked_video.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_or_masked_video.device)], dim=1)], device=self.device)
337
+ msk_idx_to_be_zero = range(1, num_frames)
338
+ clip_context = self.image_encoder.encode_image(image_or_masked_video.unsqueeze(1)) # need to be [b, 1, c, h, w]
339
+ msk = torch.ones(batch_size, num_frames, height//8, width//8, device=self.device)
340
+ msk[:, msk_idx_to_be_zero] = 0
341
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
342
+ msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8)
343
+ msk = msk.transpose(1, 2)
344
+ elif isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() == 5:
345
+ image_or_masked_video = image_or_masked_video.to(self.device)
346
+ first_image = image_or_masked_video[:, :, 0, :, :].unsqueeze(1)
347
+ clip_context = self.image_encoder.encode_image(first_image)
348
+ y = self.vae.encode(image_or_masked_video, device=self.device)
349
+ msk = condition_preserved_mask # b, t
350
+ assert msk is not None, "The mask must be provided for the masked video input."
351
+ assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
352
+ assert msk.shape[0] == batch_size, "The batch size of the mask must be the same as the input video."
353
+ assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
354
+ msk = msk.to(self.device)
355
+ msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1
356
+ msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
357
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
358
+ msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8
359
+ msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8
360
+ else:
361
+ raise ValueError("Input must be an image (PIL/Tensor in [b, c, h, w]) or a masked video (Tensor in [b, c, t, h, w]).")
362
+
363
+ y = torch.concat([msk, y], dim=1)
364
+ return {"clip_fea": clip_context, "y": y}
365
+
366
+ def tensor2video(self, frames):
367
+ frames = rearrange(frames, "C T H W -> T H W C")
368
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
369
+ frames = [PIL.Image.fromarray(frame) for frame in frames]
370
+ return frames
371
+
372
+ def prepare_extra_input(self, latents=None):
373
+ return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
374
+
375
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
376
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
377
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
378
+ return latents
379
+
380
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
381
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
382
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
383
+ return frames
384
+
385
+ @torch.no_grad()
386
+ def __call__(
387
+ self,
388
+ prompt,
389
+ negative_prompt="",
390
+ input_image=None,
391
+ input_video=None,
392
+ denoising_strength=1.0,
393
+ seed=None,
394
+ rand_device="cpu",
395
+ height=480,
396
+ width=832,
397
+ num_frames=81,
398
+ cfg_scale=5.0,
399
+ num_inference_steps=50,
400
+ sigma_shift=5.0,
401
+ tiled=True,
402
+ tile_size=(30, 52),
403
+ tile_stride=(15, 26),
404
+ progress_bar_cmd=tqdm,
405
+ # progress_bar_st=None,
406
+ input_condition_video=None,
407
+ input_condition_preserved_mask=None,
408
+ input_condition_video_sketch=None,
409
+ input_condition_preserved_mask_sketch=None,
410
+ sketch_local_mask=None,
411
+ visualize_attention=False,
412
+ output_path=None,
413
+ batch_idx=None,
414
+ sequence_cond_residual_scale=1.0,
415
+ ):
416
+ height, width = self.check_resize_height_width(height, width)
417
+ if num_frames % 4 != 1:
418
+ num_frames = (num_frames + 2) // 4 * 4 + 1
419
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
420
+
421
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
422
+
423
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
424
+
425
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
426
+ if input_video is not None:
427
+ self.load_models_to_device(['vae'])
428
+ input_video = self.preprocess_images(input_video)
429
+ input_video = torch.stack(input_video, dim=2)
430
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
431
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
432
+ else:
433
+ latents = noise
434
+
435
+ self.load_models_to_device(["text_encoder"])
436
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
437
+ if cfg_scale != 1.0:
438
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
439
+
440
+ self.load_models_to_device(["image_encoder", "vae"])
441
+ if input_image is not None and self.image_encoder is not None:
442
+ image_emb = self.encode_image(input_image, num_frames, height, width)
443
+ elif input_condition_video is not None and self.image_encoder is not None:
444
+ assert input_condition_preserved_mask is not None, "`input_condition_preserved_mask` must not be None when `input_condition_video` is given."
445
+ image_emb = self.encode_image_or_masked_video(input_condition_video, num_frames, height, width, input_condition_preserved_mask)
446
+ else:
447
+ image_emb = {}
448
+
449
+ # Extra input
450
+ extra_input = self.prepare_extra_input(latents)
451
+ if self.dit.use_sequence_cond:
452
+ assert input_condition_video_sketch is not None, "`input_condition_video_sketch` must not be None when `use_sequence_cond` is True."
453
+ assert input_condition_preserved_mask_sketch is not None, "`input_condition_preserved_mask_sketch` must not be None when `input_condition_video_sketch` is given."
454
+
455
+ if self.dit.sequence_cond_mode == "sparse":
456
+ sequence_cond, sequence_cond_compressed_indices = self.encode_video_with_mask_sparse(input_condition_video_sketch, height, width, input_condition_preserved_mask_sketch, sketch_local_mask)
457
+ extra_input.update({"sequence_cond": sequence_cond,
458
+ "sequence_cond_compressed_indices": sequence_cond_compressed_indices})
459
+ elif self.dit.sequence_cond_mode == "full":
460
+ sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
461
+ extra_input.update({"sequence_cond": sequence_cond})
462
+ else:
463
+ raise ValueError(f"Invalid `sequence_cond_model`={self.dit.sequence_cond_mode} in the DIT model.")
464
+
465
+ elif self.dit.use_channel_cond:
466
+ sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
467
+ extra_input.update({"channel_cond": sequence_cond})
468
+
469
+ self.load_models_to_device([])
470
+
471
+ if sequence_cond_residual_scale != 1.0:
472
+ extra_input.update({"sequence_cond_residual_scale": sequence_cond_residual_scale})
473
+
474
+ # Denoise
475
+ self.load_models_to_device(["dit"])
476
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
477
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
478
+ timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
479
+ _should_visualize_attention = visualize_attention and (progress_id == len(self.scheduler.timesteps) - 1)
480
+ if _should_visualize_attention:
481
+ print(f"Visualizing attention maps (Step {progress_id + 1}/{len(self.scheduler.timesteps)}).")
482
+ propagate_visualize_attention_arg(self.dit, True)
483
+
484
+ # Inference
485
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
486
+ if isinstance(noise_pred_posi, tuple):
487
+ noise_pred_posi = noise_pred_posi[0]
488
+ if cfg_scale != 1.0:
489
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
490
+ if isinstance(noise_pred_nega, tuple):
491
+ noise_pred_nega = noise_pred_nega[0]
492
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
493
+ else:
494
+ noise_pred = noise_pred_posi
495
+
496
+ # Scheduler
497
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
498
+
499
+ # If visualization is enabled, save the attention maps
500
+ if _should_visualize_attention:
501
+ print("Saving attention maps...")
502
+ from util.model_util import save_attention_maps
503
+ save_attention_maps(self.dit, output_path, batch_idx, timestep.squeeze().cpu().numpy().item())
504
+ propagate_visualize_attention_arg(self.dit, False)
505
+
506
+ # Decode
507
+ self.load_models_to_device(['vae'])
508
+ frames = self.decode_video(latents, **tiler_kwargs)
509
+ self.load_models_to_device([])
510
+
511
+ return frames
requirements.txt ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ accelerate==1.6.0
3
+ beartype==0.20.2
4
+ beautifulsoup4==4.13.4
5
+ braceexpand==0.1.7
6
+ cached-property==2.0.1
7
+ certifi==2025.1.31
8
+ charset-normalizer==3.4.1
9
+ click==8.1.8
10
+ clip==0.2.0
11
+ comm==0.2.3
12
+ contourpy==1.3.2
13
+ controlnet_aux==0.0.7
14
+ crcmod==1.7
15
+ cycler==0.12.1
16
+ datasets==3.5.0
17
+ debugpy==1.8.15
18
+ decorator==5.2.1
19
+ decord==0.6.0
20
+ deepspeed==0.16.7
21
+ diffsynth==1.1.7
22
+ diffusers==0.33.1
23
+ dill==0.3.8
24
+ docker-pycreds==0.4.0
25
+ dulwich==0.22.8
26
+ easydict==1.13
27
+ einops==0.8.1
28
+ exceptiongroup==1.2.2
29
+ executing==2.2.0
30
+ fairscale==0.4.13
31
+ fastapi==0.115.12
32
+ fastrlock==0.8.3
33
+ ffmpy==0.5.0
34
+ filelock==3.13.1
35
+ flash_attn==2.8.0.post2 --global-option="--no-build-isolation"
36
+ fonttools==4.57.0
37
+ frozenlist==1.6.0
38
+ fsspec==2024.12.0
39
+ ftfy==6.3.1
40
+ func_timeout==4.3.5
41
+ fuzzywuzzy==0.18.0
42
+ gitdb==4.0.12
43
+ GitPython==3.1.44
44
+ gradio==5.25.2
45
+ gradio_client==1.8.0
46
+ groovy==0.1.2
47
+ grpcio==1.71.0
48
+ h11==0.14.0
49
+ hjson==3.1.0
50
+ httpcore==1.0.8
51
+ httpx==0.28.1
52
+ huggingface-hub==0.30.2
53
+ idna==3.10
54
+ imageio==2.37.0
55
+ imageio-ffmpeg==0.6.0
56
+ importlib_metadata==8.6.1
57
+ ipykernel==6.30.0
58
+ ipython==8.37.0
59
+ jedi==0.19.2
60
+ Jinja2==3.1.4
61
+ joblib==1.4.2
62
+ kiwisolver==1.4.8
63
+ kornia==0.8.0
64
+ kornia_rs==0.1.8
65
+ lazy_loader==0.4
66
+ lightning==2.5.1
67
+ lightning-utilities==0.14.3
68
+ lpips==0.1.4
69
+ matplotlib==3.10.1
70
+ matplotlib-inline==0.1.7
71
+ mdurl==0.1.2
72
+ modelscope==1.25.0
73
+ moviepy==2.1.2
74
+ mpmath==1.3.0
75
+ msgpack==1.1.0
76
+ multidict==6.4.3
77
+ multiprocess==0.70.16
78
+ ninja==1.11.1.4
79
+ numpy==2.2.5
80
+ omegaconf==2.3.0
81
+ opencv-python==4.11.0.86
82
+ orjson==3.10.16
83
+ packaging==24.2
84
+ pandas==2.2.3
85
+ parso==0.8.4
86
+ peft==0.15.2
87
+ pexpect==4.9.0
88
+ pillow==10.4.0
89
+ platformdirs==4.3.7
90
+ proglog==0.1.11
91
+ prompt_toolkit==3.0.51
92
+ propcache==0.3.1
93
+ protobuf==5.29.4
94
+ psutil==7.0.0
95
+ ptyprocess==0.7.0
96
+ pure_eval==0.2.3
97
+ py-cpuinfo==9.0.0
98
+ pyarrow==19.0.1
99
+ pycryptodome==3.22.0
100
+ pydantic==2.11.3
101
+ pydantic_core==2.33.1
102
+ pydub==0.25.1
103
+ Pygments==2.19.1
104
+ pynvml==12.0.0
105
+ pyparsing==3.2.3
106
+ python-dateutil==2.9.0.post0
107
+ python-dotenv==1.1.0
108
+ python-multipart==0.0.20
109
+ pytorch-fid==0.3.0
110
+ pytorch-lightning==2.5.1
111
+ pytz==2025.2
112
+ PyYAML==6.0.2
113
+ pyzmq==27.0.0
114
+ regex==2024.11.6
115
+ requests==2.32.3
116
+ rich==14.0.0
117
+ ruff==0.11.6
118
+ safehttpx==0.1.6
119
+ safetensors==0.5.3
120
+ scikit-image==0.25.2
121
+ scikit-learn==1.6.1
122
+ scipy==1.15.2
123
+ semantic-version==2.10.0
124
+ sentencepiece==0.2.0
125
+ sentry-sdk==2.26.1
126
+ setproctitle==1.3.5
127
+ shellingham==1.5.4
128
+ simplejson==3.20.1
129
+ six==1.17.0
130
+ smmap==5.0.2
131
+ sniffio==1.3.1
132
+ soupsieve==2.7
133
+ stack-data==0.6.3
134
+ starlette==0.46.2
135
+ sympy==1.13.1
136
+ taming-transformers==0.0.1
137
+ tensorboard==2.19.0
138
+ tokenizers==0.20.3
139
+ torch==2.6.0
140
+ torchaudio==2.6.0
141
+ torchdiffeq==0.2.5
142
+ torchmetrics==1.7.1
143
+ torchsde==0.2.6
144
+ torchvision==0.21.0
145
+ tqdm==4.67.1
146
+ transformers==4.46.2
147
+ triton==3.2.0
148
+ xformers==0.0.29.post2
samples/1_image1.png ADDED

Git LFS Details

  • SHA256: 3a02da307776afbf196bcf3001b6e6b334154cc016ab23cf7863156dd1e80dd3
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
samples/1_out.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa51ac0653a18dc20b9a6946aaa1a7923d58fe291e926908703c300a4d13c4a2
3
+ size 356550
samples/1_prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ['在海底,一个上身赤裸的的男子和一个螺旋游动的蓝鱼嬉戏。鲸鱼跟着男人手里拿的袋子绕圈,男子拿着袋子引诱着蓝鱼向前游动。Anime. High quality.']
samples/1_sketch1.jpg ADDED

Git LFS Details

  • SHA256: a26bfa93807f5aff098ed3147a4e8e543d4cbe7a8d184c82c3e0e161eb8556db
  • Pointer size: 130 Bytes
  • Size of remote file: 61.3 kB
samples/1_sketch2.jpg ADDED

Git LFS Details

  • SHA256: 9327a6acc26a54a4f45132ceccacfb0d014f85f11965aa28a4ad5dab7a3b7114
  • Pointer size: 130 Bytes
  • Size of remote file: 57.2 kB
samples/1_sketch3.jpg ADDED

Git LFS Details

  • SHA256: d1866a5cf3e392525f25422824e5bb9b28838ea96e2f2d7f99ac428d14ed6053
  • Pointer size: 130 Bytes
  • Size of remote file: 58 kB
samples/2_image1.jpg ADDED

Git LFS Details

  • SHA256: a527a886764611d46d9921d2e11bda0b40f4b08c266d46c3bb8e42a179723537
  • Pointer size: 130 Bytes
  • Size of remote file: 62.9 kB
samples/2_out.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9f28ab63b4fc5b07c0ed01f715ec671f8d839b8783dc8a432c7764bd35605f5
3
+ size 151565
samples/2_prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ['一个女孩和一个银发男孩种下了一颗巨大的花,随着镜头缓慢向上移动,这个巨大的花不断生长变大并开放。Anime. High quality.']
samples/2_sketch1.jpg ADDED

Git LFS Details

  • SHA256: 65cb81959a7e1e9e779c04ea5bd33630e1d46d4f434f26b85514a8ed833a6b65
  • Pointer size: 130 Bytes
  • Size of remote file: 51.2 kB
samples/2_sketch2.jpg ADDED

Git LFS Details

  • SHA256: 62b6c56e32e29c1588df29f229d3743263e8aac56b041a4f0d627159cdc492ef
  • Pointer size: 130 Bytes
  • Size of remote file: 57.1 kB
samples/3_image1.png ADDED

Git LFS Details

  • SHA256: 968fd96a40d945afdd70f485200c1e2ee17750290493ce8d6f79e5c337da0f91
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
samples/3_out.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdb131043289d831f7c4e0d3dd4a21ecc3c4eecca1bf3ae539bb14414c439cde
3
+ size 87909
samples/3_prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ['一个古代中国男孩拿着苹果,笑眯眯地送给旁边的老人。Anime. High quality.']
samples/3_sketch1.jpg ADDED

Git LFS Details

  • SHA256: 346d166a4f69d18b666a4d550b63154e05369994e23403068086b59893803873
  • Pointer size: 130 Bytes
  • Size of remote file: 73.9 kB
samples/ToonComposer-Icon.png ADDED

Git LFS Details

  • SHA256: 79e20f3daac212a6bd7e31646e12cb3bed399798f119393241543daea035f2dd
  • Pointer size: 131 Bytes
  • Size of remote file: 704 kB
samples/ToonComposer-Method.jpg ADDED

Git LFS Details

  • SHA256: 9052fdd65dfc6d26f2f52e3b23e93dc7e5c69ac65f1e801292f6341e1a50609c
  • Pointer size: 131 Bytes
  • Size of remote file: 421 kB
samples/ToonComposer-TLDR.jpg ADDED

Git LFS Details

  • SHA256: 962b56e9858c45305f7d5df711ba6551b052238980443dd1b1d76f1465dd6f31
  • Pointer size: 131 Bytes
  • Size of remote file: 277 kB
scheduler/__init__.py ADDED
File without changes
scheduler/flow_match.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class FlowMatchScheduler():
5
+
6
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
7
+ self.num_train_timesteps = num_train_timesteps
8
+ self.shift = shift
9
+ self.sigma_max = sigma_max
10
+ self.sigma_min = sigma_min
11
+ self.inverse_timesteps = inverse_timesteps
12
+ self.extra_one_step = extra_one_step
13
+ self.reverse_sigmas = reverse_sigmas
14
+ self.set_timesteps(num_inference_steps)
15
+
16
+
17
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
18
+ if shift is not None:
19
+ self.shift = shift
20
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
21
+ if self.extra_one_step:
22
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
23
+ else:
24
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
25
+ if self.inverse_timesteps:
26
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
27
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
28
+ if self.reverse_sigmas:
29
+ self.sigmas = 1 - self.sigmas
30
+ self.timesteps = self.sigmas * self.num_train_timesteps
31
+ if training:
32
+ x = self.timesteps
33
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
34
+ y_shifted = y - y.min()
35
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
36
+ self.linear_timesteps_weights = bsmntw_weighing
37
+
38
+
39
+ def step(self, model_output, timestep, sample, to_final=False):
40
+ if isinstance(timestep, torch.Tensor):
41
+ timestep = timestep.cpu()
42
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
43
+ sigma = self.sigmas[timestep_id]
44
+ if to_final or timestep_id + 1 >= len(self.timesteps):
45
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
46
+ else:
47
+ sigma_ = self.sigmas[timestep_id + 1]
48
+ prev_sample = sample + model_output * (sigma_ - sigma)
49
+ return prev_sample
50
+
51
+
52
+ def return_to_timestep(self, timestep, sample, sample_stablized):
53
+ if isinstance(timestep, torch.Tensor):
54
+ timestep = timestep.cpu()
55
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
56
+ sigma = self.sigmas[timestep_id]
57
+ model_output = (sample - sample_stablized) / sigma
58
+ return model_output
59
+
60
+
61
+ def add_noise(self, original_samples, noise, timestep):
62
+ if isinstance(timestep, torch.Tensor):
63
+ timestep = timestep.cpu()
64
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
65
+ sigma = self.sigmas[timestep_id]
66
+ sample = (1 - sigma) * original_samples + sigma * noise
67
+ return sample
68
+
69
+
70
+ def training_target(self, sample, noise, timestep):
71
+ target = noise - sample
72
+ return target
73
+
74
+
75
+ def training_weight(self, timestep):
76
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
77
+ weights = self.linear_timesteps_weights[timestep_id]
78
+ return weights
tooncomposer.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, lightning, imageio
2
+ from peft import LoraConfig, inject_adapter_in_model
3
+ import numpy as np
4
+
5
+ from pipeline.i2v_pipeline import WanVideoPipeline
6
+
7
+
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
+ torch.set_float32_matmul_precision('medium')
10
+
11
+
12
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
13
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
14
+ for frame in frames:
15
+ frame = np.array(frame)
16
+ writer.append_data(frame)
17
+ writer.close()
18
+
19
+
20
+ def get_base_model_paths(base_model_name, format='dict', model_root="./weights"):
21
+ if base_model_name == "Wan2.1-I2V-14B-480P":
22
+ if format == 'list':
23
+ return [
24
+ [os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
25
+ os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
26
+ os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
27
+ os.path.join(model_root, "Wan2.1_VAE.pth")
28
+ ]
29
+ elif format == 'dict':
30
+ return {
31
+ "dit": [os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
32
+ "image_encoder": os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
33
+ "text_encoder": os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
34
+ "vae": os.path.join(model_root, "Wan2.1_VAE.pth")
35
+ }
36
+ else:
37
+ raise ValueError(f"Unsupported format: {format}")
38
+ else:
39
+ raise ValueError(f"Unsupported base model name: {base_model_name}")
40
+
41
+
42
+ class ToonComposer(lightning.LightningModule):
43
+ def __init__(self, base_model_name="Wan2.1-I2V-14B-480P", model_root=None, learning_rate=1e-5, lora_rank=4, lora_alpha=4,
44
+ train_architecture=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2",
45
+ init_lora_weights="kaiming", use_gradient_checkpointing=True,
46
+ checkpoint_path=None, video_condition_preservation_mode="first_and_last",
47
+ tiled=False, tile_size=(34, 34), tile_stride=(18, 16), output_path=None,
48
+ use_local_lora=False, use_dera=False, dera_rank=None, use_dera_spatial=True, use_dera_temporal=True, use_sequence_cond=False, sequence_cond_mode="sparse",
49
+ use_channel_cond=False,
50
+ use_sequence_cond_position_aware_residual=False,
51
+ use_sequence_cond_loss=False, fast_dev=False,
52
+ max_num_cond_images=1, max_num_cond_sketches=2, visualize_attention=False,
53
+ random_spaced_cond_frames=False, use_sketch_mask=False, sketch_mask_ratio=0.2, no_first_sketch=False,
54
+ test_sampling_steps=15, test_sequence_cond_residual_scale=0.5, height=480, width=832):
55
+ super().__init__()
56
+
57
+ self.pipe = WanVideoPipeline(device="cpu", torch_dtype=torch.bfloat16)
58
+ self.use_local_lora = use_local_lora
59
+ self.use_dera = use_dera
60
+ self.use_dera_spatial = use_dera_spatial
61
+ self.use_dera_temporal = use_dera_temporal
62
+ self.use_sequence_cond = use_sequence_cond
63
+ self.sequence_cond_mode = sequence_cond_mode
64
+ self.use_channel_cond = use_channel_cond
65
+ self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
66
+ assert not (use_sequence_cond and use_channel_cond), "Cannot use both sequence condition and channel condition."
67
+ self.use_sequence_cond_loss = use_sequence_cond_loss
68
+
69
+ self.max_num_cond_images = max_num_cond_images
70
+ self.max_num_cond_sketches = max_num_cond_sketches
71
+
72
+ self.visualize_attention = visualize_attention
73
+ self.random_spaced_cond_frames = random_spaced_cond_frames
74
+ self.use_sketch_mask = use_sketch_mask
75
+ self.sketch_mask_ratio = sketch_mask_ratio
76
+ self.no_first_sketch = no_first_sketch
77
+ self.test_sampling_steps = test_sampling_steps
78
+ self.test_sequence_cond_residual_scale = test_sequence_cond_residual_scale
79
+
80
+ self.height = height
81
+ self.width = width
82
+
83
+ self.current_checkpoint_path = None
84
+
85
+ paths = get_base_model_paths(base_model_name, format='dict', model_root=model_root)
86
+ if use_sequence_cond:
87
+ assert sequence_cond_mode in ["sparse", "full"], f"Unsupported sequence condition model: {sequence_cond_mode}"
88
+ if sequence_cond_mode == "sparse":
89
+ if use_sketch_mask:
90
+ sequence_cond_in_dim = 24
91
+ else:
92
+ sequence_cond_in_dim = 20
93
+ else:
94
+ sequence_cond_in_dim = 20
95
+ use_channel_cond = False
96
+ channel_cond_in_dim = None
97
+ elif use_channel_cond:
98
+ channel_cond_in_dim = 20
99
+ sequence_cond_in_dim = None
100
+ use_sequence_cond = False
101
+
102
+ dit_config = {
103
+ "use_local_lora": use_local_lora,
104
+ "use_dera": use_dera,
105
+ "dera_rank": dera_rank,
106
+ "use_dera_spatial": use_dera_spatial,
107
+ "use_dera_temporal": use_dera_temporal,
108
+ "use_sequence_cond": use_sequence_cond,
109
+ "sequence_cond_mode": sequence_cond_mode,
110
+ "sequence_cond_in_dim": sequence_cond_in_dim,
111
+ "use_channel_cond": use_channel_cond,
112
+ "channel_cond_in_dim": channel_cond_in_dim,
113
+ "use_sequence_cond_position_aware_residual": use_sequence_cond_position_aware_residual,
114
+ "use_sequence_cond_loss": use_sequence_cond_loss
115
+ }
116
+ if fast_dev:
117
+ del paths["dit"]
118
+ dit_config.update({
119
+ "model_type": "i2v",
120
+ "patch_size": (1, 2, 2),
121
+ "text_len": 512,
122
+ "in_dim": 36,
123
+ "dim": 512,
124
+ "ffn_dim": 512,
125
+ "freq_dim": 256,
126
+ "text_dim": 4096,
127
+ "out_dim": 16,
128
+ "num_heads": 2, # 40
129
+ "num_layers": 40,
130
+ "window_size": (-1, -1),
131
+ "qk_norm": True,
132
+ "cross_attn_norm": True,
133
+ "eps": 1e-6,
134
+ })
135
+ self.pipe.initialize_dummy_dit(dit_config)
136
+
137
+ self.pipe.fetch_models_from_checkpoints(
138
+ paths,
139
+ config_dict={
140
+ "dit": dit_config
141
+ })
142
+
143
+ if use_sequence_cond:
144
+ self.pipe.denoising_model().copy_sequence_cond_patch_embedding_weights()
145
+ elif use_channel_cond:
146
+ self.pipe.denoising_model().copy_patch_embedding_weights_for_channel_cond()
147
+
148
+ self.freeze_parameters()
149
+ if train_architecture == "lora":
150
+ self.add_lora_to_model(
151
+ self.pipe.denoising_model(),
152
+ lora_rank=lora_rank,
153
+ lora_alpha=lora_alpha,
154
+ lora_target_modules=lora_target_modules,
155
+ init_lora_weights=init_lora_weights
156
+ )
157
+ elif train_architecture == "full":
158
+ self.pipe.denoising_model().requires_grad_(True)
159
+
160
+ if checkpoint_path is not None:
161
+ self.load_tooncomposer_checkpoint(checkpoint_path)
162
+
163
+ self.learning_rate = learning_rate
164
+ self.use_gradient_checkpointing = use_gradient_checkpointing
165
+
166
+ self.pipe.scheduler.set_timesteps(1000, training=True)
167
+ self.vae_tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
168
+ self.video_condition_preservation_mode = video_condition_preservation_mode
169
+ self.negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
170
+
171
+ if output_path is None:
172
+ output_path = "./"
173
+ self.output_path = output_path
174
+
175
+ def load_tooncomposer_checkpoint(self, checkpoint_path):
176
+ if checkpoint_path == self.current_checkpoint_path:
177
+ print(f"Skipping loading checkpoint {checkpoint_path} because it is the same as the current checkpoint.")
178
+ return
179
+ self.current_checkpoint_path = checkpoint_path
180
+ self.load_patch_to_model(
181
+ self.pipe.denoising_model(),
182
+ checkpoint_path
183
+ )
184
+
185
+ def update_height_width(self, height, width):
186
+ self.height = height
187
+ self.width = width
188
+
189
+ def freeze_parameters(self):
190
+ self.pipe.requires_grad_(False)
191
+ self.pipe.eval()
192
+ self.pipe.denoising_model().train()
193
+
194
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
195
+ self.lora_alpha = lora_alpha
196
+ if init_lora_weights == "kaiming":
197
+ init_lora_weights = True
198
+
199
+ lora_config = LoraConfig(
200
+ r=lora_rank,
201
+ lora_alpha=lora_alpha,
202
+ init_lora_weights=init_lora_weights,
203
+ target_modules=lora_target_modules.split(","),
204
+ )
205
+ model = inject_adapter_in_model(lora_config, model)
206
+ for param in model.parameters():
207
+ if param.requires_grad:
208
+ param.data = param.to(torch.float32)
209
+
210
+ def load_patch_to_model(self, model, pretrained_path, state_dict_converter=None):
211
+ if pretrained_path is not None:
212
+ state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
213
+ self.loaded_global_step = 0
214
+ self.loaded_current_epoch = 0
215
+ if self.use_sketch_mask:
216
+ seq_cond_embed_weight = state_dict['sequence_cond_patch_embedding.weight']
217
+ current_in_channels = self.pipe.denoising_model().sequence_cond_patch_embedding.in_channels
218
+ if current_in_channels == 24 and seq_cond_embed_weight.shape[1] == 20:
219
+ new_weight = torch.zeros(
220
+ seq_cond_embed_weight.shape[0],
221
+ 4,
222
+ *seq_cond_embed_weight.shape[2:],
223
+ dtype=seq_cond_embed_weight.dtype
224
+ )
225
+ state_dict['sequence_cond_patch_embedding.weight'] = torch.cat([
226
+ seq_cond_embed_weight, new_weight], dim=1)
227
+
228
+ if state_dict_converter is not None:
229
+ state_dict = state_dict_converter(state_dict)
230
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
231
+ all_keys = [i for i, _ in model.named_parameters()]
232
+ num_updated_keys = len(all_keys) - len(missing_keys)
233
+ num_unexpected_keys = len(unexpected_keys)
234
+ print(f"[Checkpoint] {num_updated_keys} parameters are loaded from {pretrained_path}. {num_unexpected_keys} parameters are unexpected.")
util/model_util.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+ from contextlib import contextmanager
4
+ import hashlib
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.colors import LinearSegmentedColormap
7
+ import numpy as np
8
+
9
+ @contextmanager
10
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
11
+
12
+ old_register_parameter = torch.nn.Module.register_parameter
13
+ if include_buffers:
14
+ old_register_buffer = torch.nn.Module.register_buffer
15
+
16
+ def register_empty_parameter(module, name, param):
17
+ old_register_parameter(module, name, param)
18
+ if param is not None:
19
+ param_cls = type(module._parameters[name])
20
+ kwargs = module._parameters[name].__dict__
21
+ kwargs["requires_grad"] = param.requires_grad
22
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
23
+
24
+ def register_empty_buffer(module, name, buffer, persistent=True):
25
+ old_register_buffer(module, name, buffer, persistent=persistent)
26
+ if buffer is not None:
27
+ module._buffers[name] = module._buffers[name].to(device)
28
+
29
+ def patch_tensor_constructor(fn):
30
+ def wrapper(*args, **kwargs):
31
+ kwargs["device"] = device
32
+ return fn(*args, **kwargs)
33
+
34
+ return wrapper
35
+
36
+ if include_buffers:
37
+ tensor_constructors_to_patch = {
38
+ torch_function_name: getattr(torch, torch_function_name)
39
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
40
+ }
41
+ else:
42
+ tensor_constructors_to_patch = {}
43
+
44
+ try:
45
+ torch.nn.Module.register_parameter = register_empty_parameter
46
+ if include_buffers:
47
+ torch.nn.Module.register_buffer = register_empty_buffer
48
+ for torch_function_name in tensor_constructors_to_patch.keys():
49
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
50
+ yield
51
+ finally:
52
+ torch.nn.Module.register_parameter = old_register_parameter
53
+ if include_buffers:
54
+ torch.nn.Module.register_buffer = old_register_buffer
55
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
56
+ setattr(torch, torch_function_name, old_torch_function)
57
+
58
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
59
+ state_dict = {}
60
+ for file_name in os.listdir(file_path):
61
+ if "." in file_name and file_name.split(".")[-1] in [
62
+ "safetensors", "bin", "ckpt", "pth", "pt"
63
+ ]:
64
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
65
+ return state_dict
66
+
67
+
68
+ def load_state_dict(file_path, torch_dtype=None):
69
+ if file_path.endswith(".safetensors"):
70
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
71
+ else:
72
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
73
+
74
+
75
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
76
+ state_dict = {}
77
+ with safe_open(file_path, framework="pt", device="cpu") as f:
78
+ for k in f.keys():
79
+ state_dict[k] = f.get_tensor(k)
80
+ if torch_dtype is not None:
81
+ state_dict[k] = state_dict[k].to(torch_dtype)
82
+ return state_dict
83
+
84
+
85
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
86
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
87
+ if torch_dtype is not None:
88
+ for i in state_dict:
89
+ if isinstance(state_dict[i], torch.Tensor):
90
+ state_dict[i] = state_dict[i].to(torch_dtype)
91
+ return state_dict
92
+
93
+
94
+ def search_for_embeddings(state_dict):
95
+ embeddings = []
96
+ for k in state_dict:
97
+ if isinstance(state_dict[k], torch.Tensor):
98
+ embeddings.append(state_dict[k])
99
+ elif isinstance(state_dict[k], dict):
100
+ embeddings += search_for_embeddings(state_dict[k])
101
+ return embeddings
102
+
103
+
104
+ def search_parameter(param, state_dict):
105
+ for name, param_ in state_dict.items():
106
+ if param.numel() == param_.numel():
107
+ if param.shape == param_.shape:
108
+ if torch.dist(param, param_) < 1e-3:
109
+ return name
110
+ else:
111
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
112
+ return name
113
+ return None
114
+
115
+
116
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
117
+ matched_keys = set()
118
+ with torch.no_grad():
119
+ for name in source_state_dict:
120
+ rename = search_parameter(source_state_dict[name], target_state_dict)
121
+ if rename is not None:
122
+ print(f'"{name}": "{rename}",')
123
+ matched_keys.add(rename)
124
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
125
+ length = source_state_dict[name].shape[0] // 3
126
+ rename = []
127
+ for i in range(3):
128
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
129
+ if None not in rename:
130
+ print(f'"{name}": {rename},')
131
+ for rename_ in rename:
132
+ matched_keys.add(rename_)
133
+ for name in target_state_dict:
134
+ if name not in matched_keys:
135
+ print("Cannot find", name, target_state_dict[name].shape)
136
+
137
+
138
+ def search_for_files(folder, extensions):
139
+ files = []
140
+ if os.path.isdir(folder):
141
+ for file in sorted(os.listdir(folder)):
142
+ files += search_for_files(os.path.join(folder, file), extensions)
143
+ elif os.path.isfile(folder):
144
+ for extension in extensions:
145
+ if folder.endswith(extension):
146
+ files.append(folder)
147
+ break
148
+ return files
149
+
150
+
151
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
152
+ keys = []
153
+ for key, value in state_dict.items():
154
+ if isinstance(key, str):
155
+ if isinstance(value, torch.Tensor):
156
+ if with_shape:
157
+ shape = "_".join(map(str, list(value.shape)))
158
+ keys.append(key + ":" + shape)
159
+ keys.append(key)
160
+ elif isinstance(value, dict):
161
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
162
+ keys.sort()
163
+ keys_str = ",".join(keys)
164
+ return keys_str
165
+
166
+
167
+ def split_state_dict_with_prefix(state_dict):
168
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
169
+ prefix_dict = {}
170
+ for key in keys:
171
+ prefix = key if "." not in key else key.split(".")[0]
172
+ if prefix not in prefix_dict:
173
+ prefix_dict[prefix] = []
174
+ prefix_dict[prefix].append(key)
175
+ state_dicts = []
176
+ for prefix, keys in prefix_dict.items():
177
+ sub_state_dict = {key: state_dict[key] for key in keys}
178
+ state_dicts.append(sub_state_dict)
179
+ return state_dicts
180
+
181
+
182
+ def hash_state_dict_keys(state_dict, with_shape=True):
183
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
184
+ keys_str = keys_str.encode(encoding="UTF-8")
185
+ return hashlib.md5(keys_str).hexdigest()
186
+
187
+
188
+ def save_attention_maps(model, output_path, batch_idx, timestep, layer_indices=None):
189
+ """
190
+ Visualize and save the attention maps from selected layers of the model
191
+
192
+ Args:
193
+ model: The DiT model with attention maps stored
194
+ output_path: Directory to save visualizations
195
+ batch_idx: Current batch index for file naming
196
+ layer_indices: List of layer indices to visualize (if None, visualize all)
197
+ """
198
+ timestep = int(float(str(timestep)))
199
+ os.makedirs(os.path.join(output_path, "attention_maps"), exist_ok=True)
200
+
201
+ # If layer indices not specified, visualize all layers
202
+ if layer_indices is None:
203
+ layer_indices = range(len(model.blocks))
204
+
205
+ # Create a custom colormap (similar to the ones used in attention visualization papers)
206
+ colors = [(0, 0, 0.5), (0, 0, 1), (0, 0.5, 1), (0, 1, 1),
207
+ (0.5, 1, 0.5), (1, 1, 0), (1, 0.5, 0), (1, 0, 0), (0.5, 0, 0)]
208
+ attention_cmap = LinearSegmentedColormap.from_list('attention_cmap', colors)
209
+
210
+ for i in layer_indices:
211
+ if not hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
212
+ continue
213
+
214
+ attn_map = model.blocks[i].self_attn._last_attn_maps
215
+ grid_size = model.blocks[i].self_attn._last_grid_sizes
216
+ seq_len = model.blocks[i].self_attn._last_seq_lens
217
+ # attn_maps.shape=[s, s]
218
+ np.savez_compressed(os.path.join(output_path,
219
+ "attention_maps",
220
+ f"attn_maps_layer{i}_batch{batch_idx}_t{timestep}.npz"),
221
+ attn_map=attn_map, grid_size=grid_size, seq_len=seq_len)
222
+
223
+ print(f"Saving Layer {i}, Batch {batch_idx} attention maps")
224
+ attn_map -= attn_map.min()
225
+ attn_map /= attn_map.max()
226
+ plt.figure(figsize=(10, 8))
227
+ plt.imshow(attn_map ** 0.25, cmap=attention_cmap)
228
+ plt.colorbar(label='Attention Weight')
229
+ plt.title(f'Layer {i}, Batch {batch_idx} (Average)')
230
+ save_path = os.path.join(
231
+ output_path,
232
+ "attention_maps",
233
+ f"attn_map_layer{i}_average_batch{batch_idx}_t{timestep}.png"
234
+ )
235
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
236
+ plt.close()
237
+
238
+ # Clean up the stored attention maps to free memory
239
+ for i in layer_indices:
240
+ if hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
241
+ del model.blocks[i].self_attn._last_attn_maps
util/optical_flow.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
6
+ from typing import List, Tuple, Dict
7
+ import argparse
8
+ from pathlib import Path
9
+ from sklearn.cluster import KMeans
10
+ from tqdm import tqdm
11
+ import os
12
+
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '64'
14
+
15
+ class OpticalFlowAnalyzer:
16
+ def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
17
+ self.device = device
18
+ self.model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
19
+ self.model.eval()
20
+
21
+ def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
22
+ """Preprocess a frame for RAFT model."""
23
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
24
+ frame = torch.from_numpy(frame).permute(2, 0, 1).float()
25
+ frame = frame.unsqueeze(0) / 255.0
26
+ return frame.to(self.device)
27
+
28
+ def compute_optical_flow(self, frame1: np.ndarray, frame2: np.ndarray) -> np.ndarray:
29
+ """Compute optical flow between two consecutive frames."""
30
+ with torch.no_grad():
31
+ frame1_tensor = self.preprocess_frame(frame1)
32
+ frame2_tensor = self.preprocess_frame(frame2)
33
+
34
+ flow = self.model(frame1_tensor, frame2_tensor)[-1]
35
+ flow = flow[0].permute(1, 2, 0).cpu().numpy()
36
+
37
+ return flow
38
+
39
+ def analyze_motion_regions(self, flow: np.ndarray, num_clusters: int = 3) -> Tuple[np.ndarray, Dict]:
40
+ """Cluster motion regions based on optical flow magnitude and direction."""
41
+ h, w = flow.shape[:2]
42
+ magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
43
+ direction = np.arctan2(flow[..., 1], flow[..., 0])
44
+
45
+ # Create feature matrix for clustering
46
+ features = np.zeros((h * w, 3))
47
+ features[:, 0] = magnitude.ravel()
48
+ features[:, 1] = np.cos(direction).ravel()
49
+ features[:, 2] = np.sin(direction).ravel()
50
+
51
+ # Normalize features
52
+ features = (features - features.mean(axis=0)) / features.std(axis=0)
53
+
54
+ # Perform clustering
55
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42,)
56
+ labels = kmeans.fit_predict(features)
57
+ labels = labels.reshape(h, w)
58
+
59
+ # Analyze clusters
60
+ cluster_stats = {}
61
+ for i in range(num_clusters):
62
+ cluster_mask = (labels == i)
63
+ cluster_magnitude = magnitude[cluster_mask]
64
+ cluster_stats[i] = {
65
+ 'mean_magnitude': np.mean(cluster_magnitude),
66
+ 'std_magnitude': np.std(cluster_magnitude),
67
+ 'pixel_count': np.sum(cluster_mask),
68
+ 'is_static': np.mean(cluster_magnitude) < 0.1 # Threshold for static regions
69
+ }
70
+
71
+ return labels, cluster_stats
72
+
73
+ def process_video(self, video_path: str, output_path: str = None) -> List[Tuple[np.ndarray, Dict]]:
74
+ """Process a video and return motion analysis results for each frame pair."""
75
+ cap = cv2.VideoCapture(video_path)
76
+ if not cap.isOpened():
77
+ raise ValueError(f"Could not open video: {video_path}")
78
+
79
+ results = []
80
+ ret, prev_frame = cap.read()
81
+ if not ret:
82
+ raise ValueError("Could not read first frame")
83
+
84
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
85
+ pbar = tqdm(total=total_frames-1, desc="Processing video")
86
+
87
+ while True:
88
+ ret, curr_frame = cap.read()
89
+ if not ret:
90
+ break
91
+
92
+ flow = self.compute_optical_flow(prev_frame, curr_frame)
93
+ labels, stats = self.analyze_motion_regions(flow)
94
+
95
+ if output_path:
96
+ # Visualize results
97
+ vis_frame = curr_frame.copy()
98
+ for i, stat in stats.items():
99
+ if not stat['is_static']:
100
+ mask = (labels == i).astype(np.uint8) * 255
101
+ print("mask:",mask.shape)
102
+ print("vis_frame:",vis_frame.shape)
103
+ mask = np.expand_dims(mask, axis=-1).repeat(3, axis=-1)
104
+ print("mask:",mask.shape)
105
+
106
+ vis_frame[mask > 0] = cv2.addWeighted(vis_frame[mask > 0], 0.7, 255, 0.3, 0)
107
+
108
+ cv2.imwrite(f"{output_path}/frame_{len(results):04d}.jpg", vis_frame)
109
+
110
+ results.append((labels, stats))
111
+ prev_frame = curr_frame
112
+ pbar.update(1)
113
+
114
+ cap.release()
115
+ pbar.close()
116
+ return results
117
+
118
+ def main():
119
+ parser = argparse.ArgumentParser(description='Analyze motion regions in a video using RAFT optical flow')
120
+ parser.add_argument('--video', type=str, required=True, help='Path to input video')
121
+ parser.add_argument('--output', type=str, help='Path to output directory for visualization')
122
+ parser.add_argument('--clusters', type=int, default=3, help='Number of motion clusters')
123
+ args = parser.parse_args()
124
+
125
+ analyzer = OpticalFlowAnalyzer()
126
+ results = analyzer.process_video(args.video, args.output)
127
+
128
+ # Print summary statistics
129
+ print("\nMotion Analysis Summary:")
130
+ for i, (_, stats) in enumerate(results):
131
+ print(f"\nFrame {i+1}:")
132
+ for cluster_id, stat in stats.items():
133
+ motion_type = "Static" if stat['is_static'] else "Moving"
134
+ print(f" Cluster {cluster_id} ({motion_type}):")
135
+ print(f" Mean magnitude: {stat['mean_magnitude']:.4f}")
136
+ print(f" Pixel count: {stat['pixel_count']}")
137
+
138
+ if __name__ == "__main__":
139
+ main()
140
+
util/stylesheets.py ADDED
The diff for this file is too large to render. See raw diff
 
util/training_util.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+
8
+
9
+ def create_random_mask(batch_size, num_frames, height, width, device, dtype, shape_type=None):
10
+ """
11
+ Create random masks for sketch frames.
12
+
13
+ Args:
14
+ batch_size: Batch size
15
+ num_frames: Number of frames to mask
16
+ height, width: Image dimensions
17
+ device: Device for tensor
18
+ dtype: Data type for tensor
19
+ mask_area_ratio: Ratio of area to mask (0-1)
20
+ shape_type: Type of shape for masking ('square', 'circle', 'random'). If None, one is randomly selected.
21
+
22
+ Returns:
23
+ Mask tensor in [b, 1, num_frames, height, width] where 0 indicates areas to mask (inverse of previous implementation)
24
+ """
25
+ # Initialize with ones (unmasked)
26
+ masks = torch.ones(batch_size, 1, num_frames, height, width, device=device, dtype=dtype)
27
+
28
+ for b in range(batch_size):
29
+ for f in range(num_frames):
30
+ # Randomly select shape type if not specified
31
+ if shape_type is None:
32
+ shape_type = random.choice(['square', 'circle', 'random'])
33
+
34
+ # Create numpy mask for easier shape drawing
35
+ mask = np.zeros((height, width), dtype=np.float32)
36
+
37
+ if shape_type == 'square':
38
+ # Random squares
39
+ num_squares = random.randint(1, 5)
40
+ for _ in range(num_squares):
41
+ # Random square size (proportional to image dimensions)
42
+ max_size = min(height, width)
43
+ size = random.randint(max_size // 4, max_size)
44
+
45
+ # Random position
46
+ x = random.randint(0, width - size)
47
+ y = random.randint(0, height - size)
48
+
49
+ # Draw square
50
+ mask[y:y+size, x:x+size] = 1.0
51
+
52
+ elif shape_type == 'circle':
53
+ # Random circles
54
+ num_circles = random.randint(1, 5)
55
+ for _ in range(num_circles):
56
+ # Random radius (proportional to image dimensions)
57
+ max_radius = min(height, width) // 2
58
+ radius = random.randint(max_radius // 4, max_radius)
59
+
60
+ # Random center
61
+ center_x = random.randint(radius, width - radius)
62
+ center_y = random.randint(radius, height - radius)
63
+
64
+ # Draw circle
65
+ cv2.circle(mask, (center_x, center_y), radius, 1.0, -1)
66
+
67
+ elif shape_type == 'random':
68
+ # Create connected random shape with cv2
69
+ num_points = random.randint(5, 16)
70
+ points = []
71
+
72
+ # Generate random points
73
+ for _ in range(num_points):
74
+ x = random.randint(0, width - 1)
75
+ y = random.randint(0, height - 1)
76
+ points.append([x, y])
77
+
78
+ # Convert to numpy array for cv2
79
+ points = np.array(points, dtype=np.int32)
80
+
81
+ # Draw filled polygon
82
+ cv2.fillPoly(mask, [points], 1.0)
83
+
84
+ # Convert numpy mask to tensor and subtract from ones (inverse the mask)
85
+ masks[b, 0, f] = 1.0 - torch.from_numpy(mask).to(device=device, dtype=dtype)
86
+
87
+ return masks
88
+
89
+
90
+ @torch.no_grad()
91
+ def extract_img_to_sketch(_sketch_model, _img, model_name="random"):
92
+ """
93
+ Return sketch: [-1, 1]
94
+ """
95
+ orig_shape = (_img.shape[-2], _img.shape[-1])
96
+ with torch.amp.autocast(dtype=torch.float32, device_type="cuda"):
97
+ reshaped_img = torch.nn.functional.interpolate(_img, (2048, 2048))
98
+ sketch = _sketch_model(reshaped_img, model_name=model_name)
99
+ sketch = torch.nn.functional.interpolate(sketch, orig_shape)
100
+ if sketch.shape[1] == 1:
101
+ sketch = sketch.repeat(1, 3, 1, 1)
102
+ return sketch
103
+
104
+
105
+ def video_to_frame_and_sketch(
106
+ sketch_model,
107
+ original_video,
108
+ max_num_preserved_sketch_frames=2,
109
+ max_num_preserved_image_frames=1,
110
+ min_num_preserved_sketch_frames=2,
111
+ min_num_preserved_image_frames=1,
112
+ model_name=None,
113
+ detach_image_and_sketch=False,
114
+ equally_spaced_preserve_sketch=False,
115
+ apply_sketch_mask=False,
116
+ sketch_mask_ratio=0.2,
117
+ sketch_mask_shape=None,
118
+ no_first_sketch: Union[bool, float] = False,
119
+ video_clip_names=None,
120
+ is_flux_sketch_available=None,
121
+ is_evaluation=False,
122
+ ):
123
+ """
124
+ Args:
125
+ sketch_model: torch.nn.Module, a sketch pool for extracting sketches from images
126
+ original_video: torch.Tensor, shape=(batch_size, num_channels, num_frames, height, width)
127
+ max_num_preserved_sketch_frames: int, maximum number of preserved sketch frames
128
+ max_num_preserved_image_frames: int, maximum number of preserved image frames
129
+ min_num_preserved_sketch_frames: int, minimum number of preserved sketch frames
130
+ min_num_preserved_image_frames: int, minimum number of preserved image frames
131
+ model_name: str, name of the sketch model. If None, randomly select from ["lineart", "lineart_anime", "anime2sketch"]. Default: None.
132
+ equally_spaced_preserve_sketch: bool, whether to preserve sketches at equally spaced intervals. Default: False.
133
+ apply_sketch_mask: bool, whether to apply random masking to sketch frames. Default: False.
134
+ sketch_mask_ratio: float, ratio of frames to mask (0-1). Default: 0.2.
135
+ sketch_mask_shape: str, shape type for masking ('square', 'circle', 'random'). If None, randomly selected. Default: None.
136
+ Returns:
137
+ conditional_image: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
138
+ preserving_image_mask: torch.Tensor, shape=(batch_size, num_frames, 1, height, width)
139
+ full_sketch_frames: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
140
+ sketch_local_mask: torch.Tensor, shape=(batch_size, 1, num_frames, height, width) or None if apply_sketch_mask=False
141
+ """
142
+ video_shape = original_video.shape
143
+ video_dtype = original_video.dtype
144
+ video_device = original_video.device
145
+
146
+ if min_num_preserved_sketch_frames is None or min_num_preserved_sketch_frames < 2:
147
+ min_num_preserved_sketch_frames = 2 # Minimum num: 2 (the first and the last)
148
+ num_preserved_sketch_frames = random.randint(min_num_preserved_sketch_frames, max_num_preserved_sketch_frames)
149
+ num_preserved_sketch_frames = min(num_preserved_sketch_frames, video_shape[2])
150
+
151
+ # Always include first and last frames
152
+ if video_clip_names is not None and is_flux_sketch_available is not None:
153
+ if is_flux_sketch_available[0]:
154
+ num_preserved_sketch_frames = 2
155
+
156
+ if isinstance(no_first_sketch, float):
157
+ no_first_sketch = random.random() < no_first_sketch
158
+
159
+ if equally_spaced_preserve_sketch:
160
+ preserved_sketch_indices = torch.linspace(0, video_shape[2] - 1, num_preserved_sketch_frames).long().tolist()
161
+ if no_first_sketch:
162
+ preserved_sketch_indices = preserved_sketch_indices[1:]
163
+ else:
164
+ if no_first_sketch:
165
+ preserved_sketch_indices = [video_shape[2] - 1]
166
+ else:
167
+ preserved_sketch_indices = [0, video_shape[2] - 1]
168
+ # If we need more frames than just first and last
169
+ if num_preserved_sketch_frames > 2 and video_shape[2] > 4:
170
+ # Create set of all valid candidates (excluding first, last and their adjacent frames)
171
+ # Exclude indices adjacent to first and last
172
+ candidates = set(range(2, video_shape[2] - 2))
173
+
174
+ # Determine how many additional frames to select
175
+ additional_frames_needed = min(num_preserved_sketch_frames - 2, len(candidates))
176
+
177
+ # Keep selecting frames until we have enough or run out of candidates
178
+ additional_indices = []
179
+ while len(additional_indices) < additional_frames_needed and candidates:
180
+ # Convert set to list for random selection
181
+ candidate_list = list(candidates)
182
+ # Select a random candidate
183
+ idx = random.choice(candidate_list)
184
+ additional_indices.append(idx)
185
+
186
+ # Remove selected index and adjacent indices from candidates
187
+ candidates.remove(idx)
188
+ if idx - 1 in candidates:
189
+ candidates.remove(idx - 1)
190
+ if idx + 1 in candidates:
191
+ candidates.remove(idx + 1)
192
+
193
+ preserved_sketch_indices.extend(additional_indices)
194
+ preserved_sketch_indices.sort()
195
+
196
+ # Indices to preserve has been determined.
197
+ # Later code will not care the number of preserved frames but rely on the indices only.
198
+ preserved_image_indices = [0]
199
+ if max_num_preserved_image_frames is not None and max_num_preserved_image_frames > 1:
200
+ max_num_preserved_image_frames -= 1
201
+ if min_num_preserved_image_frames is None or min_num_preserved_image_frames < 1:
202
+ min_num_preserved_image_frames = 1
203
+ min_num_preserved_image_frames -= 1
204
+ other_indices = torch.tensor([i for i in range(video_shape[2]) if i not in preserved_sketch_indices])
205
+ max_num_preserved_image_frames = min(max_num_preserved_image_frames, len(other_indices))
206
+ min_num_preserved_image_frames = min(min_num_preserved_image_frames, max_num_preserved_image_frames)
207
+ num_preserved_image_frames = random.randint(min_num_preserved_image_frames, max_num_preserved_image_frames)
208
+ other_indices = other_indices[torch.randperm(len(other_indices))]
209
+ if num_preserved_image_frames > 0:
210
+ preserved_image_indices.extend(other_indices[:num_preserved_image_frames])
211
+
212
+ preserved_condition_mask = torch.zeros(size=(video_shape[0], video_shape[2]), dtype=video_dtype, device=video_device) # [b, t]
213
+ masked_condition_video = torch.zeros_like(original_video) # [b, c, t, h, w]
214
+ full_sketch_frames = torch.zeros_like(original_video) # [b, c, t, h, w]
215
+
216
+ if detach_image_and_sketch:
217
+ preserved_condition_mask_sketch = torch.zeros_like(preserved_condition_mask)
218
+ masked_condition_video_sketch = torch.zeros_like(masked_condition_video)
219
+ if 0 not in preserved_sketch_indices and not no_first_sketch:
220
+ preserved_sketch_indices.append(0)
221
+ else:
222
+ preserved_condition_mask_sketch = None
223
+ masked_condition_video_sketch = None
224
+
225
+ for _idx in preserved_image_indices:
226
+ preserved_condition_mask[:, _idx] = 1.0
227
+ masked_condition_video[:, :, _idx, :, :] = original_video[:, :, _idx, :, :]
228
+
229
+ # Set up sketch_local_mask if masking is applied
230
+ sketch_local_mask = None
231
+
232
+ if apply_sketch_mask:
233
+ # Create a full-sized mask initialized to all ones (unmasked)
234
+ sketch_local_mask = torch.ones(
235
+ video_shape[0], video_shape[2], video_shape[3], video_shape[4],
236
+ device=video_device,
237
+ dtype=video_dtype
238
+ ).unsqueeze(1) # Add channel dimension to get [b, 1, t, h, w]
239
+
240
+ if not is_evaluation and random.random() < sketch_mask_ratio:
241
+ # For preserved frames, apply random masking
242
+ for i, frame_idx in enumerate(preserved_sketch_indices):
243
+ if i == 0:
244
+ # First frame is not masked
245
+ continue
246
+ # Create masks only for preserved frames
247
+ frame_masks = create_random_mask(
248
+ batch_size=video_shape[0],
249
+ num_frames=1, # Just one frame at a time
250
+ height=video_shape[3],
251
+ width=video_shape[4],
252
+ device=video_device,
253
+ dtype=video_dtype,
254
+ # mask_area_ratio=0.4 * random.random() + 0.1,
255
+ shape_type=sketch_mask_shape
256
+ )
257
+
258
+ # Set the mask for this preserved frame
259
+ sketch_local_mask[:, :, frame_idx:frame_idx+1, :, :] = frame_masks
260
+
261
+ # Produce sketches for preserved frames
262
+ # Sketches can either be 1) calculated from sketch pool or 2) loaded from the flux sketch directory
263
+ if is_flux_sketch_available is not None and is_flux_sketch_available[0]:
264
+ should_use_flux_sketch = random.random() < 0.75 if not is_evaluation else True
265
+ else:
266
+ should_use_flux_sketch = False
267
+
268
+ cur_model_name = "flux" if should_use_flux_sketch else random.choice(["lineart", "lineart_anime", "anime2sketch"]) if model_name is None else model_name # "anime2sketch"
269
+ # cur_model_name = "anyline"
270
+ for _idx in preserved_sketch_indices:
271
+ sketch_frame = None
272
+ if should_use_flux_sketch:
273
+ # Load flux sketch
274
+ sketech_path = f"/group/40005/gzhiwang/iclora/linearts/{video_clip_names[0]}/{_idx}.lineart.png"
275
+ print(f"Loading flux sketch from {sketech_path}...")
276
+ if os.path.exists(sketech_path):
277
+ sketch_frame = cv2.imread(sketech_path)
278
+ sketch_frame = cv2.cvtColor(sketch_frame, cv2.COLOR_BGR2RGB)
279
+ # resize to 480p
280
+ sketch_frame = cv2.resize(sketch_frame, (video_shape[4], video_shape[3]))
281
+ sketch_frame = torch.from_numpy(sketch_frame).to(video_device, dtype=video_dtype)
282
+ # Normalize to [-1, 1]
283
+ sketch_frame = sketch_frame / 255.0 * 2.0 - 1.0
284
+ sketch_frame = sketch_frame.permute(2, 0, 1)
285
+ sketch_frame = sketch_frame.unsqueeze(0)
286
+ else:
287
+ print(f"FLUX Sketch path {sketech_path} does not exist. Falling back to sketch pool.")
288
+ # raise ValueError(f"FLUX Sketch path {sketech_path} does not exist.")
289
+ if sketch_frame is None:
290
+ # Calculate sketch from sketch pool
291
+ sketch_frame = extract_img_to_sketch(
292
+ sketch_model, original_video[:, :, _idx, :, :].float(),
293
+ model_name=cur_model_name).to(video_device, dtype=video_dtype)
294
+ # Convert white BG (from sketch pool or loaded from flux sketch files) to black BG (for training)
295
+ sketch_frame = -torch.clip(sketch_frame, -1, 1)
296
+ full_sketch_frames[:, :, _idx, :, :] = sketch_frame
297
+
298
+ if len(preserved_sketch_indices) > 0:
299
+ _mask_to_add = preserved_condition_mask_sketch if detach_image_and_sketch else preserved_condition_mask
300
+ _video_to_add = masked_condition_video_sketch if detach_image_and_sketch else masked_condition_video
301
+ if not detach_image_and_sketch:
302
+ preserved_sketch_indices = preserved_sketch_indices[1:]
303
+
304
+ # Apply masking to sketch frames if required
305
+ if apply_sketch_mask and sketch_local_mask is not None:
306
+ # sketch_local_mask: [b, 1, t, h, w]
307
+ for _idx in preserved_sketch_indices:
308
+ _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
309
+ _video_to_add[:, :, _idx, :, :] = torch.where(sketch_local_mask[:, 0:1, _idx, :, :] == 0, -1.0, full_sketch_frames[:, :, _idx, :, :])
310
+ else:
311
+ for _idx in preserved_sketch_indices:
312
+ _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
313
+ _video_to_add[:, :, _idx, :, :] = full_sketch_frames[:, :, _idx, :, :]
314
+
315
+ return masked_condition_video, preserved_condition_mask, masked_condition_video_sketch, preserved_condition_mask_sketch, full_sketch_frames, sketch_local_mask, cur_model_name
316
+
317
+