Spaces:
Paused
Paused
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +6 -0
- Dockerfile +28 -0
- LICENSE +201 -0
- ORIGINAL_README.md +83 -0
- diffusers_helper/bucket_tools.py +97 -0
- diffusers_helper/clip_vision.py +12 -0
- diffusers_helper/dit_common.py +53 -0
- diffusers_helper/gradio/progress_bar.py +86 -0
- diffusers_helper/hf_login.py +21 -0
- diffusers_helper/hunyuan.py +163 -0
- diffusers_helper/k_diffusion/uni_pc_fm.py +144 -0
- diffusers_helper/k_diffusion/wrapper.py +51 -0
- diffusers_helper/lora_utils.py +194 -0
- diffusers_helper/memory.py +134 -0
- diffusers_helper/models/hunyuan_video_packed.py +1062 -0
- diffusers_helper/models/mag_cache.py +219 -0
- diffusers_helper/models/mag_cache_ratios.py +71 -0
- diffusers_helper/pipelines/k_diffusion_hunyuan.py +120 -0
- diffusers_helper/thread_utils.py +76 -0
- diffusers_helper/utils.py +613 -0
- docker-compose.yml +25 -0
- install.bat +208 -0
- modules/__init__.py +4 -0
- modules/generators/__init__.py +32 -0
- modules/generators/base_generator.py +281 -0
- modules/generators/f1_generator.py +235 -0
- modules/generators/original_generator.py +213 -0
- modules/generators/original_with_endframe_generator.py +15 -0
- modules/generators/video_base_generator.py +613 -0
- modules/generators/video_f1_generator.py +189 -0
- modules/generators/video_generator.py +239 -0
- modules/grid_builder.py +78 -0
- modules/interface.py +0 -0
- modules/llm_captioner.py +66 -0
- modules/llm_enhancer.py +191 -0
- modules/pipelines/__init__.py +45 -0
- modules/pipelines/base_pipeline.py +85 -0
- modules/pipelines/f1_pipeline.py +140 -0
- modules/pipelines/metadata_utils.py +329 -0
- modules/pipelines/original_pipeline.py +138 -0
- modules/pipelines/original_with_endframe_pipeline.py +157 -0
- modules/pipelines/video_f1_pipeline.py +143 -0
- modules/pipelines/video_pipeline.py +143 -0
- modules/pipelines/video_tools.py +57 -0
- modules/pipelines/worker.py +1150 -0
- modules/prompt_handler.py +164 -0
- modules/settings.py +88 -0
- modules/toolbox/RIFE/IFNet_HDv3.py +136 -0
- modules/toolbox/RIFE/RIFE_HDv3.py +98 -0
- modules/toolbox/RIFE/__int__.py +0 -0
.dockerignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hf_download/
|
2 |
+
.framepack/
|
3 |
+
loras/
|
4 |
+
outputs/
|
5 |
+
modules/toolbox/model_esrgan/
|
6 |
+
modules/toolbox/model_rife/
|
Dockerfile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG CUDA_VERSION=12.4.1
|
2 |
+
|
3 |
+
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu22.04
|
4 |
+
|
5 |
+
ARG CUDA_VERSION
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
python3 python3-pip git ffmpeg wget curl && \
|
9 |
+
pip3 install --upgrade pip
|
10 |
+
|
11 |
+
WORKDIR /app
|
12 |
+
|
13 |
+
# This allows caching pip install if only code has changed
|
14 |
+
COPY requirements.txt .
|
15 |
+
|
16 |
+
# Install dependencies
|
17 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
18 |
+
RUN export CUDA_SHORT_VERSION=$(echo "${CUDA_VERSION}" | sed 's/\.//g' | cut -c 1-3) && \
|
19 |
+
pip3 install --force-reinstall --no-cache-dir torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cu${CUDA_SHORT_VERSION}"
|
20 |
+
|
21 |
+
# Copy the source code to /app
|
22 |
+
COPY . .
|
23 |
+
|
24 |
+
VOLUME [ "/app/.framepack", "/app/outputs", "/app/loras", "/app/hf_download", "/app/modules/toolbox/model_esrgan", "/app/modules/toolbox/model_rife" ]
|
25 |
+
|
26 |
+
EXPOSE 7860
|
27 |
+
|
28 |
+
CMD ["python3", "studio.py"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">FramePack Studio</h1>
|
2 |
+
|
3 |
+
[](https://discord.gg/MtuM7gFJ3V)[](https://www.patreon.com/ColinU)
|
4 |
+
|
5 |
+
[](https://deepwiki.com/colinurbs/FramePack-Studio)
|
6 |
+
|
7 |
+
FramePack Studio is an AI video generation application based on FramePack that strives to provide everything you need to create high quality video projects.
|
8 |
+
|
9 |
+

|
10 |
+

|
11 |
+
|
12 |
+
## Current Features
|
13 |
+
|
14 |
+
- **F1, Original and Video Extension Generations**: Run all in a single queue
|
15 |
+
- **End Frame Control for 'Original' Model**: Provides greater control over generations
|
16 |
+
- **Upscaling and Post-processing**
|
17 |
+
- **Timestamped Prompts**: Define different prompts for specific time segments in your video
|
18 |
+
- **Prompt Blending**: Define the blending time between timestamped prompts
|
19 |
+
- **LoRA Support**: Works with most (all?) Hunyuan Video LoRAs
|
20 |
+
- **Queue System**: Process multiple generation jobs without blocking the interface. Import and export queues.
|
21 |
+
- **Metadata Saving/Import**: Prompt and seed are encoded into the output PNG, all other generation metadata is saved in a JSON file that can be imported later for similar generations.
|
22 |
+
- **Custom Presets**: Allow quick switching between named groups of parameters. A custom Startup Preset can also be set.
|
23 |
+
- **I2V and T2V**: Works with or without an input image to allow for more flexibility when working with standard Hunyuan Video LoRAs
|
24 |
+
- **Latent Image Options**: When using T2V you can generate based on a black, white, green screen, or pure noise image
|
25 |
+
|
26 |
+
## Prerequisites
|
27 |
+
|
28 |
+
- CUDA-compatible GPU with at least 8GB VRAM (16GB+ recommended)
|
29 |
+
- 16GB System Memory (32GB+ strongly recommended)
|
30 |
+
- 80GB+ of storage (including ~25GB for each model family: Original and F1)
|
31 |
+
|
32 |
+
## Documentation
|
33 |
+
|
34 |
+
For information on installation, configuration, and usage, please visit our [documentation site](https://docs.framepackstudio.com/).
|
35 |
+
|
36 |
+
## Installation
|
37 |
+
|
38 |
+
Please see [this guide](https://docs.framepackstudio.com/docs/get_started/) on our documentation site to get FP-Studio installed.
|
39 |
+
|
40 |
+
## LoRAs
|
41 |
+
|
42 |
+
Add LoRAs to the /loras/ folder at the root of the installation. Select the LoRAs you wish to load and set the weights for each generation. Most Hunyuan LoRAs were originally trained for T2V, it's often helpful to run a T2V generation to ensure they're working before using input images.
|
43 |
+
|
44 |
+
NOTE: Slow lora loading is a known issue
|
45 |
+
|
46 |
+
## Working with Timestamped Prompts
|
47 |
+
|
48 |
+
You can create videos with changing prompts over time using the following syntax:
|
49 |
+
|
50 |
+
```
|
51 |
+
[0s: A serene forest with sunlight filtering through the trees ]
|
52 |
+
[5s: A deer appears in the clearing ]
|
53 |
+
[10s: The deer drinks from a small stream ]
|
54 |
+
```
|
55 |
+
|
56 |
+
Each timestamp defines when that prompt should start influencing the generation. The system will (hopefully) smoothly transition between prompts for a cohesive video.
|
57 |
+
|
58 |
+
## Credits
|
59 |
+
|
60 |
+
Many thanks to [Lvmin Zhang](https://github.com/lllyasviel) for the absolutely amazing work on the original [FramePack](https://github.com/lllyasviel/FramePack) code!
|
61 |
+
|
62 |
+
Thanks to [Rickard Edén](https://github.com/neph1) for the LoRA code and their general contributions to this growing FramePack scene!
|
63 |
+
|
64 |
+
Thanks to [Zehong Ma](https://github.com/Zehong-Ma) for [MagCache](https://github.com/Zehong-Ma/MagCache): Fast Video Generation with Magnitude-Aware Cache!
|
65 |
+
|
66 |
+
Thanks to everyone who has joined the Discord, reported a bug, sumbitted a PR, or helped with testing!
|
67 |
+
|
68 |
+
@article{zhang2025framepack,
|
69 |
+
title={Packing Input Frame Contexts in Next-Frame Prediction Models for Video Generation},
|
70 |
+
author={Lvmin Zhang and Maneesh Agrawala},
|
71 |
+
journal={Arxiv},
|
72 |
+
year={2025}
|
73 |
+
}
|
74 |
+
|
75 |
+
@misc{zhang2025packinginputframecontext,
|
76 |
+
title={Packing Input Frame Context in Next-Frame Prediction Models for Video Generation},
|
77 |
+
author={Lvmin Zhang and Maneesh Agrawala},
|
78 |
+
year={2025},
|
79 |
+
eprint={2504.12626},
|
80 |
+
archivePrefix={arXiv},
|
81 |
+
primaryClass={cs.CV},
|
82 |
+
url={https://arxiv.org/abs/2504.12626}
|
83 |
+
}
|
diffusers_helper/bucket_tools.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bucket_options = {
|
2 |
+
128: [
|
3 |
+
(96, 160),
|
4 |
+
(112, 144),
|
5 |
+
(128, 128),
|
6 |
+
(144, 112),
|
7 |
+
(160, 96),
|
8 |
+
],
|
9 |
+
256: [
|
10 |
+
(192, 320),
|
11 |
+
(224, 288),
|
12 |
+
(256, 256),
|
13 |
+
(288, 224),
|
14 |
+
(320, 192),
|
15 |
+
],
|
16 |
+
384: [
|
17 |
+
(256, 512),
|
18 |
+
(320, 448),
|
19 |
+
(384, 384),
|
20 |
+
(448, 320),
|
21 |
+
(512, 256),
|
22 |
+
],
|
23 |
+
512: [
|
24 |
+
(352, 704),
|
25 |
+
(384, 640),
|
26 |
+
(448, 576),
|
27 |
+
(512, 512),
|
28 |
+
(576, 448),
|
29 |
+
(640, 384),
|
30 |
+
(704, 352),
|
31 |
+
],
|
32 |
+
640: [
|
33 |
+
(416, 960),
|
34 |
+
(448, 864),
|
35 |
+
(480, 832),
|
36 |
+
(512, 768),
|
37 |
+
(544, 704),
|
38 |
+
(576, 672),
|
39 |
+
(608, 640),
|
40 |
+
(640, 640),
|
41 |
+
(640, 608),
|
42 |
+
(672, 576),
|
43 |
+
(704, 544),
|
44 |
+
(768, 512),
|
45 |
+
(832, 480),
|
46 |
+
(864, 448),
|
47 |
+
(960, 416),
|
48 |
+
],
|
49 |
+
768: [
|
50 |
+
(512, 1024),
|
51 |
+
(576, 896),
|
52 |
+
(640, 832),
|
53 |
+
(704, 768),
|
54 |
+
(768, 768),
|
55 |
+
(768, 704),
|
56 |
+
(832, 640),
|
57 |
+
(896, 576),
|
58 |
+
(1024, 512),
|
59 |
+
],
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
+
def find_nearest_bucket(h, w, resolution=640):
|
64 |
+
# Use the provided resolution or find the closest available bucket size
|
65 |
+
# print(f"find_nearest_bucket called with h={h}, w={w}, resolution={resolution}")
|
66 |
+
|
67 |
+
# Convert resolution to int if it's not already
|
68 |
+
resolution = int(resolution) if not isinstance(resolution, int) else resolution
|
69 |
+
|
70 |
+
if resolution not in bucket_options:
|
71 |
+
# Find the closest available resolution
|
72 |
+
available_resolutions = list(bucket_options.keys())
|
73 |
+
closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
|
74 |
+
# print(f"Resolution {resolution} not found in bucket options, using closest available: {closest_resolution}")
|
75 |
+
resolution = closest_resolution
|
76 |
+
# else:
|
77 |
+
# print(f"Resolution {resolution} found in bucket options")
|
78 |
+
|
79 |
+
# Calculate the aspect ratio of the input image
|
80 |
+
input_aspect_ratio = w / h if h > 0 else 1.0
|
81 |
+
# print(f"Input aspect ratio: {input_aspect_ratio:.4f}")
|
82 |
+
|
83 |
+
min_diff = float('inf')
|
84 |
+
best_bucket = None
|
85 |
+
|
86 |
+
# Find the bucket size with the closest aspect ratio to the input image
|
87 |
+
for (bucket_h, bucket_w) in bucket_options[resolution]:
|
88 |
+
bucket_aspect_ratio = bucket_w / bucket_h if bucket_h > 0 else 1.0
|
89 |
+
# Calculate the difference in aspect ratios
|
90 |
+
diff = abs(bucket_aspect_ratio - input_aspect_ratio)
|
91 |
+
if diff < min_diff:
|
92 |
+
min_diff = diff
|
93 |
+
best_bucket = (bucket_h, bucket_w)
|
94 |
+
# print(f" Checking bucket ({bucket_h}, {bucket_w}), aspect ratio={bucket_aspect_ratio:.4f}, diff={diff:.4f}, current best={best_bucket}")
|
95 |
+
|
96 |
+
# print(f"Using resolution {resolution}, selected bucket: {best_bucket}")
|
97 |
+
return best_bucket
|
diffusers_helper/clip_vision.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def hf_clip_vision_encode(image, feature_extractor, image_encoder):
|
5 |
+
assert isinstance(image, np.ndarray)
|
6 |
+
assert image.ndim == 3 and image.shape[2] == 3
|
7 |
+
assert image.dtype == np.uint8
|
8 |
+
|
9 |
+
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
|
10 |
+
image_encoder_output = image_encoder(**preprocessed)
|
11 |
+
|
12 |
+
return image_encoder_output
|
diffusers_helper/dit_common.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import accelerate.accelerator
|
3 |
+
|
4 |
+
from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
|
5 |
+
|
6 |
+
|
7 |
+
accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
|
8 |
+
|
9 |
+
|
10 |
+
def LayerNorm_forward(self, x):
|
11 |
+
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
|
12 |
+
|
13 |
+
|
14 |
+
LayerNorm.forward = LayerNorm_forward
|
15 |
+
torch.nn.LayerNorm.forward = LayerNorm_forward
|
16 |
+
|
17 |
+
|
18 |
+
def FP32LayerNorm_forward(self, x):
|
19 |
+
origin_dtype = x.dtype
|
20 |
+
return torch.nn.functional.layer_norm(
|
21 |
+
x.float(),
|
22 |
+
self.normalized_shape,
|
23 |
+
self.weight.float() if self.weight is not None else None,
|
24 |
+
self.bias.float() if self.bias is not None else None,
|
25 |
+
self.eps,
|
26 |
+
).to(origin_dtype)
|
27 |
+
|
28 |
+
|
29 |
+
FP32LayerNorm.forward = FP32LayerNorm_forward
|
30 |
+
|
31 |
+
|
32 |
+
def RMSNorm_forward(self, hidden_states):
|
33 |
+
input_dtype = hidden_states.dtype
|
34 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
35 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
36 |
+
|
37 |
+
if self.weight is None:
|
38 |
+
return hidden_states.to(input_dtype)
|
39 |
+
|
40 |
+
return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
|
41 |
+
|
42 |
+
|
43 |
+
RMSNorm.forward = RMSNorm_forward
|
44 |
+
|
45 |
+
|
46 |
+
def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
|
47 |
+
emb = self.linear(self.silu(conditioning_embedding))
|
48 |
+
scale, shift = emb.chunk(2, dim=1)
|
49 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
|
diffusers_helper/gradio/progress_bar.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
progress_html = '''
|
2 |
+
<div class="loader-container">
|
3 |
+
<div class="loader"></div>
|
4 |
+
<div class="progress-container">
|
5 |
+
<progress value="*number*" max="100"></progress>
|
6 |
+
</div>
|
7 |
+
<span>*text*</span>
|
8 |
+
</div>
|
9 |
+
'''
|
10 |
+
|
11 |
+
css = '''
|
12 |
+
.loader-container {
|
13 |
+
display: flex; /* Use flex to align items horizontally */
|
14 |
+
align-items: center; /* Center items vertically within the container */
|
15 |
+
white-space: nowrap; /* Prevent line breaks within the container */
|
16 |
+
}
|
17 |
+
|
18 |
+
.loader {
|
19 |
+
border: 8px solid #f3f3f3; /* Light grey */
|
20 |
+
border-top: 8px solid #3498db; /* Blue */
|
21 |
+
border-radius: 50%;
|
22 |
+
width: 30px;
|
23 |
+
height: 30px;
|
24 |
+
animation: spin 2s linear infinite;
|
25 |
+
}
|
26 |
+
|
27 |
+
@keyframes spin {
|
28 |
+
0% { transform: rotate(0deg); }
|
29 |
+
100% { transform: rotate(360deg); }
|
30 |
+
}
|
31 |
+
|
32 |
+
/* Style the progress bar */
|
33 |
+
progress {
|
34 |
+
appearance: none; /* Remove default styling */
|
35 |
+
height: 20px; /* Set the height of the progress bar */
|
36 |
+
border-radius: 5px; /* Round the corners of the progress bar */
|
37 |
+
background-color: #f3f3f3; /* Light grey background */
|
38 |
+
width: 100%;
|
39 |
+
vertical-align: middle !important;
|
40 |
+
}
|
41 |
+
|
42 |
+
/* Style the progress bar container */
|
43 |
+
.progress-container {
|
44 |
+
margin-left: 20px;
|
45 |
+
margin-right: 20px;
|
46 |
+
flex-grow: 1; /* Allow the progress container to take up remaining space */
|
47 |
+
}
|
48 |
+
|
49 |
+
/* Set the color of the progress bar fill */
|
50 |
+
progress::-webkit-progress-value {
|
51 |
+
background-color: #3498db; /* Blue color for the fill */
|
52 |
+
}
|
53 |
+
|
54 |
+
progress::-moz-progress-bar {
|
55 |
+
background-color: #3498db; /* Blue color for the fill in Firefox */
|
56 |
+
}
|
57 |
+
|
58 |
+
/* Style the text on the progress bar */
|
59 |
+
progress::after {
|
60 |
+
content: attr(value '%'); /* Display the progress value followed by '%' */
|
61 |
+
position: absolute;
|
62 |
+
top: 50%;
|
63 |
+
left: 50%;
|
64 |
+
transform: translate(-50%, -50%);
|
65 |
+
color: white; /* Set text color */
|
66 |
+
font-size: 14px; /* Set font size */
|
67 |
+
}
|
68 |
+
|
69 |
+
/* Style other texts */
|
70 |
+
.loader-container > span {
|
71 |
+
margin-left: 5px; /* Add spacing between the progress bar and the text */
|
72 |
+
}
|
73 |
+
|
74 |
+
.no-generating-animation > .generating {
|
75 |
+
display: none !important;
|
76 |
+
}
|
77 |
+
|
78 |
+
'''
|
79 |
+
|
80 |
+
|
81 |
+
def make_progress_bar_html(number, text):
|
82 |
+
return progress_html.replace('*number*', str(number)).replace('*text*', text)
|
83 |
+
|
84 |
+
|
85 |
+
def make_progress_bar_css():
|
86 |
+
return css
|
diffusers_helper/hf_login.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def login(token):
|
5 |
+
from huggingface_hub import login
|
6 |
+
import time
|
7 |
+
|
8 |
+
while True:
|
9 |
+
try:
|
10 |
+
login(token)
|
11 |
+
print('HF login ok.')
|
12 |
+
break
|
13 |
+
except Exception as e:
|
14 |
+
print(f'HF login failed: {e}. Retrying')
|
15 |
+
time.sleep(0.5)
|
16 |
+
|
17 |
+
|
18 |
+
hf_token = os.environ.get('HF_TOKEN', None)
|
19 |
+
|
20 |
+
if hf_token is not None:
|
21 |
+
login(hf_token)
|
diffusers_helper/hunyuan.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
|
4 |
+
from diffusers_helper.utils import crop_or_pad_yield_mask
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
|
9 |
+
assert isinstance(prompt, str)
|
10 |
+
|
11 |
+
prompt = [prompt]
|
12 |
+
|
13 |
+
# LLAMA
|
14 |
+
|
15 |
+
# Check if there's a custom system prompt template in settings
|
16 |
+
custom_template = None
|
17 |
+
try:
|
18 |
+
from modules.settings import Settings
|
19 |
+
settings = Settings()
|
20 |
+
override_system_prompt = settings.get("override_system_prompt", False)
|
21 |
+
custom_template_str = settings.get("system_prompt_template")
|
22 |
+
|
23 |
+
if override_system_prompt and custom_template_str:
|
24 |
+
try:
|
25 |
+
# Convert the string representation to a dictionary
|
26 |
+
# Extract template and crop_start directly from the string using regex
|
27 |
+
import re
|
28 |
+
|
29 |
+
# Try to extract the template value
|
30 |
+
template_match = re.search(r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})", custom_template_str, re.DOTALL)
|
31 |
+
crop_start_match = re.search(r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str)
|
32 |
+
|
33 |
+
if template_match and crop_start_match:
|
34 |
+
template_value = template_match.group(1)
|
35 |
+
crop_start_value = int(crop_start_match.group(1))
|
36 |
+
|
37 |
+
# Unescape any escaped characters in the template
|
38 |
+
template_value = template_value.replace("\\n", "\n").replace("\\\"", "\"").replace("\\'", "'")
|
39 |
+
|
40 |
+
custom_template = {
|
41 |
+
"template": template_value,
|
42 |
+
"crop_start": crop_start_value
|
43 |
+
}
|
44 |
+
print(f"Using custom system prompt template from settings: {custom_template}")
|
45 |
+
else:
|
46 |
+
print(f"Could not extract template or crop_start from system prompt template string")
|
47 |
+
print(f"Falling back to default template")
|
48 |
+
custom_template = None
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error parsing custom system prompt template: {e}")
|
51 |
+
print(f"Falling back to default template")
|
52 |
+
custom_template = None
|
53 |
+
else:
|
54 |
+
if not override_system_prompt:
|
55 |
+
print(f"Override system prompt is disabled, using default template")
|
56 |
+
elif not custom_template_str:
|
57 |
+
print(f"No custom system prompt template found in settings")
|
58 |
+
custom_template = None
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error loading settings: {e}")
|
61 |
+
print(f"Falling back to default template")
|
62 |
+
custom_template = None
|
63 |
+
|
64 |
+
# Use custom template if available, otherwise use default
|
65 |
+
template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE
|
66 |
+
|
67 |
+
prompt_llama = [template["template"].format(p) for p in prompt]
|
68 |
+
crop_start = template["crop_start"]
|
69 |
+
|
70 |
+
llama_inputs = tokenizer(
|
71 |
+
prompt_llama,
|
72 |
+
padding="max_length",
|
73 |
+
max_length=max_length + crop_start,
|
74 |
+
truncation=True,
|
75 |
+
return_tensors="pt",
|
76 |
+
return_length=False,
|
77 |
+
return_overflowing_tokens=False,
|
78 |
+
return_attention_mask=True,
|
79 |
+
)
|
80 |
+
|
81 |
+
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
|
82 |
+
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
|
83 |
+
llama_attention_length = int(llama_attention_mask.sum())
|
84 |
+
|
85 |
+
llama_outputs = text_encoder(
|
86 |
+
input_ids=llama_input_ids,
|
87 |
+
attention_mask=llama_attention_mask,
|
88 |
+
output_hidden_states=True,
|
89 |
+
)
|
90 |
+
|
91 |
+
llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
|
92 |
+
# llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
|
93 |
+
llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
|
94 |
+
|
95 |
+
assert torch.all(llama_attention_mask.bool())
|
96 |
+
|
97 |
+
# CLIP
|
98 |
+
|
99 |
+
clip_l_input_ids = tokenizer_2(
|
100 |
+
prompt,
|
101 |
+
padding="max_length",
|
102 |
+
max_length=77,
|
103 |
+
truncation=True,
|
104 |
+
return_overflowing_tokens=False,
|
105 |
+
return_length=False,
|
106 |
+
return_tensors="pt",
|
107 |
+
).input_ids
|
108 |
+
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
|
109 |
+
|
110 |
+
return llama_vec, clip_l_pooler
|
111 |
+
|
112 |
+
|
113 |
+
@torch.no_grad()
|
114 |
+
def vae_decode_fake(latents):
|
115 |
+
latent_rgb_factors = [
|
116 |
+
[-0.0395, -0.0331, 0.0445],
|
117 |
+
[0.0696, 0.0795, 0.0518],
|
118 |
+
[0.0135, -0.0945, -0.0282],
|
119 |
+
[0.0108, -0.0250, -0.0765],
|
120 |
+
[-0.0209, 0.0032, 0.0224],
|
121 |
+
[-0.0804, -0.0254, -0.0639],
|
122 |
+
[-0.0991, 0.0271, -0.0669],
|
123 |
+
[-0.0646, -0.0422, -0.0400],
|
124 |
+
[-0.0696, -0.0595, -0.0894],
|
125 |
+
[-0.0799, -0.0208, -0.0375],
|
126 |
+
[0.1166, 0.1627, 0.0962],
|
127 |
+
[0.1165, 0.0432, 0.0407],
|
128 |
+
[-0.2315, -0.1920, -0.1355],
|
129 |
+
[-0.0270, 0.0401, -0.0821],
|
130 |
+
[-0.0616, -0.0997, -0.0727],
|
131 |
+
[0.0249, -0.0469, -0.1703]
|
132 |
+
] # From comfyui
|
133 |
+
|
134 |
+
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
|
135 |
+
|
136 |
+
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
|
137 |
+
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
138 |
+
|
139 |
+
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
|
140 |
+
images = images.clamp(0.0, 1.0)
|
141 |
+
|
142 |
+
return images
|
143 |
+
|
144 |
+
|
145 |
+
@torch.no_grad()
|
146 |
+
def vae_decode(latents, vae, image_mode=False):
|
147 |
+
latents = latents / vae.config.scaling_factor
|
148 |
+
|
149 |
+
if not image_mode:
|
150 |
+
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
|
151 |
+
else:
|
152 |
+
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
|
153 |
+
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
|
154 |
+
image = torch.cat(image, dim=2)
|
155 |
+
|
156 |
+
return image
|
157 |
+
|
158 |
+
|
159 |
+
@torch.no_grad()
|
160 |
+
def vae_encode(image, vae):
|
161 |
+
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
162 |
+
latents = latents * vae.config.scaling_factor
|
163 |
+
return latents
|
diffusers_helper/k_diffusion/uni_pc_fm.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Better Flow Matching UniPC by Lvmin Zhang
|
2 |
+
# (c) 2025
|
3 |
+
# CC BY-SA 4.0
|
4 |
+
# Attribution-ShareAlike 4.0 International Licence
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from tqdm.auto import trange
|
10 |
+
|
11 |
+
|
12 |
+
def expand_dims(v, dims):
|
13 |
+
return v[(...,) + (None,) * (dims - 1)]
|
14 |
+
|
15 |
+
|
16 |
+
class FlowMatchUniPC:
|
17 |
+
def __init__(self, model, extra_args, variant='bh1'):
|
18 |
+
self.model = model
|
19 |
+
self.variant = variant
|
20 |
+
self.extra_args = extra_args
|
21 |
+
|
22 |
+
def model_fn(self, x, t):
|
23 |
+
return self.model(x, t, **self.extra_args)
|
24 |
+
|
25 |
+
def update_fn(self, x, model_prev_list, t_prev_list, t, order):
|
26 |
+
assert order <= len(model_prev_list)
|
27 |
+
dims = x.dim()
|
28 |
+
|
29 |
+
t_prev_0 = t_prev_list[-1]
|
30 |
+
lambda_prev_0 = - torch.log(t_prev_0)
|
31 |
+
lambda_t = - torch.log(t)
|
32 |
+
model_prev_0 = model_prev_list[-1]
|
33 |
+
|
34 |
+
h = lambda_t - lambda_prev_0
|
35 |
+
|
36 |
+
rks = []
|
37 |
+
D1s = []
|
38 |
+
for i in range(1, order):
|
39 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
40 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
41 |
+
lambda_prev_i = - torch.log(t_prev_i)
|
42 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
43 |
+
rks.append(rk)
|
44 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
45 |
+
|
46 |
+
rks.append(1.)
|
47 |
+
rks = torch.tensor(rks, device=x.device)
|
48 |
+
|
49 |
+
R = []
|
50 |
+
b = []
|
51 |
+
|
52 |
+
hh = -h[0]
|
53 |
+
h_phi_1 = torch.expm1(hh)
|
54 |
+
h_phi_k = h_phi_1 / hh - 1
|
55 |
+
|
56 |
+
factorial_i = 1
|
57 |
+
|
58 |
+
if self.variant == 'bh1':
|
59 |
+
B_h = hh
|
60 |
+
elif self.variant == 'bh2':
|
61 |
+
B_h = torch.expm1(hh)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError('Bad variant!')
|
64 |
+
|
65 |
+
for i in range(1, order + 1):
|
66 |
+
R.append(torch.pow(rks, i - 1))
|
67 |
+
b.append(h_phi_k * factorial_i / B_h)
|
68 |
+
factorial_i *= (i + 1)
|
69 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
70 |
+
|
71 |
+
R = torch.stack(R)
|
72 |
+
b = torch.tensor(b, device=x.device)
|
73 |
+
|
74 |
+
use_predictor = len(D1s) > 0
|
75 |
+
|
76 |
+
if use_predictor:
|
77 |
+
D1s = torch.stack(D1s, dim=1)
|
78 |
+
if order == 2:
|
79 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
80 |
+
else:
|
81 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
82 |
+
else:
|
83 |
+
D1s = None
|
84 |
+
rhos_p = None
|
85 |
+
|
86 |
+
if order == 1:
|
87 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
88 |
+
else:
|
89 |
+
rhos_c = torch.linalg.solve(R, b)
|
90 |
+
|
91 |
+
x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
|
92 |
+
|
93 |
+
if use_predictor:
|
94 |
+
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
|
95 |
+
else:
|
96 |
+
pred_res = 0
|
97 |
+
|
98 |
+
x_t = x_t_ - expand_dims(B_h, dims) * pred_res
|
99 |
+
model_t = self.model_fn(x_t, t)
|
100 |
+
|
101 |
+
if D1s is not None:
|
102 |
+
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
|
103 |
+
else:
|
104 |
+
corr_res = 0
|
105 |
+
|
106 |
+
D1_t = (model_t - model_prev_0)
|
107 |
+
x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
108 |
+
|
109 |
+
return x_t, model_t
|
110 |
+
|
111 |
+
def sample(self, x, sigmas, callback=None, disable_pbar=False):
|
112 |
+
order = min(3, len(sigmas) - 2)
|
113 |
+
model_prev_list, t_prev_list = [], []
|
114 |
+
for i in trange(len(sigmas) - 1, disable=disable_pbar):
|
115 |
+
vec_t = sigmas[i].expand(x.shape[0])
|
116 |
+
|
117 |
+
if i == 0:
|
118 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
119 |
+
t_prev_list = [vec_t]
|
120 |
+
elif i < order:
|
121 |
+
init_order = i
|
122 |
+
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
|
123 |
+
model_prev_list.append(model_x)
|
124 |
+
t_prev_list.append(vec_t)
|
125 |
+
else:
|
126 |
+
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
|
127 |
+
model_prev_list.append(model_x)
|
128 |
+
t_prev_list.append(vec_t)
|
129 |
+
|
130 |
+
model_prev_list = model_prev_list[-order:]
|
131 |
+
t_prev_list = t_prev_list[-order:]
|
132 |
+
|
133 |
+
if callback is not None:
|
134 |
+
callback_result = callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
|
135 |
+
if callback_result == 'cancel':
|
136 |
+
print("Cancellation signal received in sample_unipc, stopping generation")
|
137 |
+
return model_prev_list[-1] # Return current denoised result
|
138 |
+
|
139 |
+
return model_prev_list[-1]
|
140 |
+
|
141 |
+
|
142 |
+
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
143 |
+
assert variant in ['bh1', 'bh2']
|
144 |
+
return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
|
diffusers_helper/k_diffusion/wrapper.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def append_dims(x, target_dims):
|
5 |
+
return x[(...,) + (None,) * (target_dims - x.ndim)]
|
6 |
+
|
7 |
+
|
8 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
|
9 |
+
if guidance_rescale == 0:
|
10 |
+
return noise_cfg
|
11 |
+
|
12 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
13 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
14 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
15 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
|
16 |
+
return noise_cfg
|
17 |
+
|
18 |
+
|
19 |
+
def fm_wrapper(transformer, t_scale=1000.0):
|
20 |
+
def k_model(x, sigma, **extra_args):
|
21 |
+
dtype = extra_args['dtype']
|
22 |
+
cfg_scale = extra_args['cfg_scale']
|
23 |
+
cfg_rescale = extra_args['cfg_rescale']
|
24 |
+
concat_latent = extra_args['concat_latent']
|
25 |
+
|
26 |
+
original_dtype = x.dtype
|
27 |
+
sigma = sigma.float()
|
28 |
+
|
29 |
+
x = x.to(dtype)
|
30 |
+
timestep = (sigma * t_scale).to(dtype)
|
31 |
+
|
32 |
+
if concat_latent is None:
|
33 |
+
hidden_states = x
|
34 |
+
else:
|
35 |
+
hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
|
36 |
+
|
37 |
+
pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
|
38 |
+
|
39 |
+
if cfg_scale == 1.0:
|
40 |
+
pred_negative = torch.zeros_like(pred_positive)
|
41 |
+
else:
|
42 |
+
pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
|
43 |
+
|
44 |
+
pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
|
45 |
+
pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
|
46 |
+
|
47 |
+
x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
|
48 |
+
|
49 |
+
return x0.to(dtype=original_dtype)
|
50 |
+
|
51 |
+
return k_model
|
diffusers_helper/lora_utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path, PurePath
|
2 |
+
from typing import Dict, List, Optional, Union, Tuple
|
3 |
+
from diffusers.loaders.lora_pipeline import _fetch_state_dict
|
4 |
+
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
|
5 |
+
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
|
6 |
+
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
|
7 |
+
import torch
|
8 |
+
|
9 |
+
FALLBACK_CLASS_ALIASES = {
|
10 |
+
"HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel",
|
11 |
+
}
|
12 |
+
|
13 |
+
def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]:
|
14 |
+
"""
|
15 |
+
Load LoRA weights into the transformer model.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
transformer: The transformer model to which LoRA weights will be applied.
|
19 |
+
lora_path: Path to the folder containing the LoRA weights file.
|
20 |
+
weight_name: Filename of the weight to load.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
A tuple containing the modified transformer and the canonical adapter name.
|
24 |
+
"""
|
25 |
+
|
26 |
+
state_dict = _fetch_state_dict(
|
27 |
+
lora_path,
|
28 |
+
weight_name,
|
29 |
+
True,
|
30 |
+
True,
|
31 |
+
None,
|
32 |
+
None,
|
33 |
+
None,
|
34 |
+
None,
|
35 |
+
None,
|
36 |
+
None,
|
37 |
+
None,
|
38 |
+
None)
|
39 |
+
|
40 |
+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
41 |
+
|
42 |
+
# should weight_name even be Optional[str] or just str?
|
43 |
+
# For now, we assume it is never None
|
44 |
+
# The module name in the state_dict must not include a . in the name
|
45 |
+
# See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
|
46 |
+
adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_')
|
47 |
+
if '_DOT_' in adapter_name:
|
48 |
+
print(
|
49 |
+
f"LoRA file '{weight_name}' contains a '.' in the name. " +
|
50 |
+
'This may cause issues. Consider renaming the file.' +
|
51 |
+
f" Using '{adapter_name}' as the adapter name to be safe."
|
52 |
+
)
|
53 |
+
|
54 |
+
# Check if adapter already exists and delete it if it does
|
55 |
+
if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
|
56 |
+
print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
|
57 |
+
# Use delete_adapters (plural) instead of delete_adapter
|
58 |
+
transformer.delete_adapters([adapter_name])
|
59 |
+
|
60 |
+
# Load the adapter with the original name
|
61 |
+
transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
|
62 |
+
print(f"LoRA weights '{adapter_name}' loaded successfully.")
|
63 |
+
|
64 |
+
return transformer, adapter_name
|
65 |
+
|
66 |
+
def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module:
|
67 |
+
"""
|
68 |
+
Completely unload all LoRA adapters from the transformer model.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
transformer: The transformer model from which LoRA adapters will be removed.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
The transformer model after all LoRA adapters have been removed.
|
75 |
+
"""
|
76 |
+
if hasattr(transformer, 'peft_config') and transformer.peft_config:
|
77 |
+
# Get all adapter names
|
78 |
+
adapter_names = list(transformer.peft_config.keys())
|
79 |
+
|
80 |
+
if adapter_names:
|
81 |
+
print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
|
82 |
+
# Delete all adapters
|
83 |
+
transformer.delete_adapters(adapter_names)
|
84 |
+
|
85 |
+
# Force cleanup of any remaining adapter references
|
86 |
+
if hasattr(transformer, 'active_adapter'):
|
87 |
+
transformer.active_adapter = None
|
88 |
+
|
89 |
+
# Clear any cached states
|
90 |
+
for module in transformer.modules():
|
91 |
+
if hasattr(module, 'lora_A'):
|
92 |
+
if isinstance(module.lora_A, dict):
|
93 |
+
module.lora_A.clear()
|
94 |
+
if hasattr(module, 'lora_B'):
|
95 |
+
if isinstance(module.lora_B, dict):
|
96 |
+
module.lora_B.clear()
|
97 |
+
if hasattr(module, 'scaling'):
|
98 |
+
if isinstance(module.scaling, dict):
|
99 |
+
module.scaling.clear()
|
100 |
+
|
101 |
+
print("All LoRA adapters have been completely removed.")
|
102 |
+
else:
|
103 |
+
print("No LoRA adapters found to remove.")
|
104 |
+
else:
|
105 |
+
print("Model doesn't have any LoRA adapters or peft_config.")
|
106 |
+
|
107 |
+
# Force garbage collection
|
108 |
+
import gc
|
109 |
+
gc.collect()
|
110 |
+
if torch.cuda.is_available():
|
111 |
+
torch.cuda.empty_cache()
|
112 |
+
|
113 |
+
return transformer
|
114 |
+
|
115 |
+
def resolve_expansion_class_name(
|
116 |
+
transformer: torch.nn.Module,
|
117 |
+
fallback_aliases: Dict[str, str],
|
118 |
+
fn_mapping: Dict[str, callable]
|
119 |
+
) -> Optional[str]:
|
120 |
+
"""
|
121 |
+
Resolves the canonical class name for adapter scale expansion functions,
|
122 |
+
considering potential fallback aliases.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
transformer: The transformer model instance.
|
126 |
+
fallback_aliases: A dictionary mapping model class names to fallback class names.
|
127 |
+
fn_mapping: A dictionary mapping class names to their respective scale expansion functions.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
The resolved class name as a string if a matching scale function is found,
|
131 |
+
otherwise None.
|
132 |
+
"""
|
133 |
+
class_name = transformer.__class__.__name__
|
134 |
+
|
135 |
+
if class_name in fn_mapping:
|
136 |
+
return class_name
|
137 |
+
|
138 |
+
fallback_class = fallback_aliases.get(class_name)
|
139 |
+
if fallback_class in fn_mapping:
|
140 |
+
print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'")
|
141 |
+
return fallback_class
|
142 |
+
|
143 |
+
return None
|
144 |
+
|
145 |
+
def set_adapters(
|
146 |
+
transformer: torch.nn.Module,
|
147 |
+
adapter_names: Union[List[str], str],
|
148 |
+
weights: Optional[Union[float, List[float]]] = None,
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Activates and sets the weights for one or more LoRA adapters on the transformer model.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
transformer: The transformer model to which LoRA adapters are applied.
|
155 |
+
adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate.
|
156 |
+
weights: Optional. A single float weight or a list of float weights
|
157 |
+
corresponding to each adapter name. If None, defaults to 1.0 for each adapter.
|
158 |
+
If a single float, it will be applied to all adapters.
|
159 |
+
"""
|
160 |
+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
161 |
+
|
162 |
+
# Expand a single weight to apply to all adapters if needed
|
163 |
+
if not isinstance(weights, list):
|
164 |
+
weights = [weights] * len(adapter_names)
|
165 |
+
|
166 |
+
if len(adapter_names) != len(weights):
|
167 |
+
raise ValueError(
|
168 |
+
f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})."
|
169 |
+
)
|
170 |
+
|
171 |
+
# Replace any None weights with a default value of 1.0
|
172 |
+
sanitized_weights = [w if w is not None else 1.0 for w in weights]
|
173 |
+
|
174 |
+
resolved_class_name = resolve_expansion_class_name(
|
175 |
+
transformer,
|
176 |
+
fallback_aliases=FALLBACK_CLASS_ALIASES,
|
177 |
+
fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING
|
178 |
+
)
|
179 |
+
|
180 |
+
transformer_class_name = transformer.__class__.__name__
|
181 |
+
|
182 |
+
if resolved_class_name:
|
183 |
+
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name]
|
184 |
+
print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')")
|
185 |
+
final_weights = [
|
186 |
+
scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights
|
187 |
+
]
|
188 |
+
else:
|
189 |
+
print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.")
|
190 |
+
final_weights = sanitized_weights
|
191 |
+
|
192 |
+
set_weights_and_activate_adapters(transformer, adapter_names, final_weights)
|
193 |
+
|
194 |
+
print(f"Adapters {adapter_names} activated with weights {final_weights}.")
|
diffusers_helper/memory.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# By lllyasviel
|
2 |
+
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
cpu = torch.device('cpu')
|
8 |
+
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
|
9 |
+
gpu_complete_modules = []
|
10 |
+
|
11 |
+
|
12 |
+
class DynamicSwapInstaller:
|
13 |
+
@staticmethod
|
14 |
+
def _install_module(module: torch.nn.Module, **kwargs):
|
15 |
+
original_class = module.__class__
|
16 |
+
module.__dict__['forge_backup_original_class'] = original_class
|
17 |
+
|
18 |
+
def hacked_get_attr(self, name: str):
|
19 |
+
if '_parameters' in self.__dict__:
|
20 |
+
_parameters = self.__dict__['_parameters']
|
21 |
+
if name in _parameters:
|
22 |
+
p = _parameters[name]
|
23 |
+
if p is None:
|
24 |
+
return None
|
25 |
+
if p.__class__ == torch.nn.Parameter:
|
26 |
+
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
|
27 |
+
else:
|
28 |
+
return p.to(**kwargs)
|
29 |
+
if '_buffers' in self.__dict__:
|
30 |
+
_buffers = self.__dict__['_buffers']
|
31 |
+
if name in _buffers:
|
32 |
+
return _buffers[name].to(**kwargs)
|
33 |
+
return super(original_class, self).__getattr__(name)
|
34 |
+
|
35 |
+
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
|
36 |
+
'__getattr__': hacked_get_attr,
|
37 |
+
})
|
38 |
+
|
39 |
+
return
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def _uninstall_module(module: torch.nn.Module):
|
43 |
+
if 'forge_backup_original_class' in module.__dict__:
|
44 |
+
module.__class__ = module.__dict__.pop('forge_backup_original_class')
|
45 |
+
return
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def install_model(model: torch.nn.Module, **kwargs):
|
49 |
+
for m in model.modules():
|
50 |
+
DynamicSwapInstaller._install_module(m, **kwargs)
|
51 |
+
return
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def uninstall_model(model: torch.nn.Module):
|
55 |
+
for m in model.modules():
|
56 |
+
DynamicSwapInstaller._uninstall_module(m)
|
57 |
+
return
|
58 |
+
|
59 |
+
|
60 |
+
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
|
61 |
+
if hasattr(model, 'scale_shift_table'):
|
62 |
+
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
|
63 |
+
return
|
64 |
+
|
65 |
+
for k, p in model.named_modules():
|
66 |
+
if hasattr(p, 'weight'):
|
67 |
+
p.to(target_device)
|
68 |
+
return
|
69 |
+
|
70 |
+
|
71 |
+
def get_cuda_free_memory_gb(device=None):
|
72 |
+
if device is None:
|
73 |
+
device = gpu
|
74 |
+
|
75 |
+
memory_stats = torch.cuda.memory_stats(device)
|
76 |
+
bytes_active = memory_stats['active_bytes.all.current']
|
77 |
+
bytes_reserved = memory_stats['reserved_bytes.all.current']
|
78 |
+
bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
|
79 |
+
bytes_inactive_reserved = bytes_reserved - bytes_active
|
80 |
+
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
|
81 |
+
return bytes_total_available / (1024 ** 3)
|
82 |
+
|
83 |
+
|
84 |
+
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
|
85 |
+
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
|
86 |
+
|
87 |
+
for m in model.modules():
|
88 |
+
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
|
89 |
+
torch.cuda.empty_cache()
|
90 |
+
return
|
91 |
+
|
92 |
+
if hasattr(m, 'weight'):
|
93 |
+
m.to(device=target_device)
|
94 |
+
|
95 |
+
model.to(device=target_device)
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
return
|
98 |
+
|
99 |
+
|
100 |
+
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
|
101 |
+
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
|
102 |
+
|
103 |
+
for m in model.modules():
|
104 |
+
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
return
|
107 |
+
|
108 |
+
if hasattr(m, 'weight'):
|
109 |
+
m.to(device=cpu)
|
110 |
+
|
111 |
+
model.to(device=cpu)
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
def unload_complete_models(*args):
|
117 |
+
for m in gpu_complete_modules + list(args):
|
118 |
+
m.to(device=cpu)
|
119 |
+
print(f'Unloaded {m.__class__.__name__} as complete.')
|
120 |
+
|
121 |
+
gpu_complete_modules.clear()
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
def load_model_as_complete(model, target_device, unload=True):
|
127 |
+
if unload:
|
128 |
+
unload_complete_models()
|
129 |
+
|
130 |
+
model.to(device=target_device)
|
131 |
+
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
|
132 |
+
|
133 |
+
gpu_complete_modules.append(model)
|
134 |
+
return
|
diffusers_helper/models/hunyuan_video_packed.py
ADDED
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from diffusers.loaders import FromOriginalModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
from diffusers.loaders import PeftAdapterMixin
|
11 |
+
from diffusers.utils import logging
|
12 |
+
from diffusers.models.attention import FeedForward
|
13 |
+
from diffusers.models.attention_processor import Attention
|
14 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
|
15 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
16 |
+
from diffusers.models.modeling_utils import ModelMixin
|
17 |
+
from diffusers_helper.dit_common import LayerNorm
|
18 |
+
from diffusers_helper.models.mag_cache import MagCache
|
19 |
+
from diffusers_helper.utils import zero_module
|
20 |
+
|
21 |
+
|
22 |
+
enabled_backends = []
|
23 |
+
|
24 |
+
if torch.backends.cuda.flash_sdp_enabled():
|
25 |
+
enabled_backends.append("flash")
|
26 |
+
if torch.backends.cuda.math_sdp_enabled():
|
27 |
+
enabled_backends.append("math")
|
28 |
+
if torch.backends.cuda.mem_efficient_sdp_enabled():
|
29 |
+
enabled_backends.append("mem_efficient")
|
30 |
+
if torch.backends.cuda.cudnn_sdp_enabled():
|
31 |
+
enabled_backends.append("cudnn")
|
32 |
+
|
33 |
+
print("Currently enabled native sdp backends:", enabled_backends)
|
34 |
+
|
35 |
+
xformers_attn_func = None
|
36 |
+
flash_attn_varlen_func = None
|
37 |
+
flash_attn_func = None
|
38 |
+
sageattn_varlen = None
|
39 |
+
sageattn = None
|
40 |
+
|
41 |
+
try:
|
42 |
+
# raise NotImplementedError
|
43 |
+
from xformers.ops import memory_efficient_attention as xformers_attn_func
|
44 |
+
except:
|
45 |
+
pass
|
46 |
+
|
47 |
+
try:
|
48 |
+
# raise NotImplementedError
|
49 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
|
53 |
+
try:
|
54 |
+
# raise NotImplementedError
|
55 |
+
from sageattention import sageattn_varlen, sageattn
|
56 |
+
except:
|
57 |
+
pass
|
58 |
+
|
59 |
+
# --- Attention Summary ---
|
60 |
+
print("\n--- Attention Configuration ---")
|
61 |
+
has_sage = sageattn is not None and sageattn_varlen is not None
|
62 |
+
has_flash = flash_attn_func is not None and flash_attn_varlen_func is not None
|
63 |
+
has_xformers = xformers_attn_func is not None
|
64 |
+
|
65 |
+
if has_sage:
|
66 |
+
print("✅ Using SAGE Attention (highest performance).")
|
67 |
+
ignored = []
|
68 |
+
if has_flash:
|
69 |
+
ignored.append("Flash Attention")
|
70 |
+
if has_xformers:
|
71 |
+
ignored.append("xFormers")
|
72 |
+
if ignored:
|
73 |
+
print(f" - Ignoring other installed attention libraries: {', '.join(ignored)}")
|
74 |
+
elif has_flash:
|
75 |
+
print("✅ Using Flash Attention (high performance).")
|
76 |
+
if has_xformers:
|
77 |
+
print(" - Consider installing SAGE Attention for highest performance.")
|
78 |
+
print(" - Ignoring other installed attention library: xFormers")
|
79 |
+
elif has_xformers:
|
80 |
+
print("✅ Using xFormers.")
|
81 |
+
print(" - Consider installing SAGE Attention for highest performance.")
|
82 |
+
print(" - or Consider installing Flash Attention for high performance.")
|
83 |
+
else:
|
84 |
+
print("⚠️ No attention library found. Using native PyTorch Scaled Dot Product Attention.")
|
85 |
+
print(" - For better performance, consider installing one of:")
|
86 |
+
print(" SAGE Attention (highest performance), Flash Attention (high performance), or xFormers.")
|
87 |
+
print("-------------------------------\n")
|
88 |
+
|
89 |
+
|
90 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
91 |
+
|
92 |
+
|
93 |
+
def pad_for_3d_conv(x, kernel_size):
|
94 |
+
b, c, t, h, w = x.shape
|
95 |
+
pt, ph, pw = kernel_size
|
96 |
+
pad_t = (pt - (t % pt)) % pt
|
97 |
+
pad_h = (ph - (h % ph)) % ph
|
98 |
+
pad_w = (pw - (w % pw)) % pw
|
99 |
+
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
|
100 |
+
|
101 |
+
|
102 |
+
def center_down_sample_3d(x, kernel_size):
|
103 |
+
# pt, ph, pw = kernel_size
|
104 |
+
# cp = (pt * ph * pw) // 2
|
105 |
+
# xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
|
106 |
+
# xc = xp[cp]
|
107 |
+
# return xc
|
108 |
+
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
109 |
+
|
110 |
+
|
111 |
+
def get_cu_seqlens(text_mask, img_len):
|
112 |
+
batch_size = text_mask.shape[0]
|
113 |
+
text_len = text_mask.sum(dim=1)
|
114 |
+
max_len = text_mask.shape[1] + img_len
|
115 |
+
|
116 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
117 |
+
|
118 |
+
for i in range(batch_size):
|
119 |
+
s = text_len[i] + img_len
|
120 |
+
s1 = i * max_len + s
|
121 |
+
s2 = (i + 1) * max_len
|
122 |
+
cu_seqlens[2 * i + 1] = s1
|
123 |
+
cu_seqlens[2 * i + 2] = s2
|
124 |
+
|
125 |
+
return cu_seqlens
|
126 |
+
|
127 |
+
|
128 |
+
def apply_rotary_emb_transposed(x, freqs_cis):
|
129 |
+
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
|
130 |
+
x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
|
131 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
132 |
+
out = x.float() * cos + x_rotated.float() * sin
|
133 |
+
out = out.to(x)
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
|
138 |
+
if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
|
139 |
+
if sageattn is not None:
|
140 |
+
x = sageattn(q, k, v, tensor_layout='NHD')
|
141 |
+
return x
|
142 |
+
|
143 |
+
if flash_attn_func is not None:
|
144 |
+
x = flash_attn_func(q, k, v)
|
145 |
+
return x
|
146 |
+
|
147 |
+
if xformers_attn_func is not None:
|
148 |
+
x = xformers_attn_func(q, k, v)
|
149 |
+
return x
|
150 |
+
|
151 |
+
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
|
152 |
+
return x
|
153 |
+
|
154 |
+
batch_size = q.shape[0]
|
155 |
+
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
|
156 |
+
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
|
157 |
+
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
|
158 |
+
if sageattn_varlen is not None:
|
159 |
+
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
160 |
+
elif flash_attn_varlen_func is not None:
|
161 |
+
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
162 |
+
else:
|
163 |
+
raise NotImplementedError('No Attn Installed!')
|
164 |
+
x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class HunyuanAttnProcessorFlashAttnDouble:
|
169 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
|
170 |
+
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
|
171 |
+
|
172 |
+
query = attn.to_q(hidden_states)
|
173 |
+
key = attn.to_k(hidden_states)
|
174 |
+
value = attn.to_v(hidden_states)
|
175 |
+
|
176 |
+
query = query.unflatten(2, (attn.heads, -1))
|
177 |
+
key = key.unflatten(2, (attn.heads, -1))
|
178 |
+
value = value.unflatten(2, (attn.heads, -1))
|
179 |
+
|
180 |
+
query = attn.norm_q(query)
|
181 |
+
key = attn.norm_k(key)
|
182 |
+
|
183 |
+
query = apply_rotary_emb_transposed(query, image_rotary_emb)
|
184 |
+
key = apply_rotary_emb_transposed(key, image_rotary_emb)
|
185 |
+
|
186 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
187 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
188 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
189 |
+
|
190 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
191 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
192 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
193 |
+
|
194 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
195 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
196 |
+
|
197 |
+
query = torch.cat([query, encoder_query], dim=1)
|
198 |
+
key = torch.cat([key, encoder_key], dim=1)
|
199 |
+
value = torch.cat([value, encoder_value], dim=1)
|
200 |
+
|
201 |
+
hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
202 |
+
hidden_states = hidden_states.flatten(-2)
|
203 |
+
|
204 |
+
txt_length = encoder_hidden_states.shape[1]
|
205 |
+
hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
|
206 |
+
|
207 |
+
hidden_states = attn.to_out[0](hidden_states)
|
208 |
+
hidden_states = attn.to_out[1](hidden_states)
|
209 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
210 |
+
|
211 |
+
return hidden_states, encoder_hidden_states
|
212 |
+
|
213 |
+
|
214 |
+
class HunyuanAttnProcessorFlashAttnSingle:
|
215 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
|
216 |
+
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
|
217 |
+
|
218 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
219 |
+
|
220 |
+
query = attn.to_q(hidden_states)
|
221 |
+
key = attn.to_k(hidden_states)
|
222 |
+
value = attn.to_v(hidden_states)
|
223 |
+
|
224 |
+
query = query.unflatten(2, (attn.heads, -1))
|
225 |
+
key = key.unflatten(2, (attn.heads, -1))
|
226 |
+
value = value.unflatten(2, (attn.heads, -1))
|
227 |
+
|
228 |
+
query = attn.norm_q(query)
|
229 |
+
key = attn.norm_k(key)
|
230 |
+
|
231 |
+
txt_length = encoder_hidden_states.shape[1]
|
232 |
+
|
233 |
+
query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
|
234 |
+
key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
|
235 |
+
|
236 |
+
hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
237 |
+
hidden_states = hidden_states.flatten(-2)
|
238 |
+
|
239 |
+
hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
|
240 |
+
|
241 |
+
return hidden_states, encoder_hidden_states
|
242 |
+
|
243 |
+
|
244 |
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
245 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
246 |
+
super().__init__()
|
247 |
+
|
248 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
249 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
250 |
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
251 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
252 |
+
|
253 |
+
def forward(self, timestep, guidance, pooled_projection):
|
254 |
+
timesteps_proj = self.time_proj(timestep)
|
255 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
|
256 |
+
|
257 |
+
guidance_proj = self.time_proj(guidance)
|
258 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
259 |
+
|
260 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
261 |
+
|
262 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
263 |
+
conditioning = time_guidance_emb + pooled_projections
|
264 |
+
|
265 |
+
return conditioning
|
266 |
+
|
267 |
+
|
268 |
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
269 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
273 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
274 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
275 |
+
|
276 |
+
def forward(self, timestep, pooled_projection):
|
277 |
+
timesteps_proj = self.time_proj(timestep)
|
278 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
|
279 |
+
|
280 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
281 |
+
|
282 |
+
conditioning = timesteps_emb + pooled_projections
|
283 |
+
|
284 |
+
return conditioning
|
285 |
+
|
286 |
+
|
287 |
+
class HunyuanVideoAdaNorm(nn.Module):
|
288 |
+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
289 |
+
super().__init__()
|
290 |
+
|
291 |
+
out_features = out_features or 2 * in_features
|
292 |
+
self.linear = nn.Linear(in_features, out_features)
|
293 |
+
self.nonlinearity = nn.SiLU()
|
294 |
+
|
295 |
+
def forward(
|
296 |
+
self, temb: torch.Tensor
|
297 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
298 |
+
temb = self.linear(self.nonlinearity(temb))
|
299 |
+
gate_msa, gate_mlp = temb.chunk(2, dim=-1)
|
300 |
+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
301 |
+
return gate_msa, gate_mlp
|
302 |
+
|
303 |
+
|
304 |
+
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
305 |
+
def __init__(
|
306 |
+
self,
|
307 |
+
num_attention_heads: int,
|
308 |
+
attention_head_dim: int,
|
309 |
+
mlp_width_ratio: str = 4.0,
|
310 |
+
mlp_drop_rate: float = 0.0,
|
311 |
+
attention_bias: bool = True,
|
312 |
+
) -> None:
|
313 |
+
super().__init__()
|
314 |
+
|
315 |
+
hidden_size = num_attention_heads * attention_head_dim
|
316 |
+
|
317 |
+
self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
318 |
+
self.attn = Attention(
|
319 |
+
query_dim=hidden_size,
|
320 |
+
cross_attention_dim=None,
|
321 |
+
heads=num_attention_heads,
|
322 |
+
dim_head=attention_head_dim,
|
323 |
+
bias=attention_bias,
|
324 |
+
)
|
325 |
+
|
326 |
+
self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
327 |
+
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
328 |
+
|
329 |
+
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
330 |
+
|
331 |
+
def forward(
|
332 |
+
self,
|
333 |
+
hidden_states: torch.Tensor,
|
334 |
+
temb: torch.Tensor,
|
335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
336 |
+
) -> torch.Tensor:
|
337 |
+
norm_hidden_states = self.norm1(hidden_states)
|
338 |
+
|
339 |
+
attn_output = self.attn(
|
340 |
+
hidden_states=norm_hidden_states,
|
341 |
+
encoder_hidden_states=None,
|
342 |
+
attention_mask=attention_mask,
|
343 |
+
)
|
344 |
+
|
345 |
+
gate_msa, gate_mlp = self.norm_out(temb)
|
346 |
+
hidden_states = hidden_states + attn_output * gate_msa
|
347 |
+
|
348 |
+
ff_output = self.ff(self.norm2(hidden_states))
|
349 |
+
hidden_states = hidden_states + ff_output * gate_mlp
|
350 |
+
|
351 |
+
return hidden_states
|
352 |
+
|
353 |
+
|
354 |
+
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
355 |
+
def __init__(
|
356 |
+
self,
|
357 |
+
num_attention_heads: int,
|
358 |
+
attention_head_dim: int,
|
359 |
+
num_layers: int,
|
360 |
+
mlp_width_ratio: float = 4.0,
|
361 |
+
mlp_drop_rate: float = 0.0,
|
362 |
+
attention_bias: bool = True,
|
363 |
+
) -> None:
|
364 |
+
super().__init__()
|
365 |
+
|
366 |
+
self.refiner_blocks = nn.ModuleList(
|
367 |
+
[
|
368 |
+
HunyuanVideoIndividualTokenRefinerBlock(
|
369 |
+
num_attention_heads=num_attention_heads,
|
370 |
+
attention_head_dim=attention_head_dim,
|
371 |
+
mlp_width_ratio=mlp_width_ratio,
|
372 |
+
mlp_drop_rate=mlp_drop_rate,
|
373 |
+
attention_bias=attention_bias,
|
374 |
+
)
|
375 |
+
for _ in range(num_layers)
|
376 |
+
]
|
377 |
+
)
|
378 |
+
|
379 |
+
def forward(
|
380 |
+
self,
|
381 |
+
hidden_states: torch.Tensor,
|
382 |
+
temb: torch.Tensor,
|
383 |
+
attention_mask: Optional[torch.Tensor] = None,
|
384 |
+
) -> None:
|
385 |
+
self_attn_mask = None
|
386 |
+
if attention_mask is not None:
|
387 |
+
batch_size = attention_mask.shape[0]
|
388 |
+
seq_len = attention_mask.shape[1]
|
389 |
+
attention_mask = attention_mask.to(hidden_states.device).bool()
|
390 |
+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
391 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
392 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
393 |
+
self_attn_mask[:, :, :, 0] = True
|
394 |
+
|
395 |
+
for block in self.refiner_blocks:
|
396 |
+
hidden_states = block(hidden_states, temb, self_attn_mask)
|
397 |
+
|
398 |
+
return hidden_states
|
399 |
+
|
400 |
+
|
401 |
+
class HunyuanVideoTokenRefiner(nn.Module):
|
402 |
+
def __init__(
|
403 |
+
self,
|
404 |
+
in_channels: int,
|
405 |
+
num_attention_heads: int,
|
406 |
+
attention_head_dim: int,
|
407 |
+
num_layers: int,
|
408 |
+
mlp_ratio: float = 4.0,
|
409 |
+
mlp_drop_rate: float = 0.0,
|
410 |
+
attention_bias: bool = True,
|
411 |
+
) -> None:
|
412 |
+
super().__init__()
|
413 |
+
|
414 |
+
hidden_size = num_attention_heads * attention_head_dim
|
415 |
+
|
416 |
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
417 |
+
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
418 |
+
)
|
419 |
+
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
420 |
+
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
421 |
+
num_attention_heads=num_attention_heads,
|
422 |
+
attention_head_dim=attention_head_dim,
|
423 |
+
num_layers=num_layers,
|
424 |
+
mlp_width_ratio=mlp_ratio,
|
425 |
+
mlp_drop_rate=mlp_drop_rate,
|
426 |
+
attention_bias=attention_bias,
|
427 |
+
)
|
428 |
+
|
429 |
+
def forward(
|
430 |
+
self,
|
431 |
+
hidden_states: torch.Tensor,
|
432 |
+
timestep: torch.LongTensor,
|
433 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
434 |
+
) -> torch.Tensor:
|
435 |
+
if attention_mask is None:
|
436 |
+
pooled_projections = hidden_states.mean(dim=1)
|
437 |
+
else:
|
438 |
+
original_dtype = hidden_states.dtype
|
439 |
+
mask_float = attention_mask.float().unsqueeze(-1)
|
440 |
+
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
441 |
+
pooled_projections = pooled_projections.to(original_dtype)
|
442 |
+
|
443 |
+
temb = self.time_text_embed(timestep, pooled_projections)
|
444 |
+
hidden_states = self.proj_in(hidden_states)
|
445 |
+
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
446 |
+
|
447 |
+
return hidden_states
|
448 |
+
|
449 |
+
|
450 |
+
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
451 |
+
def __init__(self, rope_dim, theta):
|
452 |
+
super().__init__()
|
453 |
+
self.DT, self.DY, self.DX = rope_dim
|
454 |
+
self.theta = theta
|
455 |
+
|
456 |
+
@torch.no_grad()
|
457 |
+
def get_frequency(self, dim, pos):
|
458 |
+
T, H, W = pos.shape
|
459 |
+
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
|
460 |
+
freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
|
461 |
+
return freqs.cos(), freqs.sin()
|
462 |
+
|
463 |
+
@torch.no_grad()
|
464 |
+
def forward_inner(self, frame_indices, height, width, device):
|
465 |
+
GT, GY, GX = torch.meshgrid(
|
466 |
+
frame_indices.to(device=device, dtype=torch.float32),
|
467 |
+
torch.arange(0, height, device=device, dtype=torch.float32),
|
468 |
+
torch.arange(0, width, device=device, dtype=torch.float32),
|
469 |
+
indexing="ij"
|
470 |
+
)
|
471 |
+
|
472 |
+
FCT, FST = self.get_frequency(self.DT, GT)
|
473 |
+
FCY, FSY = self.get_frequency(self.DY, GY)
|
474 |
+
FCX, FSX = self.get_frequency(self.DX, GX)
|
475 |
+
|
476 |
+
result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
|
477 |
+
|
478 |
+
return result.to(device)
|
479 |
+
|
480 |
+
@torch.no_grad()
|
481 |
+
def forward(self, frame_indices, height, width, device):
|
482 |
+
frame_indices = frame_indices.unbind(0)
|
483 |
+
results = [self.forward_inner(f, height, width, device) for f in frame_indices]
|
484 |
+
results = torch.stack(results, dim=0)
|
485 |
+
return results
|
486 |
+
|
487 |
+
|
488 |
+
class AdaLayerNormZero(nn.Module):
|
489 |
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
490 |
+
super().__init__()
|
491 |
+
self.silu = nn.SiLU()
|
492 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
493 |
+
if norm_type == "layer_norm":
|
494 |
+
self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
495 |
+
else:
|
496 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
497 |
+
|
498 |
+
def forward(
|
499 |
+
self,
|
500 |
+
x: torch.Tensor,
|
501 |
+
emb: Optional[torch.Tensor] = None,
|
502 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
503 |
+
emb = emb.unsqueeze(-2)
|
504 |
+
emb = self.linear(self.silu(emb))
|
505 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
|
506 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
507 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
508 |
+
|
509 |
+
|
510 |
+
class AdaLayerNormZeroSingle(nn.Module):
|
511 |
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
512 |
+
super().__init__()
|
513 |
+
|
514 |
+
self.silu = nn.SiLU()
|
515 |
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
516 |
+
if norm_type == "layer_norm":
|
517 |
+
self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
518 |
+
else:
|
519 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
520 |
+
|
521 |
+
def forward(
|
522 |
+
self,
|
523 |
+
x: torch.Tensor,
|
524 |
+
emb: Optional[torch.Tensor] = None,
|
525 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
526 |
+
emb = emb.unsqueeze(-2)
|
527 |
+
emb = self.linear(self.silu(emb))
|
528 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
|
529 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
530 |
+
return x, gate_msa
|
531 |
+
|
532 |
+
|
533 |
+
class AdaLayerNormContinuous(nn.Module):
|
534 |
+
def __init__(
|
535 |
+
self,
|
536 |
+
embedding_dim: int,
|
537 |
+
conditioning_embedding_dim: int,
|
538 |
+
elementwise_affine=True,
|
539 |
+
eps=1e-5,
|
540 |
+
bias=True,
|
541 |
+
norm_type="layer_norm",
|
542 |
+
):
|
543 |
+
super().__init__()
|
544 |
+
self.silu = nn.SiLU()
|
545 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
546 |
+
if norm_type == "layer_norm":
|
547 |
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
548 |
+
else:
|
549 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
550 |
+
|
551 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
552 |
+
emb = emb.unsqueeze(-2)
|
553 |
+
emb = self.linear(self.silu(emb))
|
554 |
+
scale, shift = emb.chunk(2, dim=-1)
|
555 |
+
x = self.norm(x) * (1 + scale) + shift
|
556 |
+
return x
|
557 |
+
|
558 |
+
|
559 |
+
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
560 |
+
def __init__(
|
561 |
+
self,
|
562 |
+
num_attention_heads: int,
|
563 |
+
attention_head_dim: int,
|
564 |
+
mlp_ratio: float = 4.0,
|
565 |
+
qk_norm: str = "rms_norm",
|
566 |
+
) -> None:
|
567 |
+
super().__init__()
|
568 |
+
|
569 |
+
hidden_size = num_attention_heads * attention_head_dim
|
570 |
+
mlp_dim = int(hidden_size * mlp_ratio)
|
571 |
+
|
572 |
+
self.attn = Attention(
|
573 |
+
query_dim=hidden_size,
|
574 |
+
cross_attention_dim=None,
|
575 |
+
dim_head=attention_head_dim,
|
576 |
+
heads=num_attention_heads,
|
577 |
+
out_dim=hidden_size,
|
578 |
+
bias=True,
|
579 |
+
processor=HunyuanAttnProcessorFlashAttnSingle(),
|
580 |
+
qk_norm=qk_norm,
|
581 |
+
eps=1e-6,
|
582 |
+
pre_only=True,
|
583 |
+
)
|
584 |
+
|
585 |
+
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
586 |
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
587 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
588 |
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
589 |
+
|
590 |
+
def forward(
|
591 |
+
self,
|
592 |
+
hidden_states: torch.Tensor,
|
593 |
+
encoder_hidden_states: torch.Tensor,
|
594 |
+
temb: torch.Tensor,
|
595 |
+
attention_mask: Optional[torch.Tensor] = None,
|
596 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
597 |
+
) -> torch.Tensor:
|
598 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
599 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
600 |
+
|
601 |
+
residual = hidden_states
|
602 |
+
|
603 |
+
# 1. Input normalization
|
604 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
605 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
606 |
+
|
607 |
+
norm_hidden_states, norm_encoder_hidden_states = (
|
608 |
+
norm_hidden_states[:, :-text_seq_length, :],
|
609 |
+
norm_hidden_states[:, -text_seq_length:, :],
|
610 |
+
)
|
611 |
+
|
612 |
+
# 2. Attention
|
613 |
+
attn_output, context_attn_output = self.attn(
|
614 |
+
hidden_states=norm_hidden_states,
|
615 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
616 |
+
attention_mask=attention_mask,
|
617 |
+
image_rotary_emb=image_rotary_emb,
|
618 |
+
)
|
619 |
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
620 |
+
|
621 |
+
# 3. Modulation and residual connection
|
622 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
623 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
624 |
+
hidden_states = hidden_states + residual
|
625 |
+
|
626 |
+
hidden_states, encoder_hidden_states = (
|
627 |
+
hidden_states[:, :-text_seq_length, :],
|
628 |
+
hidden_states[:, -text_seq_length:, :],
|
629 |
+
)
|
630 |
+
return hidden_states, encoder_hidden_states
|
631 |
+
|
632 |
+
|
633 |
+
class HunyuanVideoTransformerBlock(nn.Module):
|
634 |
+
def __init__(
|
635 |
+
self,
|
636 |
+
num_attention_heads: int,
|
637 |
+
attention_head_dim: int,
|
638 |
+
mlp_ratio: float,
|
639 |
+
qk_norm: str = "rms_norm",
|
640 |
+
) -> None:
|
641 |
+
super().__init__()
|
642 |
+
|
643 |
+
hidden_size = num_attention_heads * attention_head_dim
|
644 |
+
|
645 |
+
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
646 |
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
647 |
+
|
648 |
+
self.attn = Attention(
|
649 |
+
query_dim=hidden_size,
|
650 |
+
cross_attention_dim=None,
|
651 |
+
added_kv_proj_dim=hidden_size,
|
652 |
+
dim_head=attention_head_dim,
|
653 |
+
heads=num_attention_heads,
|
654 |
+
out_dim=hidden_size,
|
655 |
+
context_pre_only=False,
|
656 |
+
bias=True,
|
657 |
+
processor=HunyuanAttnProcessorFlashAttnDouble(),
|
658 |
+
qk_norm=qk_norm,
|
659 |
+
eps=1e-6,
|
660 |
+
)
|
661 |
+
|
662 |
+
self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
663 |
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
664 |
+
|
665 |
+
self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
666 |
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
667 |
+
|
668 |
+
def forward(
|
669 |
+
self,
|
670 |
+
hidden_states: torch.Tensor,
|
671 |
+
encoder_hidden_states: torch.Tensor,
|
672 |
+
temb: torch.Tensor,
|
673 |
+
attention_mask: Optional[torch.Tensor] = None,
|
674 |
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
675 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
676 |
+
# 1. Input normalization
|
677 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
678 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
|
679 |
+
|
680 |
+
# 2. Joint attention
|
681 |
+
attn_output, context_attn_output = self.attn(
|
682 |
+
hidden_states=norm_hidden_states,
|
683 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
684 |
+
attention_mask=attention_mask,
|
685 |
+
image_rotary_emb=freqs_cis,
|
686 |
+
)
|
687 |
+
|
688 |
+
# 3. Modulation and residual connection
|
689 |
+
hidden_states = hidden_states + attn_output * gate_msa
|
690 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
|
691 |
+
|
692 |
+
norm_hidden_states = self.norm2(hidden_states)
|
693 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
694 |
+
|
695 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
696 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
697 |
+
|
698 |
+
# 4. Feed-forward
|
699 |
+
ff_output = self.ff(norm_hidden_states)
|
700 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
701 |
+
|
702 |
+
hidden_states = hidden_states + gate_mlp * ff_output
|
703 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
704 |
+
|
705 |
+
return hidden_states, encoder_hidden_states
|
706 |
+
|
707 |
+
|
708 |
+
class ClipVisionProjection(nn.Module):
|
709 |
+
def __init__(self, in_channels, out_channels):
|
710 |
+
super().__init__()
|
711 |
+
self.up = nn.Linear(in_channels, out_channels * 3)
|
712 |
+
self.down = nn.Linear(out_channels * 3, out_channels)
|
713 |
+
|
714 |
+
def forward(self, x):
|
715 |
+
projected_x = self.down(nn.functional.silu(self.up(x)))
|
716 |
+
return projected_x
|
717 |
+
|
718 |
+
|
719 |
+
class HunyuanVideoPatchEmbed(nn.Module):
|
720 |
+
def __init__(self, patch_size, in_chans, embed_dim):
|
721 |
+
super().__init__()
|
722 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
723 |
+
|
724 |
+
|
725 |
+
class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
|
726 |
+
def __init__(self, inner_dim):
|
727 |
+
super().__init__()
|
728 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
729 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
730 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
731 |
+
|
732 |
+
@torch.no_grad()
|
733 |
+
def initialize_weight_from_another_conv3d(self, another_layer):
|
734 |
+
weight = another_layer.weight.detach().clone()
|
735 |
+
bias = another_layer.bias.detach().clone()
|
736 |
+
|
737 |
+
sd = {
|
738 |
+
'proj.weight': weight.clone(),
|
739 |
+
'proj.bias': bias.clone(),
|
740 |
+
'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
|
741 |
+
'proj_2x.bias': bias.clone(),
|
742 |
+
'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
|
743 |
+
'proj_4x.bias': bias.clone(),
|
744 |
+
}
|
745 |
+
|
746 |
+
sd = {k: v.clone() for k, v in sd.items()}
|
747 |
+
|
748 |
+
self.load_state_dict(sd)
|
749 |
+
return
|
750 |
+
|
751 |
+
|
752 |
+
class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
753 |
+
@register_to_config
|
754 |
+
def __init__(
|
755 |
+
self,
|
756 |
+
in_channels: int = 16,
|
757 |
+
out_channels: int = 16,
|
758 |
+
num_attention_heads: int = 24,
|
759 |
+
attention_head_dim: int = 128,
|
760 |
+
num_layers: int = 20,
|
761 |
+
num_single_layers: int = 40,
|
762 |
+
num_refiner_layers: int = 2,
|
763 |
+
mlp_ratio: float = 4.0,
|
764 |
+
patch_size: int = 2,
|
765 |
+
patch_size_t: int = 1,
|
766 |
+
qk_norm: str = "rms_norm",
|
767 |
+
guidance_embeds: bool = True,
|
768 |
+
text_embed_dim: int = 4096,
|
769 |
+
pooled_projection_dim: int = 768,
|
770 |
+
rope_theta: float = 256.0,
|
771 |
+
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
772 |
+
has_image_proj=False,
|
773 |
+
image_proj_dim=1152,
|
774 |
+
has_clean_x_embedder=False,
|
775 |
+
) -> None:
|
776 |
+
super().__init__()
|
777 |
+
|
778 |
+
inner_dim = num_attention_heads * attention_head_dim
|
779 |
+
out_channels = out_channels or in_channels
|
780 |
+
|
781 |
+
# 1. Latent and condition embedders
|
782 |
+
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
783 |
+
self.context_embedder = HunyuanVideoTokenRefiner(
|
784 |
+
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
785 |
+
)
|
786 |
+
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
787 |
+
|
788 |
+
self.clean_x_embedder = None
|
789 |
+
self.image_projection = None
|
790 |
+
|
791 |
+
# 2. RoPE
|
792 |
+
self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
|
793 |
+
|
794 |
+
# 3. Dual stream transformer blocks
|
795 |
+
self.transformer_blocks = nn.ModuleList(
|
796 |
+
[
|
797 |
+
HunyuanVideoTransformerBlock(
|
798 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
799 |
+
)
|
800 |
+
for _ in range(num_layers)
|
801 |
+
]
|
802 |
+
)
|
803 |
+
|
804 |
+
# 4. Single stream transformer blocks
|
805 |
+
self.single_transformer_blocks = nn.ModuleList(
|
806 |
+
[
|
807 |
+
HunyuanVideoSingleTransformerBlock(
|
808 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
809 |
+
)
|
810 |
+
for _ in range(num_single_layers)
|
811 |
+
]
|
812 |
+
)
|
813 |
+
|
814 |
+
# 5. Output projection
|
815 |
+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
816 |
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
817 |
+
|
818 |
+
self.inner_dim = inner_dim
|
819 |
+
self.use_gradient_checkpointing = False
|
820 |
+
self.enable_teacache = False
|
821 |
+
self.magcache: MagCache = None
|
822 |
+
|
823 |
+
if has_image_proj:
|
824 |
+
self.install_image_projection(image_proj_dim)
|
825 |
+
|
826 |
+
if has_clean_x_embedder:
|
827 |
+
self.install_clean_x_embedder()
|
828 |
+
|
829 |
+
self.high_quality_fp32_output_for_inference = False
|
830 |
+
|
831 |
+
def install_image_projection(self, in_channels):
|
832 |
+
self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
|
833 |
+
self.config['has_image_proj'] = True
|
834 |
+
self.config['image_proj_dim'] = in_channels
|
835 |
+
|
836 |
+
def install_clean_x_embedder(self):
|
837 |
+
self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
|
838 |
+
self.config['has_clean_x_embedder'] = True
|
839 |
+
|
840 |
+
def enable_gradient_checkpointing(self):
|
841 |
+
self.use_gradient_checkpointing = True
|
842 |
+
print('self.use_gradient_checkpointing = True')
|
843 |
+
|
844 |
+
def disable_gradient_checkpointing(self):
|
845 |
+
self.use_gradient_checkpointing = False
|
846 |
+
print('self.use_gradient_checkpointing = False')
|
847 |
+
|
848 |
+
def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
|
849 |
+
self.enable_teacache = enable_teacache
|
850 |
+
self.cnt = 0
|
851 |
+
self.num_steps = num_steps
|
852 |
+
self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
853 |
+
self.accumulated_rel_l1_distance = 0
|
854 |
+
self.previous_modulated_input = None
|
855 |
+
self.previous_residual = None
|
856 |
+
self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
|
857 |
+
|
858 |
+
def install_magcache(self, magcache: MagCache):
|
859 |
+
self.magcache = magcache
|
860 |
+
|
861 |
+
def uninstall_magcache(self):
|
862 |
+
self.magcache = None
|
863 |
+
|
864 |
+
def gradient_checkpointing_method(self, block, *args):
|
865 |
+
if self.use_gradient_checkpointing:
|
866 |
+
result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
|
867 |
+
else:
|
868 |
+
result = block(*args)
|
869 |
+
return result
|
870 |
+
|
871 |
+
def process_input_hidden_states(
|
872 |
+
self,
|
873 |
+
latents, latent_indices=None,
|
874 |
+
clean_latents=None, clean_latent_indices=None,
|
875 |
+
clean_latents_2x=None, clean_latent_2x_indices=None,
|
876 |
+
clean_latents_4x=None, clean_latent_4x_indices=None
|
877 |
+
):
|
878 |
+
hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
|
879 |
+
B, C, T, H, W = hidden_states.shape
|
880 |
+
|
881 |
+
if latent_indices is None:
|
882 |
+
latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
|
883 |
+
|
884 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
885 |
+
|
886 |
+
rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
|
887 |
+
rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
|
888 |
+
|
889 |
+
if clean_latents is not None and clean_latent_indices is not None:
|
890 |
+
clean_latents = clean_latents.to(hidden_states)
|
891 |
+
clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
|
892 |
+
clean_latents = clean_latents.flatten(2).transpose(1, 2)
|
893 |
+
|
894 |
+
clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
|
895 |
+
clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
|
896 |
+
|
897 |
+
hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
|
898 |
+
rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
|
899 |
+
|
900 |
+
if clean_latents_2x is not None and clean_latent_2x_indices is not None:
|
901 |
+
clean_latents_2x = clean_latents_2x.to(hidden_states)
|
902 |
+
clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
|
903 |
+
clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
|
904 |
+
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
905 |
+
|
906 |
+
clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
|
907 |
+
clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
|
908 |
+
clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
|
909 |
+
clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
|
910 |
+
|
911 |
+
hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
|
912 |
+
rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
|
913 |
+
|
914 |
+
if clean_latents_4x is not None and clean_latent_4x_indices is not None:
|
915 |
+
clean_latents_4x = clean_latents_4x.to(hidden_states)
|
916 |
+
clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
|
917 |
+
clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
|
918 |
+
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
919 |
+
|
920 |
+
clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
|
921 |
+
clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
|
922 |
+
clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
|
923 |
+
clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
|
924 |
+
|
925 |
+
hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
|
926 |
+
rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
|
927 |
+
|
928 |
+
return hidden_states, rope_freqs
|
929 |
+
|
930 |
+
def forward(
|
931 |
+
self,
|
932 |
+
hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
|
933 |
+
latent_indices=None,
|
934 |
+
clean_latents=None, clean_latent_indices=None,
|
935 |
+
clean_latents_2x=None, clean_latent_2x_indices=None,
|
936 |
+
clean_latents_4x=None, clean_latent_4x_indices=None,
|
937 |
+
image_embeddings=None,
|
938 |
+
attention_kwargs=None, return_dict=True
|
939 |
+
):
|
940 |
+
|
941 |
+
if attention_kwargs is None:
|
942 |
+
attention_kwargs = {}
|
943 |
+
|
944 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
945 |
+
p, p_t = self.config['patch_size'], self.config['patch_size_t']
|
946 |
+
post_patch_num_frames = num_frames // p_t
|
947 |
+
post_patch_height = height // p
|
948 |
+
post_patch_width = width // p
|
949 |
+
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
|
950 |
+
|
951 |
+
hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
|
952 |
+
|
953 |
+
temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
|
954 |
+
encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
|
955 |
+
|
956 |
+
if self.image_projection is not None:
|
957 |
+
assert image_embeddings is not None, 'You must use image embeddings!'
|
958 |
+
extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
|
959 |
+
extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
|
960 |
+
|
961 |
+
# must cat before (not after) encoder_hidden_states, due to attn masking
|
962 |
+
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
963 |
+
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
964 |
+
|
965 |
+
with torch.no_grad():
|
966 |
+
if batch_size == 1:
|
967 |
+
# When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
|
968 |
+
# If they are not same, then their impls are wrong. Ours are always the correct one.
|
969 |
+
text_len = encoder_attention_mask.sum().item()
|
970 |
+
encoder_hidden_states = encoder_hidden_states[:, :text_len]
|
971 |
+
attention_mask = None, None, None, None
|
972 |
+
else:
|
973 |
+
img_seq_len = hidden_states.shape[1]
|
974 |
+
txt_seq_len = encoder_hidden_states.shape[1]
|
975 |
+
|
976 |
+
cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
|
977 |
+
cu_seqlens_kv = cu_seqlens_q
|
978 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
979 |
+
max_seqlen_kv = max_seqlen_q
|
980 |
+
|
981 |
+
attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
982 |
+
|
983 |
+
if self.enable_teacache:
|
984 |
+
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|
985 |
+
|
986 |
+
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
987 |
+
should_calc = True
|
988 |
+
self.accumulated_rel_l1_distance = 0
|
989 |
+
else:
|
990 |
+
curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
|
991 |
+
self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
|
992 |
+
should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
|
993 |
+
|
994 |
+
if should_calc:
|
995 |
+
self.accumulated_rel_l1_distance = 0
|
996 |
+
|
997 |
+
self.previous_modulated_input = modulated_inp
|
998 |
+
self.cnt += 1
|
999 |
+
|
1000 |
+
if self.cnt == self.num_steps:
|
1001 |
+
self.cnt = 0
|
1002 |
+
|
1003 |
+
if not should_calc:
|
1004 |
+
hidden_states = hidden_states + self.previous_residual
|
1005 |
+
else:
|
1006 |
+
ori_hidden_states = hidden_states.clone()
|
1007 |
+
|
1008 |
+
hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
|
1009 |
+
|
1010 |
+
self.previous_residual = hidden_states - ori_hidden_states
|
1011 |
+
|
1012 |
+
elif self.magcache and self.magcache.is_enabled:
|
1013 |
+
if self.magcache.should_skip(hidden_states):
|
1014 |
+
hidden_states = self.magcache.estimate_predicted_hidden_states()
|
1015 |
+
else:
|
1016 |
+
hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
|
1017 |
+
self.magcache.update_hidden_states(model_prediction_hidden_states=hidden_states)
|
1018 |
+
|
1019 |
+
else:
|
1020 |
+
hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
|
1021 |
+
|
1022 |
+
hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
|
1023 |
+
|
1024 |
+
hidden_states = hidden_states[:, -original_context_length:, :]
|
1025 |
+
|
1026 |
+
if self.high_quality_fp32_output_for_inference:
|
1027 |
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
1028 |
+
if self.proj_out.weight.dtype != torch.float32:
|
1029 |
+
self.proj_out.to(dtype=torch.float32)
|
1030 |
+
|
1031 |
+
hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
|
1032 |
+
|
1033 |
+
hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
|
1034 |
+
t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
|
1035 |
+
pt=p_t, ph=p, pw=p)
|
1036 |
+
|
1037 |
+
if return_dict:
|
1038 |
+
return Transformer2DModelOutput(sample=hidden_states)
|
1039 |
+
|
1040 |
+
return hidden_states,
|
1041 |
+
|
1042 |
+
def _run_denoising_layers(
|
1043 |
+
self,
|
1044 |
+
hidden_states: torch.Tensor,
|
1045 |
+
encoder_hidden_states: torch.Tensor,
|
1046 |
+
temb: torch.Tensor,
|
1047 |
+
attention_mask: Optional[Tuple],
|
1048 |
+
rope_freqs: Optional[torch.Tensor]
|
1049 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1050 |
+
"""
|
1051 |
+
Applies the dual-stream and single-stream transformer blocks.
|
1052 |
+
"""
|
1053 |
+
for block_id, block in enumerate(self.transformer_blocks):
|
1054 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
1055 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
for block_id, block in enumerate(self.single_transformer_blocks):
|
1059 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
1060 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
1061 |
+
)
|
1062 |
+
return hidden_states, encoder_hidden_states
|
diffusers_helper/models/mag_cache.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB
|
6 |
+
|
7 |
+
|
8 |
+
class MagCache:
|
9 |
+
"""
|
10 |
+
Implements the MagCache algorithm for skipping transformer steps during video generation.
|
11 |
+
MagCache: Fast Video Generation with Magnitude-Aware Cache
|
12 |
+
Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian
|
13 |
+
https://arxiv.org/abs/2506.09045
|
14 |
+
https://github.com/Zehong-Ma/MagCache
|
15 |
+
PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2
|
16 |
+
Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25):
|
20 |
+
self.model_family = model_family
|
21 |
+
self.height = height
|
22 |
+
self.width = width
|
23 |
+
self.num_steps = num_steps
|
24 |
+
|
25 |
+
self.is_enabled = is_enabled
|
26 |
+
self.is_calibrating = is_calibrating
|
27 |
+
|
28 |
+
self.threshold = threshold
|
29 |
+
self.max_consectutive_skips = max_consectutive_skips
|
30 |
+
self.retention_ratio = retention_ratio
|
31 |
+
|
32 |
+
# total cache statistics for all sections in the entire generation
|
33 |
+
self.total_cache_requests = 0
|
34 |
+
self.total_cache_hits = 0
|
35 |
+
|
36 |
+
self.mag_ratios = self._determine_mag_ratios()
|
37 |
+
|
38 |
+
self._init_for_every_section()
|
39 |
+
|
40 |
+
|
41 |
+
def _init_for_every_section(self):
|
42 |
+
self.step_index = 0
|
43 |
+
self.steps_skipped_list = []
|
44 |
+
#Error accumulation state
|
45 |
+
self.accumulated_ratio = 1.0
|
46 |
+
self.accumulated_steps = 0
|
47 |
+
self.accumulated_err = 0
|
48 |
+
# Statistics for calibration
|
49 |
+
self.norm_ratio, self.norm_std, self.cos_dis = [], [], []
|
50 |
+
|
51 |
+
self.hidden_states = None
|
52 |
+
self.previous_residual = None
|
53 |
+
|
54 |
+
if self.is_calibrating and self.total_cache_requests > 0:
|
55 |
+
print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.')
|
56 |
+
|
57 |
+
def should_skip(self, hidden_states):
|
58 |
+
"""
|
59 |
+
Expected to be called once per step during the forward pass, for the numer of initialized steps.
|
60 |
+
Determines if the current step should be skipped based on estimated accumulated error.
|
61 |
+
If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states().
|
62 |
+
|
63 |
+
Args:
|
64 |
+
hidden_states: The current hidden states tensor from the transformer model.
|
65 |
+
Returns:
|
66 |
+
True if the step should be skipped, False otherwise
|
67 |
+
"""
|
68 |
+
if self.step_index == 0 or self.step_index >= self.num_steps:
|
69 |
+
self._init_for_every_section()
|
70 |
+
self.total_cache_requests += 1
|
71 |
+
self.hidden_states = hidden_states.clone() # Is clone needed?
|
72 |
+
|
73 |
+
if self.is_calibrating:
|
74 |
+
print('######################### Calibrating MagCache #########################')
|
75 |
+
return False
|
76 |
+
|
77 |
+
should_skip_forward = False
|
78 |
+
if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps
|
79 |
+
cur_mag_ratio = self.mag_ratios[self.step_index]
|
80 |
+
self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio
|
81 |
+
cur_skip_err = np.abs(1-self.accumulated_ratio)
|
82 |
+
self.accumulated_err += cur_skip_err
|
83 |
+
self.accumulated_steps += 1
|
84 |
+
# RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param.
|
85 |
+
if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06:
|
86 |
+
should_skip_forward = True
|
87 |
+
else:
|
88 |
+
self.accumulated_ratio = 1.0
|
89 |
+
self.accumulated_steps = 0
|
90 |
+
self.accumulated_err = 0
|
91 |
+
|
92 |
+
if should_skip_forward:
|
93 |
+
self.total_cache_hits += 1
|
94 |
+
self.steps_skipped_list.append(self.step_index)
|
95 |
+
# Increment for next step
|
96 |
+
self.step_index += 1
|
97 |
+
if self.step_index == self.num_steps:
|
98 |
+
self.step_index = 0
|
99 |
+
|
100 |
+
return should_skip_forward
|
101 |
+
|
102 |
+
def estimate_predicted_hidden_states(self):
|
103 |
+
"""
|
104 |
+
Should be called if and only if should_skip() returned True for the current step.
|
105 |
+
Estimates the hidden states for the current step based on the previous hidden states and residual.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
The estimated hidden states tensor.
|
109 |
+
"""
|
110 |
+
return self.hidden_states + self.previous_residual
|
111 |
+
|
112 |
+
def update_hidden_states(self, model_prediction_hidden_states):
|
113 |
+
"""
|
114 |
+
If and only if should_skip() returned False for the current step, the denoising layers should have been run,
|
115 |
+
and this function should be called to compute and store the residual for future steps.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
model_prediction_hidden_states: The hidden states tensor output from running the denoising layers.
|
119 |
+
"""
|
120 |
+
|
121 |
+
current_residual = model_prediction_hidden_states - self.hidden_states
|
122 |
+
if self.is_calibrating:
|
123 |
+
self._update_calibration_stats(current_residual)
|
124 |
+
|
125 |
+
self.previous_residual = current_residual
|
126 |
+
|
127 |
+
def _update_calibration_stats(self, current_residual):
|
128 |
+
if self.step_index >= 1:
|
129 |
+
norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item()
|
130 |
+
norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item()
|
131 |
+
cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item()
|
132 |
+
self.norm_ratio.append(round(norm_ratio, 5))
|
133 |
+
self.norm_std.append(round(norm_std, 5))
|
134 |
+
self.cos_dis.append(round(cos_dis, 5))
|
135 |
+
# print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
|
136 |
+
|
137 |
+
self.step_index += 1
|
138 |
+
if self.step_index == self.num_steps:
|
139 |
+
print("norm ratio")
|
140 |
+
print(self.norm_ratio)
|
141 |
+
print("norm std")
|
142 |
+
print(self.norm_std)
|
143 |
+
print("cos_dis")
|
144 |
+
print(self.cos_dis)
|
145 |
+
self.step_index = 0
|
146 |
+
|
147 |
+
def _determine_mag_ratios(self):
|
148 |
+
"""
|
149 |
+
Determines the magnitude ratios by finding the closest resolution and step count
|
150 |
+
in the pre-calibrated database.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
A numpy array of magnitude ratios for the specified configuration, or None if not found.
|
154 |
+
"""
|
155 |
+
if self.is_calibrating:
|
156 |
+
return None
|
157 |
+
try:
|
158 |
+
# Find the closest available resolution group for the given model family
|
159 |
+
resolution_groups = MAG_RATIOS_DB[self.model_family]
|
160 |
+
available_resolutions = list(resolution_groups.keys())
|
161 |
+
if not available_resolutions:
|
162 |
+
raise ValueError("No resolutions defined for this model family.")
|
163 |
+
|
164 |
+
avg_resolution = (self.height + self.width) / 2.0
|
165 |
+
closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution))
|
166 |
+
|
167 |
+
# Find the closest available step count for the given model/resolution
|
168 |
+
steps_group = resolution_groups[closest_resolution_key]
|
169 |
+
available_steps = list(steps_group.keys())
|
170 |
+
if not available_steps:
|
171 |
+
raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.")
|
172 |
+
closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps))
|
173 |
+
base_ratios = steps_group[closest_steps]
|
174 |
+
if closest_steps == self.num_steps:
|
175 |
+
print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.")
|
176 |
+
return base_ratios
|
177 |
+
print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.")
|
178 |
+
return self._nearest_step_interpolation(base_ratios, self.num_steps)
|
179 |
+
except KeyError:
|
180 |
+
# This will catch if model_family is not in MAG_RATIOS_DB
|
181 |
+
print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.")
|
182 |
+
self.is_enabled = False
|
183 |
+
except (ValueError, TypeError) as e:
|
184 |
+
# This will catch errors if resolution keys or step keys are not numbers, or if groups are empty.
|
185 |
+
print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.")
|
186 |
+
self.is_enabled = False
|
187 |
+
return None
|
188 |
+
|
189 |
+
# Nearest interpolation function for MagCache mag_ratios
|
190 |
+
@staticmethod
|
191 |
+
def _nearest_step_interpolation(src_array, target_length):
|
192 |
+
src_length = len(src_array)
|
193 |
+
if target_length == 1:
|
194 |
+
return np.array([src_array[-1]])
|
195 |
+
|
196 |
+
scale = (src_length - 1) / (target_length - 1)
|
197 |
+
mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
|
198 |
+
return src_array[mapped_indices]
|
199 |
+
|
200 |
+
def append_calibration_to_file(self, output_file):
|
201 |
+
"""
|
202 |
+
Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file.
|
203 |
+
"""
|
204 |
+
if not self.is_calibrating or not self.norm_ratio:
|
205 |
+
print("Calibration data can only be appended after calibration.")
|
206 |
+
return False
|
207 |
+
try:
|
208 |
+
with open(output_file, "a") as f:
|
209 |
+
# Format the data as a string
|
210 |
+
calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}"
|
211 |
+
# data_string = f"{calibration_set}\t{self.norm_ratio}"
|
212 |
+
entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio}),"
|
213 |
+
# Append the data to the file
|
214 |
+
f.write(entry_string + "\n")
|
215 |
+
print(f"Calibration data appended to {output_file}")
|
216 |
+
return True
|
217 |
+
except Exception as e:
|
218 |
+
print(f"Error appending calibration data: {e}")
|
219 |
+
return False
|
diffusers_helper/models/mag_cache_ratios.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Pre-calibrated magnitude ratios for different model families, resolutions, and step counts
|
4 |
+
# Format: MAG_RATIOS_DB[model_family][resolution][step_count] = np.array([...])
|
5 |
+
# All calibrations performed with FramePackStudio v0.4 with default settings and seed 31337
|
6 |
+
MAG_RATIOS_DB = {
|
7 |
+
"Original": {
|
8 |
+
768: {
|
9 |
+
25: np.array([1.0] + [1.30469, 1.22656, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.99219, 1.00781, 1.00781, 0.98828, 0.94141, 0.93359, 0.78906]),
|
10 |
+
50: np.array([1.0] + [1.30469, 0.99609, 1.16406, 1.0625, 1.01562, 1.02344, 1.01562, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.78125]),
|
11 |
+
75: np.array([1.0] + [1.01562, 1.27344, 1.0, 1.15625, 1.0625, 1.00781, 1.02344, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.0, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94531, 0.91406, 0.90234, 0.85938, 0.76172]),
|
12 |
+
},
|
13 |
+
640: {
|
14 |
+
25: np.array([1.0] + [1.30469, 1.22656, 1.05469, 1.02344, 1.03906, 1.02344, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.03906, 0.98828, 1.0, 1.0, 0.98828, 0.94531, 0.93359, 0.78516]),
|
15 |
+
50: np.array([1.0] + [1.28906, 1.0, 1.17188, 1.0625, 1.02344, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.01562, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.97656, 0.91016, 0.85547, 0.76953]),
|
16 |
+
75: np.array([1.0] + [1.00781, 1.30469, 1.0, 1.15625, 1.05469, 1.01562, 1.01562, 1.0, 1.01562, 1.0, 1.02344, 0.99609, 1.0, 1.04688, 0.99219, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99219, 1.00781, 0.99219, 0.99609, 1.00781, 1.03906, 0.98828, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03906, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]),
|
17 |
+
},
|
18 |
+
512: {
|
19 |
+
25: np.array([1.0] + [1.32031, 1.21875, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.98828, 1.0, 1.00781, 0.98828, 0.94141, 0.9375, 0.78516]),
|
20 |
+
50: np.array([1.0] + [1.32031, 0.99609, 1.15625, 1.0625, 1.01562, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.98828, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.77734]),
|
21 |
+
75: np.array([1.0] + [1.02344, 1.28906, 1.0, 1.15625, 1.0625, 1.01562, 1.01562, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98828, 1.01562, 1.02344, 0.99219, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 1.0, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]),
|
22 |
+
},
|
23 |
+
384: {
|
24 |
+
25: np.array([1.0] + [1.58594, 1.0625, 1.03906, 1.02344, 1.0625, 1.04688, 1.04688, 1.03125, 1.02344, 1.01562, 1.00781, 1.01562, 1.01562, 0.99219, 1.07031, 0.96094, 0.96484, 1.03125, 0.96875, 0.94141, 0.97266, 0.92188, 0.88672, 0.75]),
|
25 |
+
50: np.array([1.0] + [1.29688, 1.21875, 1.02344, 1.03906, 1.0, 1.03906, 1.00781, 1.02344, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 1.00781, 0.98828, 1.03906, 0.99609, 1.03125, 1.03125, 0.98438, 1.01562, 0.99609, 1.01562, 1.0, 1.01562, 1.0, 1.01562, 0.98047, 1.02344, 1.04688, 0.97266, 0.98828, 0.97656, 0.98828, 1.05469, 0.97656, 0.98828, 0.98047, 0.98438, 0.95703, 1.00781, 0.96484, 0.97656, 0.94531, 0.94141, 0.94531, 0.875, 0.85547, 0.79688]),
|
26 |
+
75: np.array([1.0] + [1.29688, 1.14844, 1.07031, 1.01562, 1.02344, 1.02344, 0.99609, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.00781, 1.03125, 1.00781, 1.00781, 1.04688, 0.99219, 1.00781, 0.99219, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.01562, 1.01562, 1.0, 1.0, 1.00781, 0.99219, 1.01562, 1.0, 0.99609, 1.03125, 0.98828, 1.02344, 0.99609, 0.97656, 1.01562, 1.00781, 1.03906, 0.96484, 1.03125, 0.96484, 1.02344, 0.98438, 0.96094, 1.03125, 0.98828, 1.01562, 0.97266, 1.02344, 0.97656, 1.0, 0.98438, 0.95703, 1.02344, 0.96094, 0.99609, 0.99609, 0.9375, 0.98438, 0.94141, 0.97266, 0.96875, 0.89844, 0.95703, 0.87109, 0.86328, 0.85547]),
|
27 |
+
},
|
28 |
+
256: {
|
29 |
+
25: np.array([1.0] + [1.59375, 1.10156, 1.08594, 1.05469, 1.03906, 1.03125, 1.03125, 1.02344, 1.01562, 1.02344, 0.98438, 1.0625, 0.96875, 1.00781, 0.98438, 1.00781, 0.92969, 0.97656, 0.99609, 0.91406, 0.94922, 0.88672, 0.86328, 0.75391]),
|
30 |
+
50: np.array([1.0] + [1.46875, 1.10156, 1.04688, 1.03906, 1.02344, 1.0625, 1.03125, 1.02344, 1.03906, 1.0, 1.01562, 1.01562, 1.03125, 0.99609, 1.01562, 1.00781, 0.99609, 1.02344, 1.01562, 1.00781, 1.00781, 0.98047, 1.02344, 1.04688, 0.97266, 0.99609, 1.0, 1.00781, 0.98047, 1.00781, 0.98047, 1.02344, 0.96094, 0.96875, 1.03125, 0.94531, 0.98047, 1.01562, 0.96484, 0.94531, 0.99609, 0.95312, 0.96484, 0.91406, 0.92969, 0.92969, 0.88672, 0.85156, 0.89062]),
|
31 |
+
75: np.array([1.0] + [1.25781, 1.23438, 1.04688, 1.03906, 1.0, 1.03906, 1.02344, 1.00781, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 0.99219, 1.00781, 1.0, 1.03906, 0.99609, 1.03125, 0.98828, 1.00781, 1.00781, 1.02344, 0.99219, 1.00781, 1.00781, 1.0, 0.99609, 1.03125, 0.99609, 1.01562, 0.99609, 0.97656, 1.03906, 0.98438, 1.03906, 0.96484, 1.01562, 0.98438, 1.02344, 0.95312, 1.03906, 0.98047, 1.02344, 0.98047, 1.0, 0.97656, 1.03906, 0.94922, 1.01562, 0.97266, 1.01562, 0.98828, 0.97266, 1.01562, 0.97656, 1.00781, 0.9375, 1.0, 0.97266, 1.00781, 0.96875, 0.97266, 0.96875, 0.93359, 0.98047, 0.95703, 0.94922, 0.94922, 0.92578, 0.94531, 0.85938, 0.91797, 0.94531]),
|
32 |
+
},
|
33 |
+
128: {
|
34 |
+
25: np.array([1.0] + [1.63281, 1.0625, 1.14062, 1.04688, 1.03906, 1.03125, 1.03125, 1.02344, 0.99219, 1.03125, 0.96484, 1.02344, 0.97266, 0.98438, 0.97656, 0.96875, 0.95312, 0.95312, 0.97656, 0.92188, 0.9375, 0.87891, 0.85156, 0.82812]),
|
35 |
+
50: np.array([1.0] + [1.5625, 1.05469, 1.03906, 1.03125, 1.07031, 1.03906, 1.05469, 0.99609, 1.03906, 1.0, 1.05469, 0.97656, 1.01562, 1.01562, 1.0, 1.03125, 1.01562, 0.97656, 0.99609, 1.03906, 0.98828, 0.98047, 1.03125, 0.99609, 0.97266, 1.0, 0.99609, 0.98438, 0.97266, 1.00781, 0.98828, 0.98438, 0.97656, 0.98047, 0.99609, 0.95703, 1.00781, 0.96484, 0.99219, 0.92969, 0.98828, 0.94922, 0.94531, 0.92969, 0.92578, 0.92188, 0.91797, 0.90625, 1.00781]),
|
36 |
+
75: np.array([1.0] + [1.47656, 1.08594, 1.02344, 1.0625, 1.0, 1.02344, 1.0625, 1.03906, 1.01562, 1.0, 1.04688, 1.0, 1.01562, 1.00781, 1.01562, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 1.00781, 1.01562, 0.97656, 1.00781, 0.98438, 1.03906, 0.97656, 1.01562, 0.94531, 1.03125, 1.00781, 0.98438, 1.02344, 0.97656, 1.00781, 0.95703, 1.01562, 0.97656, 1.02344, 0.97266, 0.96484, 1.02344, 0.96094, 0.99609, 0.99609, 0.96094, 1.0, 1.00781, 0.97266, 0.98828, 0.96875, 0.96484, 0.98828, 0.95703, 0.99219, 0.97266, 0.89844, 1.0, 0.96094, 0.92578, 0.95703, 0.9375, 0.91016, 0.97266, 0.96875, 1.07812]),
|
37 |
+
},
|
38 |
+
},
|
39 |
+
"F1": {
|
40 |
+
768: {
|
41 |
+
25: np.array([1.0] + [1.27344, 1.08594, 1.03125, 1.00781, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97266, 0.92188, 0.95703, 0.77734]),
|
42 |
+
50: np.array([1.0] + [1.27344, 1.0, 1.07031, 1.01562, 1.0, 1.02344, 1.0, 1.01562, 1.02344, 0.98828, 1.00781, 1.00781, 1.0, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76562]),
|
43 |
+
75: np.array([1.0] + [1.0, 1.26562, 1.00781, 1.07812, 1.0, 1.00781, 1.01562, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.03125, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98828, 1.01562, 1.03125, 0.97266, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96875, 0.96484, 0.95312, 0.89844, 0.90234, 0.86719, 0.76562]),
|
44 |
+
},
|
45 |
+
640: {
|
46 |
+
25: np.array([1.0] + [1.27344, 1.07031, 1.03906, 1.00781, 1.00781, 1.00781, 1.03125, 1.04688, 1.00781, 1.03125, 0.99219, 1.01562, 1.01562, 1.01562, 1.00781, 0.98438, 1.05469, 0.98438, 0.96875, 1.03125, 0.97266, 0.92578, 0.95703, 0.77734]),
|
47 |
+
50: np.array([1.0] + [1.27344, 1.0, 1.07812, 1.01562, 1.0, 1.01562, 1.00781, 1.00781, 1.02344, 0.98828, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.98047, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.01562, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76953]),
|
48 |
+
75: np.array([1.0] + [1.0, 1.27344, 1.0, 1.07031, 1.01562, 0.99609, 1.00781, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 1.0, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96484, 0.96484, 0.95312, 0.89844, 0.90234, 0.87109, 0.76953]),
|
49 |
+
},
|
50 |
+
512: {
|
51 |
+
25: np.array([1.0] + [1.28125, 1.08594, 1.02344, 1.01562, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97656, 0.92188, 0.96094, 0.77734]),
|
52 |
+
50: np.array([1.0] + [1.28125, 1.00781, 1.08594, 1.0, 1.01562, 1.01562, 1.00781, 1.00781, 1.02344, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.94922, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.83984, 0.76953]),
|
53 |
+
75: np.array([1.0] + [1.00781, 1.27344, 1.00781, 1.07812, 1.0, 1.00781, 1.00781, 0.99609, 1.01562, 0.99609, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.98828, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.00781, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96484, 1.0625, 0.98828, 1.0, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.95703, 1.00781, 0.96484, 1.02344, 0.98828, 0.99609, 0.95703, 0.96875, 1.03125, 0.96875, 0.96484, 0.95703, 0.89844, 0.90234, 0.87109, 0.76953]),
|
54 |
+
},
|
55 |
+
384: {
|
56 |
+
25: np.array([1.0] + [1.36719, 1.03125, 1.02344, 1.01562, 1.04688, 1.03125, 1.04688, 1.02344, 1.00781, 0.99609, 1.01562, 0.99219, 1.00781, 0.97266, 1.07812, 0.95703, 0.9375, 1.04688, 0.98828, 0.89844, 1.00781, 0.92188, 0.89844, 0.72656]),
|
57 |
+
50: np.array([1.0] + [1.27344, 1.08594, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.01562, 0.98047, 1.04688, 0.98438, 1.02344, 1.02344, 0.97656, 1.03125, 0.98828, 1.00781, 0.98828, 1.00781, 1.0, 0.99609, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.97266, 0.96094, 1.09375, 0.95703, 0.98438, 1.0, 0.96094, 0.93359, 1.03125, 0.97266, 1.0, 0.92188, 0.93359, 0.96484, 0.85156, 0.84375, 0.79688]),
|
58 |
+
75: np.array([1.0] + [1.26562, 1.08594, 1.00781, 1.00781, 1.01562, 1.00781, 1.00781, 1.03125, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.00781, 1.02344, 0.98828, 1.01562, 1.03125, 1.0, 1.01562, 0.98047, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.0, 1.02344, 1.00781, 1.0, 1.00781, 0.97266, 1.01562, 1.00781, 0.97656, 1.0625, 0.97656, 1.02344, 0.98828, 0.96484, 1.02344, 1.01562, 1.04688, 0.95312, 1.03125, 0.98047, 1.01562, 0.97266, 0.94922, 1.0625, 0.96484, 1.02344, 0.98438, 1.02344, 0.98047, 0.97266, 0.99219, 0.92969, 1.07031, 0.96094, 0.98047, 0.98438, 0.94531, 0.98828, 0.9375, 0.97266, 0.98828, 0.86719, 0.98047, 0.84766, 0.86328, 0.86719]),
|
59 |
+
},
|
60 |
+
256: {
|
61 |
+
25: np.array([1.0] + [1.38281, 1.04688, 1.05469, 1.03906, 1.03906, 1.01562, 1.0, 1.03906, 1.00781, 1.03125, 0.96094, 1.08594, 0.96094, 1.00781, 0.98438, 1.02344, 0.91016, 0.99609, 1.0, 0.90234, 0.97266, 0.87109, 0.85547, 0.71875]),
|
62 |
+
50: np.array([1.0] + [1.375, 1.02344, 1.02344, 1.01562, 1.00781, 1.04688, 1.03125, 1.00781, 1.03125, 1.00781, 1.0, 1.01562, 1.01562, 0.98438, 1.03125, 1.00781, 0.98438, 1.02344, 1.01562, 1.02344, 0.98047, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.99219, 1.01562, 0.99609, 0.98828, 0.99219, 1.03125, 0.96875, 0.94141, 1.05469, 0.94531, 0.95312, 1.04688, 0.94141, 0.96094, 1.00781, 0.96094, 0.95312, 0.91016, 0.91797, 0.93359, 0.89062, 0.8125, 0.83984]),
|
63 |
+
75: np.array([1.0] + [1.27344, 1.09375, 1.00781, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 1.0, 1.03125, 0.98828, 1.02344, 0.97266, 1.0, 1.02344, 1.03125, 0.98047, 0.99609, 1.02344, 0.99219, 0.98047, 1.0625, 0.99219, 1.00781, 0.98828, 0.96875, 1.0625, 0.98047, 1.04688, 0.95312, 1.03125, 0.98047, 1.03125, 0.94922, 1.03125, 0.99609, 1.03125, 0.95703, 1.0, 0.98438, 1.03906, 0.9375, 1.01562, 0.96094, 1.03125, 0.98828, 0.98047, 1.00781, 0.99219, 1.0, 0.91016, 1.01562, 0.97266, 1.00781, 0.98047, 0.98438, 0.97266, 0.89062, 1.00781, 0.95703, 0.95312, 0.94141, 0.9375, 0.93359, 0.82031, 0.91016, 0.87891]),
|
64 |
+
},
|
65 |
+
128: {
|
66 |
+
25: np.array([1.0] + [1.42188, 1.03125, 1.0625, 1.04688, 1.02344, 1.03125, 1.03125, 1.02344, 0.99219, 1.02344, 0.94531, 1.04688, 0.94922, 1.01562, 0.98047, 0.96484, 0.93359, 0.96484, 0.98438, 0.91406, 0.97266, 0.87891, 0.85547, 0.83203]),
|
67 |
+
50: np.array([1.0] + [1.375, 1.03125, 1.02344, 1.01562, 1.04688, 1.01562, 1.04688, 1.0, 1.04688, 0.98438, 1.05469, 0.97266, 1.01562, 1.01562, 0.98828, 1.03906, 1.01562, 0.97656, 0.98047, 1.03906, 0.97266, 0.97266, 1.0625, 0.99219, 0.97656, 0.97266, 1.02344, 0.99609, 0.94141, 1.03906, 0.97266, 0.99219, 0.96875, 0.96484, 1.00781, 0.95312, 1.03906, 0.94922, 0.99609, 0.91797, 1.00781, 0.96484, 0.9375, 0.93359, 0.91797, 0.92969, 0.91406, 0.90625, 0.92188]),
|
68 |
+
75: np.array([1.0] + [1.36719, 1.02344, 1.01562, 1.03906, 0.99219, 1.00781, 1.03906, 1.03125, 0.99219, 0.99609, 1.05469, 0.99609, 1.01562, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.01562, 1.00781, 1.0, 0.96875, 1.05469, 0.99219, 1.00781, 1.0, 1.0, 1.01562, 1.00781, 0.96875, 1.02344, 0.94922, 1.07812, 0.94922, 1.03125, 0.92578, 1.05469, 0.97266, 0.98438, 1.04688, 0.98438, 1.00781, 0.95312, 1.02344, 0.94922, 1.04688, 0.96875, 0.96484, 1.03906, 0.93359, 1.00781, 1.00781, 0.95312, 1.00781, 1.0, 0.97656, 1.0, 0.94922, 0.96094, 1.02344, 0.92969, 1.02344, 0.96094, 0.88281, 1.03125, 0.94141, 0.91797, 0.98438, 0.92578, 0.90234, 0.99219, 0.92188, 0.98438]),
|
69 |
+
},
|
70 |
+
}
|
71 |
+
}
|
diffusers_helper/pipelines/k_diffusion_hunyuan.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
|
5 |
+
from diffusers_helper.k_diffusion.wrapper import fm_wrapper
|
6 |
+
from diffusers_helper.utils import repeat_to_batch_size
|
7 |
+
|
8 |
+
|
9 |
+
def flux_time_shift(t, mu=1.15, sigma=1.0):
|
10 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
11 |
+
|
12 |
+
|
13 |
+
def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
|
14 |
+
k = (y2 - y1) / (x2 - x1)
|
15 |
+
b = y1 - k * x1
|
16 |
+
mu = k * context_length + b
|
17 |
+
mu = min(mu, math.log(exp_max))
|
18 |
+
return mu
|
19 |
+
|
20 |
+
|
21 |
+
def get_flux_sigmas_from_mu(n, mu):
|
22 |
+
sigmas = torch.linspace(1, 0, steps=n + 1)
|
23 |
+
sigmas = flux_time_shift(sigmas, mu=mu)
|
24 |
+
return sigmas
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
def sample_hunyuan(
|
29 |
+
transformer,
|
30 |
+
sampler='unipc',
|
31 |
+
initial_latent=None,
|
32 |
+
concat_latent=None,
|
33 |
+
strength=1.0,
|
34 |
+
width=512,
|
35 |
+
height=512,
|
36 |
+
frames=16,
|
37 |
+
real_guidance_scale=1.0,
|
38 |
+
distilled_guidance_scale=6.0,
|
39 |
+
guidance_rescale=0.0,
|
40 |
+
shift=None,
|
41 |
+
num_inference_steps=25,
|
42 |
+
batch_size=None,
|
43 |
+
generator=None,
|
44 |
+
prompt_embeds=None,
|
45 |
+
prompt_embeds_mask=None,
|
46 |
+
prompt_poolers=None,
|
47 |
+
negative_prompt_embeds=None,
|
48 |
+
negative_prompt_embeds_mask=None,
|
49 |
+
negative_prompt_poolers=None,
|
50 |
+
dtype=torch.bfloat16,
|
51 |
+
device=None,
|
52 |
+
negative_kwargs=None,
|
53 |
+
callback=None,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
device = device or transformer.device
|
57 |
+
|
58 |
+
if batch_size is None:
|
59 |
+
batch_size = int(prompt_embeds.shape[0])
|
60 |
+
|
61 |
+
latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
|
62 |
+
|
63 |
+
B, C, T, H, W = latents.shape
|
64 |
+
seq_length = T * H * W // 4
|
65 |
+
|
66 |
+
if shift is None:
|
67 |
+
mu = calculate_flux_mu(seq_length, exp_max=7.0)
|
68 |
+
else:
|
69 |
+
mu = math.log(shift)
|
70 |
+
|
71 |
+
sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
|
72 |
+
|
73 |
+
k_model = fm_wrapper(transformer)
|
74 |
+
|
75 |
+
if initial_latent is not None:
|
76 |
+
sigmas = sigmas * strength
|
77 |
+
first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
|
78 |
+
initial_latent = initial_latent.to(device=device, dtype=torch.float32)
|
79 |
+
latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
|
80 |
+
|
81 |
+
if concat_latent is not None:
|
82 |
+
concat_latent = concat_latent.to(latents)
|
83 |
+
|
84 |
+
distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
|
85 |
+
|
86 |
+
prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
|
87 |
+
prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
|
88 |
+
prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
|
89 |
+
negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
|
90 |
+
negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
|
91 |
+
negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
|
92 |
+
concat_latent = repeat_to_batch_size(concat_latent, batch_size)
|
93 |
+
|
94 |
+
sampler_kwargs = dict(
|
95 |
+
dtype=dtype,
|
96 |
+
cfg_scale=real_guidance_scale,
|
97 |
+
cfg_rescale=guidance_rescale,
|
98 |
+
concat_latent=concat_latent,
|
99 |
+
positive=dict(
|
100 |
+
pooled_projections=prompt_poolers,
|
101 |
+
encoder_hidden_states=prompt_embeds,
|
102 |
+
encoder_attention_mask=prompt_embeds_mask,
|
103 |
+
guidance=distilled_guidance,
|
104 |
+
**kwargs,
|
105 |
+
),
|
106 |
+
negative=dict(
|
107 |
+
pooled_projections=negative_prompt_poolers,
|
108 |
+
encoder_hidden_states=negative_prompt_embeds,
|
109 |
+
encoder_attention_mask=negative_prompt_embeds_mask,
|
110 |
+
guidance=distilled_guidance,
|
111 |
+
**(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
if sampler == 'unipc':
|
116 |
+
results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError(f'Sampler {sampler} is not supported.')
|
119 |
+
|
120 |
+
return results
|
diffusers_helper/thread_utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
from threading import Thread, Lock
|
4 |
+
|
5 |
+
|
6 |
+
class Listener:
|
7 |
+
task_queue = []
|
8 |
+
lock = Lock()
|
9 |
+
thread = None
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def _process_tasks(cls):
|
13 |
+
while True:
|
14 |
+
task = None
|
15 |
+
with cls.lock:
|
16 |
+
if cls.task_queue:
|
17 |
+
task = cls.task_queue.pop(0)
|
18 |
+
|
19 |
+
if task is None:
|
20 |
+
time.sleep(0.001)
|
21 |
+
continue
|
22 |
+
|
23 |
+
func, args, kwargs = task
|
24 |
+
try:
|
25 |
+
func(*args, **kwargs)
|
26 |
+
except Exception as e:
|
27 |
+
print(f"Error in listener thread: {e}")
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def add_task(cls, func, *args, **kwargs):
|
31 |
+
with cls.lock:
|
32 |
+
cls.task_queue.append((func, args, kwargs))
|
33 |
+
|
34 |
+
if cls.thread is None:
|
35 |
+
cls.thread = Thread(target=cls._process_tasks, daemon=True)
|
36 |
+
cls.thread.start()
|
37 |
+
|
38 |
+
|
39 |
+
def async_run(func, *args, **kwargs):
|
40 |
+
Listener.add_task(func, *args, **kwargs)
|
41 |
+
|
42 |
+
|
43 |
+
class FIFOQueue:
|
44 |
+
def __init__(self):
|
45 |
+
self.queue = []
|
46 |
+
self.lock = Lock()
|
47 |
+
|
48 |
+
def push(self, item):
|
49 |
+
with self.lock:
|
50 |
+
self.queue.append(item)
|
51 |
+
|
52 |
+
def pop(self):
|
53 |
+
with self.lock:
|
54 |
+
if self.queue:
|
55 |
+
return self.queue.pop(0)
|
56 |
+
return None
|
57 |
+
|
58 |
+
def top(self):
|
59 |
+
with self.lock:
|
60 |
+
if self.queue:
|
61 |
+
return self.queue[0]
|
62 |
+
return None
|
63 |
+
|
64 |
+
def next(self):
|
65 |
+
while True:
|
66 |
+
with self.lock:
|
67 |
+
if self.queue:
|
68 |
+
return self.queue.pop(0)
|
69 |
+
|
70 |
+
time.sleep(0.001)
|
71 |
+
|
72 |
+
|
73 |
+
class AsyncStream:
|
74 |
+
def __init__(self):
|
75 |
+
self.input_queue = FIFOQueue()
|
76 |
+
self.output_queue = FIFOQueue()
|
diffusers_helper/utils.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import glob
|
6 |
+
import torch
|
7 |
+
import einops
|
8 |
+
import numpy as np
|
9 |
+
import datetime
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
import safetensors.torch as sf
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
def min_resize(x, m):
|
17 |
+
if x.shape[0] < x.shape[1]:
|
18 |
+
s0 = m
|
19 |
+
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
20 |
+
else:
|
21 |
+
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
22 |
+
s1 = m
|
23 |
+
new_max = max(s1, s0)
|
24 |
+
raw_max = max(x.shape[0], x.shape[1])
|
25 |
+
if new_max < raw_max:
|
26 |
+
interpolation = cv2.INTER_AREA
|
27 |
+
else:
|
28 |
+
interpolation = cv2.INTER_LANCZOS4
|
29 |
+
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
|
30 |
+
return y
|
31 |
+
|
32 |
+
|
33 |
+
def d_resize(x, y):
|
34 |
+
H, W, C = y.shape
|
35 |
+
new_min = min(H, W)
|
36 |
+
raw_min = min(x.shape[0], x.shape[1])
|
37 |
+
if new_min < raw_min:
|
38 |
+
interpolation = cv2.INTER_AREA
|
39 |
+
else:
|
40 |
+
interpolation = cv2.INTER_LANCZOS4
|
41 |
+
y = cv2.resize(x, (W, H), interpolation=interpolation)
|
42 |
+
return y
|
43 |
+
|
44 |
+
|
45 |
+
def resize_and_center_crop(image, target_width, target_height):
|
46 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
47 |
+
return image
|
48 |
+
|
49 |
+
pil_image = Image.fromarray(image)
|
50 |
+
original_width, original_height = pil_image.size
|
51 |
+
scale_factor = max(target_width / original_width, target_height / original_height)
|
52 |
+
resized_width = int(round(original_width * scale_factor))
|
53 |
+
resized_height = int(round(original_height * scale_factor))
|
54 |
+
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
|
55 |
+
left = (resized_width - target_width) / 2
|
56 |
+
top = (resized_height - target_height) / 2
|
57 |
+
right = (resized_width + target_width) / 2
|
58 |
+
bottom = (resized_height + target_height) / 2
|
59 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
60 |
+
return np.array(cropped_image)
|
61 |
+
|
62 |
+
|
63 |
+
def resize_and_center_crop_pytorch(image, target_width, target_height):
|
64 |
+
B, C, H, W = image.shape
|
65 |
+
|
66 |
+
if H == target_height and W == target_width:
|
67 |
+
return image
|
68 |
+
|
69 |
+
scale_factor = max(target_width / W, target_height / H)
|
70 |
+
resized_width = int(round(W * scale_factor))
|
71 |
+
resized_height = int(round(H * scale_factor))
|
72 |
+
|
73 |
+
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
|
74 |
+
|
75 |
+
top = (resized_height - target_height) // 2
|
76 |
+
left = (resized_width - target_width) // 2
|
77 |
+
cropped = resized[:, :, top:top + target_height, left:left + target_width]
|
78 |
+
|
79 |
+
return cropped
|
80 |
+
|
81 |
+
|
82 |
+
def resize_without_crop(image, target_width, target_height):
|
83 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
84 |
+
return image
|
85 |
+
|
86 |
+
pil_image = Image.fromarray(image)
|
87 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
88 |
+
return np.array(resized_image)
|
89 |
+
|
90 |
+
|
91 |
+
def just_crop(image, w, h):
|
92 |
+
if h == image.shape[0] and w == image.shape[1]:
|
93 |
+
return image
|
94 |
+
|
95 |
+
original_height, original_width = image.shape[:2]
|
96 |
+
k = min(original_height / h, original_width / w)
|
97 |
+
new_width = int(round(w * k))
|
98 |
+
new_height = int(round(h * k))
|
99 |
+
x_start = (original_width - new_width) // 2
|
100 |
+
y_start = (original_height - new_height) // 2
|
101 |
+
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
|
102 |
+
return cropped_image
|
103 |
+
|
104 |
+
|
105 |
+
def write_to_json(data, file_path):
|
106 |
+
temp_file_path = file_path + ".tmp"
|
107 |
+
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
108 |
+
json.dump(data, temp_file, indent=4)
|
109 |
+
os.replace(temp_file_path, file_path)
|
110 |
+
return
|
111 |
+
|
112 |
+
|
113 |
+
def read_from_json(file_path):
|
114 |
+
with open(file_path, 'rt', encoding='utf-8') as file:
|
115 |
+
data = json.load(file)
|
116 |
+
return data
|
117 |
+
|
118 |
+
|
119 |
+
def get_active_parameters(m):
|
120 |
+
return {k: v for k, v in m.named_parameters() if v.requires_grad}
|
121 |
+
|
122 |
+
|
123 |
+
def cast_training_params(m, dtype=torch.float32):
|
124 |
+
result = {}
|
125 |
+
for n, param in m.named_parameters():
|
126 |
+
if param.requires_grad:
|
127 |
+
param.data = param.to(dtype)
|
128 |
+
result[n] = param
|
129 |
+
return result
|
130 |
+
|
131 |
+
|
132 |
+
def separate_lora_AB(parameters, B_patterns=None):
|
133 |
+
parameters_normal = {}
|
134 |
+
parameters_B = {}
|
135 |
+
|
136 |
+
if B_patterns is None:
|
137 |
+
B_patterns = ['.lora_B.', '__zero__']
|
138 |
+
|
139 |
+
for k, v in parameters.items():
|
140 |
+
if any(B_pattern in k for B_pattern in B_patterns):
|
141 |
+
parameters_B[k] = v
|
142 |
+
else:
|
143 |
+
parameters_normal[k] = v
|
144 |
+
|
145 |
+
return parameters_normal, parameters_B
|
146 |
+
|
147 |
+
|
148 |
+
def set_attr_recursive(obj, attr, value):
|
149 |
+
attrs = attr.split(".")
|
150 |
+
for name in attrs[:-1]:
|
151 |
+
obj = getattr(obj, name)
|
152 |
+
setattr(obj, attrs[-1], value)
|
153 |
+
return
|
154 |
+
|
155 |
+
|
156 |
+
def print_tensor_list_size(tensors):
|
157 |
+
total_size = 0
|
158 |
+
total_elements = 0
|
159 |
+
|
160 |
+
if isinstance(tensors, dict):
|
161 |
+
tensors = tensors.values()
|
162 |
+
|
163 |
+
for tensor in tensors:
|
164 |
+
total_size += tensor.nelement() * tensor.element_size()
|
165 |
+
total_elements += tensor.nelement()
|
166 |
+
|
167 |
+
total_size_MB = total_size / (1024 ** 2)
|
168 |
+
total_elements_B = total_elements / 1e9
|
169 |
+
|
170 |
+
print(f"Total number of tensors: {len(tensors)}")
|
171 |
+
print(f"Total size of tensors: {total_size_MB:.2f} MB")
|
172 |
+
print(f"Total number of parameters: {total_elements_B:.3f} billion")
|
173 |
+
return
|
174 |
+
|
175 |
+
|
176 |
+
@torch.no_grad()
|
177 |
+
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
|
178 |
+
batch_size = a.size(0)
|
179 |
+
|
180 |
+
if b is None:
|
181 |
+
b = torch.zeros_like(a)
|
182 |
+
|
183 |
+
if mask_a is None:
|
184 |
+
mask_a = torch.rand(batch_size) < probability_a
|
185 |
+
|
186 |
+
mask_a = mask_a.to(a.device)
|
187 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
188 |
+
result = torch.where(mask_a, a, b)
|
189 |
+
return result
|
190 |
+
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def zero_module(module):
|
194 |
+
for p in module.parameters():
|
195 |
+
p.detach().zero_()
|
196 |
+
return module
|
197 |
+
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def supress_lower_channels(m, k, alpha=0.01):
|
201 |
+
data = m.weight.data.clone()
|
202 |
+
|
203 |
+
assert int(data.shape[1]) >= k
|
204 |
+
|
205 |
+
data[:, :k] = data[:, :k] * alpha
|
206 |
+
m.weight.data = data.contiguous().clone()
|
207 |
+
return m
|
208 |
+
|
209 |
+
|
210 |
+
def freeze_module(m):
|
211 |
+
if not hasattr(m, '_forward_inside_frozen_module'):
|
212 |
+
m._forward_inside_frozen_module = m.forward
|
213 |
+
m.requires_grad_(False)
|
214 |
+
m.forward = torch.no_grad()(m.forward)
|
215 |
+
return m
|
216 |
+
|
217 |
+
|
218 |
+
def get_latest_safetensors(folder_path):
|
219 |
+
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
|
220 |
+
|
221 |
+
if not safetensors_files:
|
222 |
+
raise ValueError('No file to resume!')
|
223 |
+
|
224 |
+
latest_file = max(safetensors_files, key=os.path.getmtime)
|
225 |
+
latest_file = os.path.abspath(os.path.realpath(latest_file))
|
226 |
+
return latest_file
|
227 |
+
|
228 |
+
|
229 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
230 |
+
tags = tags_str.split(', ')
|
231 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
232 |
+
prompt = ', '.join(tags)
|
233 |
+
return prompt
|
234 |
+
|
235 |
+
|
236 |
+
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
|
237 |
+
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
|
238 |
+
if round_to_int:
|
239 |
+
numbers = np.round(numbers).astype(int)
|
240 |
+
return numbers.tolist()
|
241 |
+
|
242 |
+
|
243 |
+
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
|
244 |
+
edges = np.linspace(0, 1, n + 1)
|
245 |
+
points = np.random.uniform(edges[:-1], edges[1:])
|
246 |
+
numbers = inclusive + (exclusive - inclusive) * points
|
247 |
+
if round_to_int:
|
248 |
+
numbers = np.round(numbers).astype(int)
|
249 |
+
return numbers.tolist()
|
250 |
+
|
251 |
+
|
252 |
+
def soft_append_bcthw(history, current, overlap=0):
|
253 |
+
if overlap <= 0:
|
254 |
+
return torch.cat([history, current], dim=2)
|
255 |
+
|
256 |
+
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
|
257 |
+
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
|
258 |
+
|
259 |
+
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
|
260 |
+
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
|
261 |
+
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
|
262 |
+
|
263 |
+
return output.to(history)
|
264 |
+
|
265 |
+
|
266 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
|
267 |
+
b, c, t, h, w = x.shape
|
268 |
+
|
269 |
+
per_row = b
|
270 |
+
for p in [6, 5, 4, 3, 2]:
|
271 |
+
if b % p == 0:
|
272 |
+
per_row = p
|
273 |
+
break
|
274 |
+
|
275 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
276 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
277 |
+
x = x.detach().cpu().to(torch.uint8)
|
278 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
279 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
|
280 |
+
return x
|
281 |
+
|
282 |
+
|
283 |
+
def save_bcthw_as_png(x, output_filename):
|
284 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
285 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
286 |
+
x = x.detach().cpu().to(torch.uint8)
|
287 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
288 |
+
torchvision.io.write_png(x, output_filename)
|
289 |
+
return output_filename
|
290 |
+
|
291 |
+
|
292 |
+
def save_bchw_as_png(x, output_filename):
|
293 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
294 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
295 |
+
x = x.detach().cpu().to(torch.uint8)
|
296 |
+
x = einops.rearrange(x, 'b c h w -> c h (b w)')
|
297 |
+
torchvision.io.write_png(x, output_filename)
|
298 |
+
return output_filename
|
299 |
+
|
300 |
+
|
301 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
302 |
+
if tensor1.shape == tensor2.shape:
|
303 |
+
return tensor1 + tensor2
|
304 |
+
|
305 |
+
shape1 = tensor1.shape
|
306 |
+
shape2 = tensor2.shape
|
307 |
+
|
308 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
309 |
+
|
310 |
+
padded_tensor1 = torch.zeros(new_shape)
|
311 |
+
padded_tensor2 = torch.zeros(new_shape)
|
312 |
+
|
313 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
314 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
315 |
+
|
316 |
+
result = padded_tensor1 + padded_tensor2
|
317 |
+
return result
|
318 |
+
|
319 |
+
|
320 |
+
def print_free_mem():
|
321 |
+
torch.cuda.empty_cache()
|
322 |
+
free_mem, total_mem = torch.cuda.mem_get_info(0)
|
323 |
+
free_mem_mb = free_mem / (1024 ** 2)
|
324 |
+
total_mem_mb = total_mem / (1024 ** 2)
|
325 |
+
print(f"Free memory: {free_mem_mb:.2f} MB")
|
326 |
+
print(f"Total memory: {total_mem_mb:.2f} MB")
|
327 |
+
return
|
328 |
+
|
329 |
+
|
330 |
+
def print_gpu_parameters(device, state_dict, log_count=1):
|
331 |
+
summary = {"device": device, "keys_count": len(state_dict)}
|
332 |
+
|
333 |
+
logged_params = {}
|
334 |
+
for i, (key, tensor) in enumerate(state_dict.items()):
|
335 |
+
if i >= log_count:
|
336 |
+
break
|
337 |
+
logged_params[key] = tensor.flatten()[:3].tolist()
|
338 |
+
|
339 |
+
summary["params"] = logged_params
|
340 |
+
|
341 |
+
print(str(summary))
|
342 |
+
return
|
343 |
+
|
344 |
+
|
345 |
+
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
|
346 |
+
from PIL import Image, ImageDraw, ImageFont
|
347 |
+
|
348 |
+
txt = Image.new("RGB", (width, height), color="white")
|
349 |
+
draw = ImageDraw.Draw(txt)
|
350 |
+
font = ImageFont.truetype(font_path, size=size)
|
351 |
+
|
352 |
+
if text == '':
|
353 |
+
return np.array(txt)
|
354 |
+
|
355 |
+
# Split text into lines that fit within the image width
|
356 |
+
lines = []
|
357 |
+
words = text.split()
|
358 |
+
current_line = words[0]
|
359 |
+
|
360 |
+
for word in words[1:]:
|
361 |
+
line_with_word = f"{current_line} {word}"
|
362 |
+
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
|
363 |
+
current_line = line_with_word
|
364 |
+
else:
|
365 |
+
lines.append(current_line)
|
366 |
+
current_line = word
|
367 |
+
|
368 |
+
lines.append(current_line)
|
369 |
+
|
370 |
+
# Draw the text line by line
|
371 |
+
y = 0
|
372 |
+
line_height = draw.textbbox((0, 0), "A", font=font)[3]
|
373 |
+
|
374 |
+
for line in lines:
|
375 |
+
if y + line_height > height:
|
376 |
+
break # stop drawing if the next line will be outside the image
|
377 |
+
draw.text((0, y), line, fill="black", font=font)
|
378 |
+
y += line_height
|
379 |
+
|
380 |
+
return np.array(txt)
|
381 |
+
|
382 |
+
|
383 |
+
def blue_mark(x):
|
384 |
+
x = x.copy()
|
385 |
+
c = x[:, :, 2]
|
386 |
+
b = cv2.blur(c, (9, 9))
|
387 |
+
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
|
388 |
+
return x
|
389 |
+
|
390 |
+
|
391 |
+
def green_mark(x):
|
392 |
+
x = x.copy()
|
393 |
+
x[:, :, 2] = -1
|
394 |
+
x[:, :, 0] = -1
|
395 |
+
return x
|
396 |
+
|
397 |
+
|
398 |
+
def frame_mark(x):
|
399 |
+
x = x.copy()
|
400 |
+
x[:64] = -1
|
401 |
+
x[-64:] = -1
|
402 |
+
x[:, :8] = 1
|
403 |
+
x[:, -8:] = 1
|
404 |
+
return x
|
405 |
+
|
406 |
+
|
407 |
+
@torch.inference_mode()
|
408 |
+
def pytorch2numpy(imgs):
|
409 |
+
results = []
|
410 |
+
for x in imgs:
|
411 |
+
y = x.movedim(0, -1)
|
412 |
+
y = y * 127.5 + 127.5
|
413 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
414 |
+
results.append(y)
|
415 |
+
return results
|
416 |
+
|
417 |
+
|
418 |
+
@torch.inference_mode()
|
419 |
+
def numpy2pytorch(imgs):
|
420 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
421 |
+
h = h.movedim(-1, 1)
|
422 |
+
return h
|
423 |
+
|
424 |
+
|
425 |
+
@torch.no_grad()
|
426 |
+
def duplicate_prefix_to_suffix(x, count, zero_out=False):
|
427 |
+
if zero_out:
|
428 |
+
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
|
429 |
+
else:
|
430 |
+
return torch.cat([x, x[:count]], dim=0)
|
431 |
+
|
432 |
+
|
433 |
+
def weighted_mse(a, b, weight):
|
434 |
+
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
|
435 |
+
|
436 |
+
|
437 |
+
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
|
438 |
+
x = (x - x_min) / (x_max - x_min)
|
439 |
+
x = max(0.0, min(x, 1.0))
|
440 |
+
x = x ** sigma
|
441 |
+
return y_min + x * (y_max - y_min)
|
442 |
+
|
443 |
+
|
444 |
+
def expand_to_dims(x, target_dims):
|
445 |
+
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
|
446 |
+
|
447 |
+
|
448 |
+
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
449 |
+
if tensor is None:
|
450 |
+
return None
|
451 |
+
|
452 |
+
first_dim = tensor.shape[0]
|
453 |
+
|
454 |
+
if first_dim == batch_size:
|
455 |
+
return tensor
|
456 |
+
|
457 |
+
if batch_size % first_dim != 0:
|
458 |
+
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
|
459 |
+
|
460 |
+
repeat_times = batch_size // first_dim
|
461 |
+
|
462 |
+
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
|
463 |
+
|
464 |
+
|
465 |
+
def dim5(x):
|
466 |
+
return expand_to_dims(x, 5)
|
467 |
+
|
468 |
+
|
469 |
+
def dim4(x):
|
470 |
+
return expand_to_dims(x, 4)
|
471 |
+
|
472 |
+
|
473 |
+
def dim3(x):
|
474 |
+
return expand_to_dims(x, 3)
|
475 |
+
|
476 |
+
|
477 |
+
def crop_or_pad_yield_mask(x, length):
|
478 |
+
B, F, C = x.shape
|
479 |
+
device = x.device
|
480 |
+
dtype = x.dtype
|
481 |
+
|
482 |
+
if F < length:
|
483 |
+
y = torch.zeros((B, length, C), dtype=dtype, device=device)
|
484 |
+
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
|
485 |
+
y[:, :F, :] = x
|
486 |
+
mask[:, :F] = True
|
487 |
+
return y, mask
|
488 |
+
|
489 |
+
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
|
490 |
+
|
491 |
+
|
492 |
+
def extend_dim(x, dim, minimal_length, zero_pad=False):
|
493 |
+
original_length = int(x.shape[dim])
|
494 |
+
|
495 |
+
if original_length >= minimal_length:
|
496 |
+
return x
|
497 |
+
|
498 |
+
if zero_pad:
|
499 |
+
padding_shape = list(x.shape)
|
500 |
+
padding_shape[dim] = minimal_length - original_length
|
501 |
+
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
|
502 |
+
else:
|
503 |
+
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
|
504 |
+
last_element = x[idx]
|
505 |
+
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
|
506 |
+
|
507 |
+
return torch.cat([x, padding], dim=dim)
|
508 |
+
|
509 |
+
|
510 |
+
def lazy_positional_encoding(t, repeats=None):
|
511 |
+
if not isinstance(t, list):
|
512 |
+
t = [t]
|
513 |
+
|
514 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
515 |
+
|
516 |
+
te = torch.tensor(t)
|
517 |
+
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
|
518 |
+
|
519 |
+
if repeats is None:
|
520 |
+
return te
|
521 |
+
|
522 |
+
te = te[:, None, :].expand(-1, repeats, -1)
|
523 |
+
|
524 |
+
return te
|
525 |
+
|
526 |
+
|
527 |
+
def state_dict_offset_merge(A, B, C=None):
|
528 |
+
result = {}
|
529 |
+
keys = A.keys()
|
530 |
+
|
531 |
+
for key in keys:
|
532 |
+
A_value = A[key]
|
533 |
+
B_value = B[key].to(A_value)
|
534 |
+
|
535 |
+
if C is None:
|
536 |
+
result[key] = A_value + B_value
|
537 |
+
else:
|
538 |
+
C_value = C[key].to(A_value)
|
539 |
+
result[key] = A_value + B_value - C_value
|
540 |
+
|
541 |
+
return result
|
542 |
+
|
543 |
+
|
544 |
+
def state_dict_weighted_merge(state_dicts, weights):
|
545 |
+
if len(state_dicts) != len(weights):
|
546 |
+
raise ValueError("Number of state dictionaries must match number of weights")
|
547 |
+
|
548 |
+
if not state_dicts:
|
549 |
+
return {}
|
550 |
+
|
551 |
+
total_weight = sum(weights)
|
552 |
+
|
553 |
+
if total_weight == 0:
|
554 |
+
raise ValueError("Sum of weights cannot be zero")
|
555 |
+
|
556 |
+
normalized_weights = [w / total_weight for w in weights]
|
557 |
+
|
558 |
+
keys = state_dicts[0].keys()
|
559 |
+
result = {}
|
560 |
+
|
561 |
+
for key in keys:
|
562 |
+
result[key] = state_dicts[0][key] * normalized_weights[0]
|
563 |
+
|
564 |
+
for i in range(1, len(state_dicts)):
|
565 |
+
state_dict_value = state_dicts[i][key].to(result[key])
|
566 |
+
result[key] += state_dict_value * normalized_weights[i]
|
567 |
+
|
568 |
+
return result
|
569 |
+
|
570 |
+
|
571 |
+
def group_files_by_folder(all_files):
|
572 |
+
grouped_files = {}
|
573 |
+
|
574 |
+
for file in all_files:
|
575 |
+
folder_name = os.path.basename(os.path.dirname(file))
|
576 |
+
if folder_name not in grouped_files:
|
577 |
+
grouped_files[folder_name] = []
|
578 |
+
grouped_files[folder_name].append(file)
|
579 |
+
|
580 |
+
list_of_lists = list(grouped_files.values())
|
581 |
+
return list_of_lists
|
582 |
+
|
583 |
+
|
584 |
+
def generate_timestamp():
|
585 |
+
now = datetime.datetime.now()
|
586 |
+
timestamp = now.strftime('%y%m%d_%H%M%S')
|
587 |
+
milliseconds = f"{int(now.microsecond / 1000):03d}"
|
588 |
+
random_number = random.randint(0, 9999)
|
589 |
+
return f"{timestamp}_{milliseconds}_{random_number}"
|
590 |
+
|
591 |
+
|
592 |
+
def write_PIL_image_with_png_info(image, metadata, path):
|
593 |
+
from PIL.PngImagePlugin import PngInfo
|
594 |
+
|
595 |
+
png_info = PngInfo()
|
596 |
+
for key, value in metadata.items():
|
597 |
+
png_info.add_text(key, value)
|
598 |
+
|
599 |
+
image.save(path, "PNG", pnginfo=png_info)
|
600 |
+
return image
|
601 |
+
|
602 |
+
|
603 |
+
def torch_safe_save(content, path):
|
604 |
+
torch.save(content, path + '_tmp')
|
605 |
+
os.replace(path + '_tmp', path)
|
606 |
+
return path
|
607 |
+
|
608 |
+
|
609 |
+
def move_optimizer_to_device(optimizer, device):
|
610 |
+
for state in optimizer.state.values():
|
611 |
+
for k, v in state.items():
|
612 |
+
if isinstance(v, torch.Tensor):
|
613 |
+
state[k] = v.to(device)
|
docker-compose.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
studio:
|
3 |
+
build:
|
4 |
+
# modify this if you are building the image locally and need a different CUDA version
|
5 |
+
args:
|
6 |
+
- CUDA_VERSION=12.4.1
|
7 |
+
# modify the tag here if you need a different CUDA version or branch
|
8 |
+
image: colinurbs/fp-studio:cuda12.4-latest-develop
|
9 |
+
restart: unless-stopped
|
10 |
+
ports:
|
11 |
+
- "7860:7860"
|
12 |
+
volumes:
|
13 |
+
- "./loras:/app/loras"
|
14 |
+
- "./outputs:/app/outputs"
|
15 |
+
- "./.framepack:/app/.framepack"
|
16 |
+
- "./modules/toolbox/model_esrgan:/app/modules/toolbox/model_esrgan"
|
17 |
+
- "./modules/toolbox/model_rife:/app/modules/toolbox/model_rife"
|
18 |
+
- "$HOME/.cache/huggingface:/app/hf_download"
|
19 |
+
deploy:
|
20 |
+
resources:
|
21 |
+
reservations:
|
22 |
+
devices:
|
23 |
+
- driver: nvidia
|
24 |
+
count: 1
|
25 |
+
capabilities: [gpu]
|
install.bat
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
echo FramePack-Studio Setup Script
|
3 |
+
setlocal enabledelayedexpansion
|
4 |
+
|
5 |
+
REM Check if Python is installed (basic check)
|
6 |
+
where python >nul 2>&1
|
7 |
+
if %errorlevel% neq 0 (
|
8 |
+
echo Error: Python is not installed or not in your PATH. Please install Python and try again.
|
9 |
+
goto end
|
10 |
+
)
|
11 |
+
|
12 |
+
if exist "%cd%/venv" (
|
13 |
+
echo Virtual Environment already exists.
|
14 |
+
set /p choice= "Do you want to reinstall packages?[Y/N]: "
|
15 |
+
|
16 |
+
if "!choice!" == "y" (goto checkgpu)
|
17 |
+
if "!choice!"=="Y" (goto checkgpu)
|
18 |
+
|
19 |
+
goto end
|
20 |
+
)
|
21 |
+
|
22 |
+
REM Check the python version
|
23 |
+
echo Python versions 3.10-3.12 have been confirmed to work. Other versions are currently not supported. You currently have:
|
24 |
+
python -V
|
25 |
+
set choice=
|
26 |
+
set /p choice= "Do you want to continue?[Y/N]: "
|
27 |
+
|
28 |
+
|
29 |
+
if "!choice!" == "y" (goto makevenv)
|
30 |
+
if "!choice!"=="Y" (goto makevenv)
|
31 |
+
|
32 |
+
goto end
|
33 |
+
|
34 |
+
:makevenv
|
35 |
+
REM This creates a virtual environment in the folder
|
36 |
+
echo Creating a Virtual Environment...
|
37 |
+
python -m venv venv
|
38 |
+
echo Upgrading pip in Virtual Environment to lower chance of error...
|
39 |
+
"%cd%/venv/Scripts/python.exe" -m pip install --upgrade pip
|
40 |
+
|
41 |
+
:checkgpu
|
42 |
+
REM ask Windows for GPU
|
43 |
+
where nvidia-smi >nul 2>&1
|
44 |
+
if %errorlevel% neq 0 (
|
45 |
+
echo Error: Nvidia GPU doesn't exist or drivers installed incorrectly. Please confirm your drivers are installed.
|
46 |
+
goto end
|
47 |
+
)
|
48 |
+
|
49 |
+
echo Checking your GPU...
|
50 |
+
|
51 |
+
for /F "tokens=* skip=1" %%n in ('nvidia-smi --query-gpu=name') do set GPU_NAME=%%n && goto gpuchecked
|
52 |
+
|
53 |
+
:gpuchecked
|
54 |
+
echo Detected %GPU_NAME%
|
55 |
+
set "GPU_SERIES=%GPU_NAME:*RTX =%"
|
56 |
+
set "GPU_SERIES=%GPU_SERIES:~0,2%00"
|
57 |
+
|
58 |
+
REM This gets the shortened Python version for later use. e.g. 3.10.13 becomes 310.
|
59 |
+
for /f "delims=" %%A in ('python -V') do set "pyv=%%A"
|
60 |
+
for /f "tokens=2 delims= " %%A in ("%pyv%") do (
|
61 |
+
set pyv=%%A
|
62 |
+
)
|
63 |
+
set pyv=%pyv:.=%
|
64 |
+
set pyv=%pyv:~0,3%
|
65 |
+
|
66 |
+
echo Installing torch...
|
67 |
+
|
68 |
+
if !GPU_SERIES! geq 5000 (
|
69 |
+
goto torch270
|
70 |
+
) else (
|
71 |
+
goto torch260
|
72 |
+
)
|
73 |
+
|
74 |
+
REM RTX 5000 Series
|
75 |
+
:torch270
|
76 |
+
"%cd%/venv/Scripts/pip.exe" install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --force-reinstall
|
77 |
+
REM Check if pip installation was successful
|
78 |
+
if %errorlevel% neq 0 (
|
79 |
+
echo Warning: Failed to install dependencies. You may need to install them manually.
|
80 |
+
goto end
|
81 |
+
)
|
82 |
+
|
83 |
+
REM Ask if user wants Sage Attention
|
84 |
+
set choice=
|
85 |
+
echo Do you want to install any of the following? They speed up generation.
|
86 |
+
echo 1) Sage Attention
|
87 |
+
echo 2) Flash Attention
|
88 |
+
echo 3) BOTH!
|
89 |
+
echo 4) No
|
90 |
+
set /p choice= "Input Selection: "
|
91 |
+
|
92 |
+
set both="N"
|
93 |
+
|
94 |
+
if "!choice!" == "1" (goto triton270)
|
95 |
+
if "!choice!"== "2" (goto flash270)
|
96 |
+
if "!choice!"== "3" (set both="Y"
|
97 |
+
goto triton270
|
98 |
+
)
|
99 |
+
|
100 |
+
goto requirements
|
101 |
+
|
102 |
+
:triton270
|
103 |
+
REM Sage Attention and Triton for Torch 2.7.0
|
104 |
+
"%cd%/venv/Scripts/pip.exe" install "triton-windows<3.4" --force-reinstall
|
105 |
+
"%cd%/venv/Scripts/pip.exe" install "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp%pyv%-cp%pyv%-win_amd64.whl" --force-reinstall
|
106 |
+
echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder...
|
107 |
+
|
108 |
+
REM Check for python version and download the triton-windows required libs accordingly
|
109 |
+
if %pyv% == 310 (
|
110 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')"
|
111 |
+
)
|
112 |
+
|
113 |
+
if %pyv% == 311 (
|
114 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')"
|
115 |
+
)
|
116 |
+
|
117 |
+
if %pyv% == 312 (
|
118 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')"
|
119 |
+
)
|
120 |
+
|
121 |
+
REM Extract the zip into the Python Folder and Delete zip
|
122 |
+
powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force
|
123 |
+
del triton-lib.zip
|
124 |
+
if %both% == "Y" (goto flash270)
|
125 |
+
|
126 |
+
goto requirements
|
127 |
+
|
128 |
+
:flash270
|
129 |
+
REM Install flash-attn.
|
130 |
+
"%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1%%2Bcu128torch2.7.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true"
|
131 |
+
goto requirements
|
132 |
+
|
133 |
+
|
134 |
+
REM RTX 4000 Series and below
|
135 |
+
:torch260
|
136 |
+
"%cd%/venv/Scripts/pip.exe" install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 --force-reinstall
|
137 |
+
REM Check if pip installation was successful
|
138 |
+
if %errorlevel% neq 0 (
|
139 |
+
echo Warning: Failed to install dependencies. You may need to install them manually.
|
140 |
+
goto end
|
141 |
+
)
|
142 |
+
|
143 |
+
REM Ask if user wants Sage Attention
|
144 |
+
set choice=
|
145 |
+
echo Do you want to install any of the following? They speed up generation.
|
146 |
+
echo 1) Sage Attention
|
147 |
+
echo 2) Flash Attention
|
148 |
+
echo 3) BOTH!
|
149 |
+
echo 4) No
|
150 |
+
set /p choice= "Input Selection: "
|
151 |
+
|
152 |
+
set both="N"
|
153 |
+
|
154 |
+
if "!choice!" == "1" (goto triton260)
|
155 |
+
if "!choice!"== "2" (goto flash260)
|
156 |
+
if "!choice!"== "3" (set both="Y"
|
157 |
+
goto triton260)
|
158 |
+
|
159 |
+
goto requirements
|
160 |
+
|
161 |
+
:triton260
|
162 |
+
REM Sage Attention and Triton for Torch 2.6.0
|
163 |
+
"%cd%/venv/Scripts/pip.exe" install "triton-windows<3.3.0" --force-reinstall
|
164 |
+
"%cd%/venv/Scripts/pip.exe" install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp%pyv%-cp%pyv%-win_amd64.whl --force-reinstall
|
165 |
+
|
166 |
+
echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder...
|
167 |
+
|
168 |
+
REM Check for python version and download the triton-windows required libs accordingly
|
169 |
+
if %pyv% == 310 (
|
170 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')"
|
171 |
+
)
|
172 |
+
|
173 |
+
if %pyv% == 311 (
|
174 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')"
|
175 |
+
)
|
176 |
+
|
177 |
+
if %pyv% == 312 (
|
178 |
+
powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')"
|
179 |
+
)
|
180 |
+
|
181 |
+
REM Extract the zip into the Python Folder and Delete zip
|
182 |
+
powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force
|
183 |
+
del triton-lib.zip
|
184 |
+
|
185 |
+
if %both% == "Y" (goto flash260)
|
186 |
+
|
187 |
+
goto requirements
|
188 |
+
|
189 |
+
:flash260
|
190 |
+
REM Install flash-attn.
|
191 |
+
"%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%%2Bcu126torch2.6.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true"
|
192 |
+
|
193 |
+
:requirements
|
194 |
+
echo Installing remaining required packages through pip...
|
195 |
+
REM This assumes there's a requirements.txt file in the root
|
196 |
+
"%cd%/venv/Scripts/pip.exe" install -r requirements.txt
|
197 |
+
|
198 |
+
REM Check if pip installation was successful
|
199 |
+
if %errorlevel% neq 0 (
|
200 |
+
echo Warning: Failed to install dependencies. You may need to install them manually.
|
201 |
+
goto end
|
202 |
+
)
|
203 |
+
|
204 |
+
echo Setup complete.
|
205 |
+
|
206 |
+
:end
|
207 |
+
echo Exiting setup script.
|
208 |
+
pause
|
modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/__init__.py
|
2 |
+
|
3 |
+
# Workaround for the single lora bug. Must not be an empty string.
|
4 |
+
DUMMY_LORA_NAME = " "
|
modules/generators/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .original_generator import OriginalModelGenerator
|
2 |
+
from .f1_generator import F1ModelGenerator
|
3 |
+
from .video_generator import VideoModelGenerator
|
4 |
+
from .video_f1_generator import VideoF1ModelGenerator
|
5 |
+
from .original_with_endframe_generator import OriginalWithEndframeModelGenerator
|
6 |
+
|
7 |
+
def create_model_generator(model_type, **kwargs):
|
8 |
+
"""
|
9 |
+
Create a model generator based on the model type.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
model_type: The type of model to create ("Original", "Original with Endframe", "F1", "Video", or "Video F1")
|
13 |
+
**kwargs: Additional arguments to pass to the model generator constructor
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
A model generator instance
|
17 |
+
|
18 |
+
Raises:
|
19 |
+
ValueError: If the model type is not supported
|
20 |
+
"""
|
21 |
+
if model_type == "Original":
|
22 |
+
return OriginalModelGenerator(**kwargs)
|
23 |
+
elif model_type == "Original with Endframe":
|
24 |
+
return OriginalWithEndframeModelGenerator(**kwargs)
|
25 |
+
elif model_type == "F1":
|
26 |
+
return F1ModelGenerator(**kwargs)
|
27 |
+
elif model_type == "Video":
|
28 |
+
return VideoModelGenerator(**kwargs)
|
29 |
+
elif model_type == "Video F1":
|
30 |
+
return VideoF1ModelGenerator(**kwargs)
|
31 |
+
else:
|
32 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
modules/generators/base_generator.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os # required for os.path
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from diffusers_helper import lora_utils
|
5 |
+
from typing import List, Optional
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
class BaseModelGenerator(ABC):
|
9 |
+
"""
|
10 |
+
Base class for model generators.
|
11 |
+
This defines the common interface that all model generators must implement.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
text_encoder,
|
16 |
+
text_encoder_2,
|
17 |
+
tokenizer,
|
18 |
+
tokenizer_2,
|
19 |
+
vae,
|
20 |
+
image_encoder,
|
21 |
+
feature_extractor,
|
22 |
+
high_vram=False,
|
23 |
+
prompt_embedding_cache=None,
|
24 |
+
settings=None,
|
25 |
+
offline=False): # NEW: offline flag
|
26 |
+
"""
|
27 |
+
Initialize the base model generator.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
text_encoder: The text encoder model
|
31 |
+
text_encoder_2: The second text encoder model
|
32 |
+
tokenizer: The tokenizer for the first text encoder
|
33 |
+
tokenizer_2: The tokenizer for the second text encoder
|
34 |
+
vae: The VAE model
|
35 |
+
image_encoder: The image encoder model
|
36 |
+
feature_extractor: The feature extractor
|
37 |
+
high_vram: Whether high VRAM mode is enabled
|
38 |
+
prompt_embedding_cache: Cache for prompt embeddings
|
39 |
+
settings: Application settings
|
40 |
+
offline: Whether to run in offline mode for model loading
|
41 |
+
"""
|
42 |
+
self.text_encoder = text_encoder
|
43 |
+
self.text_encoder_2 = text_encoder_2
|
44 |
+
self.tokenizer = tokenizer
|
45 |
+
self.tokenizer_2 = tokenizer_2
|
46 |
+
self.vae = vae
|
47 |
+
self.image_encoder = image_encoder
|
48 |
+
self.feature_extractor = feature_extractor
|
49 |
+
self.high_vram = high_vram
|
50 |
+
self.prompt_embedding_cache = prompt_embedding_cache or {}
|
51 |
+
self.settings = settings
|
52 |
+
self.offline = offline
|
53 |
+
self.transformer = None
|
54 |
+
self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
+
self.cpu = torch.device("cpu")
|
56 |
+
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def load_model(self):
|
60 |
+
"""
|
61 |
+
Load the transformer model.
|
62 |
+
This method should be implemented by each specific model generator.
|
63 |
+
"""
|
64 |
+
pass
|
65 |
+
|
66 |
+
@abstractmethod
|
67 |
+
def get_model_name(self):
|
68 |
+
"""
|
69 |
+
Get the name of the model.
|
70 |
+
This method should be implemented by each specific model generator.
|
71 |
+
"""
|
72 |
+
pass
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None:
|
76 |
+
"""
|
77 |
+
Reads the commit hash from the refs/main file for a given model in the HF cache.
|
78 |
+
Args:
|
79 |
+
model_repo_id_for_cache (str): The model ID formatted for cache directory names
|
80 |
+
(e.g., "models--lllyasviel--FramePackI2V_HY").
|
81 |
+
Returns:
|
82 |
+
str: The commit hash if found, otherwise None.
|
83 |
+
"""
|
84 |
+
hf_home_dir = os.environ.get('HF_HOME')
|
85 |
+
if not hf_home_dir:
|
86 |
+
print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.")
|
87 |
+
return None
|
88 |
+
|
89 |
+
refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main')
|
90 |
+
if os.path.exists(refs_main_path):
|
91 |
+
try:
|
92 |
+
with open(refs_main_path, 'r') as f:
|
93 |
+
print(f"Offline mode: Reading snapshot hash from: {refs_main_path}")
|
94 |
+
return f.read().strip()
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}")
|
97 |
+
return None
|
98 |
+
else:
|
99 |
+
print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.")
|
100 |
+
return None
|
101 |
+
|
102 |
+
def _get_offline_load_path(self) -> str:
|
103 |
+
"""
|
104 |
+
Returns the local snapshot path for offline loading if available.
|
105 |
+
Falls back to the default self.model_path if local snapshot can't be found.
|
106 |
+
Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses.
|
107 |
+
"""
|
108 |
+
# Ensure necessary attributes are set by the subclass
|
109 |
+
if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache:
|
110 |
+
print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.")
|
111 |
+
# Fallback to model_path if it exists, otherwise None
|
112 |
+
return getattr(self, 'model_path', None)
|
113 |
+
|
114 |
+
if not hasattr(self, 'model_path') or not self.model_path:
|
115 |
+
print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.")
|
116 |
+
return None
|
117 |
+
|
118 |
+
snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache)
|
119 |
+
hf_home = os.environ.get('HF_HOME')
|
120 |
+
|
121 |
+
if snapshot_hash and hf_home:
|
122 |
+
specific_snapshot_path = os.path.join(
|
123 |
+
hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash
|
124 |
+
)
|
125 |
+
if os.path.isdir(specific_snapshot_path):
|
126 |
+
return specific_snapshot_path
|
127 |
+
|
128 |
+
# If snapshot logic fails or path is not a dir, fallback to the default model path
|
129 |
+
return self.model_path
|
130 |
+
|
131 |
+
def unload_loras(self):
|
132 |
+
"""
|
133 |
+
Unload all LoRAs from the transformer model.
|
134 |
+
"""
|
135 |
+
if self.transformer is not None:
|
136 |
+
print(f"Unloading all LoRAs from {self.get_model_name()} model")
|
137 |
+
self.transformer = lora_utils.unload_all_loras(self.transformer)
|
138 |
+
self.verify_lora_state("After unloading LoRAs")
|
139 |
+
import gc
|
140 |
+
gc.collect()
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
torch.cuda.empty_cache()
|
143 |
+
|
144 |
+
def verify_lora_state(self, label=""):
|
145 |
+
"""
|
146 |
+
Debug function to verify the state of LoRAs in the transformer model.
|
147 |
+
"""
|
148 |
+
if self.transformer is None:
|
149 |
+
print(f"[{label}] Transformer is None, cannot verify LoRA state")
|
150 |
+
return
|
151 |
+
|
152 |
+
has_loras = False
|
153 |
+
if hasattr(self.transformer, 'peft_config'):
|
154 |
+
adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else []
|
155 |
+
if adapter_names:
|
156 |
+
has_loras = True
|
157 |
+
print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
|
158 |
+
else:
|
159 |
+
print(f"[{label}] Transformer has no LoRAs in peft_config")
|
160 |
+
else:
|
161 |
+
print(f"[{label}] Transformer has no peft_config attribute")
|
162 |
+
|
163 |
+
# Check for any LoRA modules
|
164 |
+
for name, module in self.transformer.named_modules():
|
165 |
+
if hasattr(module, 'lora_A') and module.lora_A:
|
166 |
+
has_loras = True
|
167 |
+
# print(f"[{label}] Found lora_A in module {name}")
|
168 |
+
if hasattr(module, 'lora_B') and module.lora_B:
|
169 |
+
has_loras = True
|
170 |
+
# print(f"[{label}] Found lora_B in module {name}")
|
171 |
+
|
172 |
+
if not has_loras:
|
173 |
+
print(f"[{label}] No LoRA components found in transformer")
|
174 |
+
|
175 |
+
def move_lora_adapters_to_device(self, target_device):
|
176 |
+
"""
|
177 |
+
Move all LoRA adapters in the transformer model to the specified device.
|
178 |
+
This handles the PEFT implementation of LoRA.
|
179 |
+
"""
|
180 |
+
if self.transformer is None:
|
181 |
+
return
|
182 |
+
|
183 |
+
print(f"Moving all LoRA adapters to {target_device}")
|
184 |
+
|
185 |
+
# First, find all modules with LoRA adapters
|
186 |
+
lora_modules = []
|
187 |
+
for name, module in self.transformer.named_modules():
|
188 |
+
if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
|
189 |
+
lora_modules.append((name, module))
|
190 |
+
|
191 |
+
# Now move all LoRA components to the target device
|
192 |
+
for name, module in lora_modules:
|
193 |
+
# Get the active adapter name
|
194 |
+
active_adapter = module.active_adapter
|
195 |
+
|
196 |
+
# Move the LoRA layers to the target device
|
197 |
+
if active_adapter is not None:
|
198 |
+
if isinstance(module.lora_A, torch.nn.ModuleDict):
|
199 |
+
# Handle ModuleDict case (PEFT implementation)
|
200 |
+
for adapter_name in list(module.lora_A.keys()):
|
201 |
+
# Move lora_A
|
202 |
+
if adapter_name in module.lora_A:
|
203 |
+
module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device)
|
204 |
+
|
205 |
+
# Move lora_B
|
206 |
+
if adapter_name in module.lora_B:
|
207 |
+
module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device)
|
208 |
+
|
209 |
+
# Move scaling
|
210 |
+
if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling:
|
211 |
+
if isinstance(module.scaling[adapter_name], torch.Tensor):
|
212 |
+
module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device)
|
213 |
+
else:
|
214 |
+
# Handle direct attribute case
|
215 |
+
if hasattr(module, 'lora_A') and module.lora_A is not None:
|
216 |
+
module.lora_A = module.lora_A.to(target_device)
|
217 |
+
if hasattr(module, 'lora_B') and module.lora_B is not None:
|
218 |
+
module.lora_B = module.lora_B.to(target_device)
|
219 |
+
if hasattr(module, 'scaling') and module.scaling is not None:
|
220 |
+
if isinstance(module.scaling, torch.Tensor):
|
221 |
+
module.scaling = module.scaling.to(target_device)
|
222 |
+
|
223 |
+
print(f"Moved all LoRA adapters to {target_device}")
|
224 |
+
|
225 |
+
def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None):
|
226 |
+
"""
|
227 |
+
Load LoRAs into the transformer model and applies their weights.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
selected_loras: List of LoRA base names to load (e.g., ["lora_A", "lora_B"]).
|
231 |
+
lora_folder: Path to the folder containing the LoRA files.
|
232 |
+
lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing.
|
233 |
+
lora_values: A list of strength values corresponding to lora_loaded_names.
|
234 |
+
"""
|
235 |
+
self.unload_loras()
|
236 |
+
|
237 |
+
if not selected_loras:
|
238 |
+
print("No LoRAs selected, skipping loading.")
|
239 |
+
return
|
240 |
+
|
241 |
+
lora_dir = Path(lora_folder)
|
242 |
+
|
243 |
+
adapter_names = []
|
244 |
+
strengths = []
|
245 |
+
|
246 |
+
for idx, lora_base_name in enumerate(selected_loras):
|
247 |
+
lora_file = None
|
248 |
+
for ext in (".safetensors", ".pt"):
|
249 |
+
candidate_path_relative = f"{lora_base_name}{ext}"
|
250 |
+
candidate_path_full = lora_dir / candidate_path_relative
|
251 |
+
if candidate_path_full.is_file():
|
252 |
+
lora_file = candidate_path_relative
|
253 |
+
break
|
254 |
+
|
255 |
+
if not lora_file:
|
256 |
+
print(f"Warning: LoRA file for base name '{lora_base_name}' not found; skipping.")
|
257 |
+
continue
|
258 |
+
|
259 |
+
print(f"Loading LoRA from '{lora_file}'...")
|
260 |
+
|
261 |
+
self.transformer, adapter_name = lora_utils.load_lora(self.transformer, lora_dir, lora_file)
|
262 |
+
adapter_names.append(adapter_name)
|
263 |
+
|
264 |
+
weight = 1.0
|
265 |
+
if lora_values:
|
266 |
+
try:
|
267 |
+
master_list_idx = lora_loaded_names.index(lora_base_name)
|
268 |
+
if master_list_idx < len(lora_values):
|
269 |
+
weight = float(lora_values[master_list_idx])
|
270 |
+
else:
|
271 |
+
print(f"Warning: Index mismatch for '{lora_base_name}'. Defaulting to 1.0.")
|
272 |
+
except ValueError:
|
273 |
+
print(f"Warning: LoRA '{lora_base_name}' not found in master list. Defaulting to 1.0.")
|
274 |
+
|
275 |
+
strengths.append(weight)
|
276 |
+
|
277 |
+
if adapter_names:
|
278 |
+
print(f"Activating adapters: {adapter_names} with strengths: {strengths}")
|
279 |
+
lora_utils.set_adapters(self.transformer, adapter_names, strengths)
|
280 |
+
|
281 |
+
self.verify_lora_state("After completing load_loras")
|
modules/generators/f1_generator.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os # for offline loading path
|
3 |
+
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
4 |
+
from diffusers_helper.memory import DynamicSwapInstaller
|
5 |
+
from .base_generator import BaseModelGenerator
|
6 |
+
|
7 |
+
class F1ModelGenerator(BaseModelGenerator):
|
8 |
+
"""
|
9 |
+
Model generator for the F1 HunyuanVideo model.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
"""
|
14 |
+
Initialize the F1 model generator.
|
15 |
+
"""
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.model_name = "F1"
|
18 |
+
self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503'
|
19 |
+
self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503"
|
20 |
+
|
21 |
+
def get_model_name(self):
|
22 |
+
"""
|
23 |
+
Get the name of the model.
|
24 |
+
"""
|
25 |
+
return self.model_name
|
26 |
+
|
27 |
+
def load_model(self):
|
28 |
+
"""
|
29 |
+
Load the F1 transformer model.
|
30 |
+
If offline mode is True, attempts to load from a local snapshot.
|
31 |
+
"""
|
32 |
+
print(f"Loading {self.model_name} Transformer...")
|
33 |
+
|
34 |
+
path_to_load = self.model_path # Initialize with the default path
|
35 |
+
|
36 |
+
if self.offline:
|
37 |
+
path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
|
38 |
+
|
39 |
+
# Create the transformer model
|
40 |
+
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
|
41 |
+
path_to_load,
|
42 |
+
torch_dtype=torch.bfloat16
|
43 |
+
).cpu()
|
44 |
+
|
45 |
+
# Configure the model
|
46 |
+
self.transformer.eval()
|
47 |
+
self.transformer.to(dtype=torch.bfloat16)
|
48 |
+
self.transformer.requires_grad_(False)
|
49 |
+
|
50 |
+
# Set up dynamic swap if not in high VRAM mode
|
51 |
+
if not self.high_vram:
|
52 |
+
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
|
53 |
+
else:
|
54 |
+
# In high VRAM mode, move the entire model to GPU
|
55 |
+
self.transformer.to(device=self.gpu)
|
56 |
+
|
57 |
+
print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
|
58 |
+
return self.transformer
|
59 |
+
|
60 |
+
def prepare_history_latents(self, height, width):
|
61 |
+
"""
|
62 |
+
Prepare the history latents tensor for the F1 model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
height: The height of the image
|
66 |
+
width: The width of the image
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
The initialized history latents tensor
|
70 |
+
"""
|
71 |
+
return torch.zeros(
|
72 |
+
size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
|
73 |
+
dtype=torch.float32
|
74 |
+
).cpu()
|
75 |
+
|
76 |
+
def initialize_with_start_latent(self, history_latents, start_latent, is_real_image_latent):
|
77 |
+
"""
|
78 |
+
Initialize the history latents with the start latent for the F1 model.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
history_latents: The history latents
|
82 |
+
start_latent: The start latent
|
83 |
+
is_real_image_latent: Whether the start latent came from a real input image or is a synthetic noise
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
The initialized history latents
|
87 |
+
"""
|
88 |
+
# Add the start frame to history_latents
|
89 |
+
if is_real_image_latent:
|
90 |
+
return torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
|
91 |
+
# After prepare_history_latents, history_latents (initialized with zeros)
|
92 |
+
# already has the required 19 entries for initial clean latents
|
93 |
+
return history_latents
|
94 |
+
|
95 |
+
def get_latent_paddings(self, total_latent_sections):
|
96 |
+
"""
|
97 |
+
Get the latent paddings for the F1 model.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
total_latent_sections: The total number of latent sections
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
A list of latent paddings
|
104 |
+
"""
|
105 |
+
# F1 model uses a fixed approach with just 0 for last section and 1 for others
|
106 |
+
return [1] * (total_latent_sections - 1) + [0]
|
107 |
+
|
108 |
+
def prepare_indices(self, latent_padding_size, latent_window_size):
|
109 |
+
"""
|
110 |
+
Prepare the indices for the F1 model.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
latent_padding_size: The size of the latent padding
|
114 |
+
latent_window_size: The size of the latent window
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
|
118 |
+
"""
|
119 |
+
# F1 model uses a different indices approach
|
120 |
+
# latent_window_sizeが4.5の場合は特別に5を使用
|
121 |
+
effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size)
|
122 |
+
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
123 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
124 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
125 |
+
|
126 |
+
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
|
127 |
+
|
128 |
+
def prepare_clean_latents(self, start_latent, history_latents):
|
129 |
+
"""
|
130 |
+
Prepare the clean latents for the F1 model.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
start_latent: The start latent
|
134 |
+
history_latents: The history latents
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
|
138 |
+
"""
|
139 |
+
# For F1, we take the last frames for clean latents
|
140 |
+
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
|
141 |
+
# For F1, we prepend the start latent to clean_latents_1x
|
142 |
+
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
143 |
+
|
144 |
+
return clean_latents, clean_latents_2x, clean_latents_4x
|
145 |
+
|
146 |
+
def update_history_latents(self, history_latents, generated_latents):
|
147 |
+
"""
|
148 |
+
Update the history latents with the generated latents for the F1 model.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
history_latents: The history latents
|
152 |
+
generated_latents: The generated latents
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
The updated history latents
|
156 |
+
"""
|
157 |
+
# For F1, we append new frames to the end
|
158 |
+
return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
159 |
+
|
160 |
+
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
|
161 |
+
"""
|
162 |
+
Get the real history latents for the F1 model.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
history_latents: The history latents
|
166 |
+
total_generated_latent_frames: The total number of generated latent frames
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
The real history latents
|
170 |
+
"""
|
171 |
+
# For F1, we take frames from the end
|
172 |
+
return history_latents[:, :, -total_generated_latent_frames:, :, :]
|
173 |
+
|
174 |
+
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
|
175 |
+
"""
|
176 |
+
Update the history pixels with the current pixels for the F1 model.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
history_pixels: The history pixels
|
180 |
+
current_pixels: The current pixels
|
181 |
+
overlapped_frames: The number of overlapped frames
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
The updated history pixels
|
185 |
+
"""
|
186 |
+
from diffusers_helper.utils import soft_append_bcthw
|
187 |
+
# For F1 model, history_pixels is first, current_pixels is second
|
188 |
+
return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
|
189 |
+
|
190 |
+
def get_section_latent_frames(self, latent_window_size, is_last_section):
|
191 |
+
"""
|
192 |
+
Get the number of section latent frames for the F1 model.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
latent_window_size: The size of the latent window
|
196 |
+
is_last_section: Whether this is the last section
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
The number of section latent frames
|
200 |
+
"""
|
201 |
+
return latent_window_size * 2
|
202 |
+
|
203 |
+
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
|
204 |
+
"""
|
205 |
+
Get the current pixels for the F1 model.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
real_history_latents: The real history latents
|
209 |
+
section_latent_frames: The number of section latent frames
|
210 |
+
vae: The VAE model
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
The current pixels
|
214 |
+
"""
|
215 |
+
from diffusers_helper.hunyuan import vae_decode
|
216 |
+
# For F1, we take frames from the end
|
217 |
+
return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
|
218 |
+
|
219 |
+
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
|
220 |
+
"""
|
221 |
+
Format the position description for the F1 model.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
total_generated_latent_frames: The total number of generated latent frames
|
225 |
+
current_pos: The current position in seconds
|
226 |
+
original_pos: The original position in seconds
|
227 |
+
current_prompt: The current prompt
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
The formatted position description
|
231 |
+
"""
|
232 |
+
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
|
233 |
+
f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
|
234 |
+
f'Current position: {current_pos:.2f}s. '
|
235 |
+
f'using prompt: {current_prompt[:256]}...')
|
modules/generators/original_generator.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os # for offline loading path
|
3 |
+
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
4 |
+
from diffusers_helper.memory import DynamicSwapInstaller
|
5 |
+
from .base_generator import BaseModelGenerator
|
6 |
+
|
7 |
+
class OriginalModelGenerator(BaseModelGenerator):
|
8 |
+
"""
|
9 |
+
Model generator for the Original HunyuanVideo model.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
"""
|
14 |
+
Initialize the Original model generator.
|
15 |
+
"""
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.model_name = "Original"
|
18 |
+
self.model_path = 'lllyasviel/FramePackI2V_HY'
|
19 |
+
self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
|
20 |
+
|
21 |
+
def get_model_name(self):
|
22 |
+
"""
|
23 |
+
Get the name of the model.
|
24 |
+
"""
|
25 |
+
return self.model_name
|
26 |
+
|
27 |
+
def load_model(self):
|
28 |
+
"""
|
29 |
+
Load the Original transformer model.
|
30 |
+
If offline mode is True, attempts to load from a local snapshot.
|
31 |
+
"""
|
32 |
+
print(f"Loading {self.model_name} Transformer...")
|
33 |
+
|
34 |
+
path_to_load = self.model_path # Initialize with the default path
|
35 |
+
|
36 |
+
if self.offline:
|
37 |
+
path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
|
38 |
+
|
39 |
+
# Create the transformer model
|
40 |
+
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
|
41 |
+
path_to_load,
|
42 |
+
torch_dtype=torch.bfloat16
|
43 |
+
).cpu()
|
44 |
+
|
45 |
+
# Configure the model
|
46 |
+
self.transformer.eval()
|
47 |
+
self.transformer.to(dtype=torch.bfloat16)
|
48 |
+
self.transformer.requires_grad_(False)
|
49 |
+
|
50 |
+
# Set up dynamic swap if not in high VRAM mode
|
51 |
+
if not self.high_vram:
|
52 |
+
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
|
53 |
+
else:
|
54 |
+
# In high VRAM mode, move the entire model to GPU
|
55 |
+
self.transformer.to(device=self.gpu)
|
56 |
+
|
57 |
+
print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
|
58 |
+
return self.transformer
|
59 |
+
|
60 |
+
def prepare_history_latents(self, height, width):
|
61 |
+
"""
|
62 |
+
Prepare the history latents tensor for the Original model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
height: The height of the image
|
66 |
+
width: The width of the image
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
The initialized history latents tensor
|
70 |
+
"""
|
71 |
+
return torch.zeros(
|
72 |
+
size=(1, 16, 1 + 2 + 16, height // 8, width // 8),
|
73 |
+
dtype=torch.float32
|
74 |
+
).cpu()
|
75 |
+
|
76 |
+
def get_latent_paddings(self, total_latent_sections):
|
77 |
+
"""
|
78 |
+
Get the latent paddings for the Original model.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
total_latent_sections: The total number of latent sections
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
A list of latent paddings
|
85 |
+
"""
|
86 |
+
# Original model uses reversed latent paddings
|
87 |
+
if total_latent_sections > 4:
|
88 |
+
return [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
89 |
+
else:
|
90 |
+
return list(reversed(range(total_latent_sections)))
|
91 |
+
|
92 |
+
def prepare_indices(self, latent_padding_size, latent_window_size):
|
93 |
+
"""
|
94 |
+
Prepare the indices for the Original model.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
latent_padding_size: The size of the latent padding
|
98 |
+
latent_window_size: The size of the latent window
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
|
102 |
+
"""
|
103 |
+
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
104 |
+
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
105 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
106 |
+
|
107 |
+
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
|
108 |
+
|
109 |
+
def prepare_clean_latents(self, start_latent, history_latents):
|
110 |
+
"""
|
111 |
+
Prepare the clean latents for the Original model.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
start_latent: The start latent
|
115 |
+
history_latents: The history latents
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
|
119 |
+
"""
|
120 |
+
clean_latents_pre = start_latent.to(history_latents)
|
121 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
|
122 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
123 |
+
|
124 |
+
return clean_latents, clean_latents_2x, clean_latents_4x
|
125 |
+
|
126 |
+
def update_history_latents(self, history_latents, generated_latents):
|
127 |
+
"""
|
128 |
+
Update the history latents with the generated latents for the Original model.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
history_latents: The history latents
|
132 |
+
generated_latents: The generated latents
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
The updated history latents
|
136 |
+
"""
|
137 |
+
# For Original model, we prepend the generated latents
|
138 |
+
return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
139 |
+
|
140 |
+
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
|
141 |
+
"""
|
142 |
+
Get the real history latents for the Original model.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
history_latents: The history latents
|
146 |
+
total_generated_latent_frames: The total number of generated latent frames
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
The real history latents
|
150 |
+
"""
|
151 |
+
return history_latents[:, :, :total_generated_latent_frames, :, :]
|
152 |
+
|
153 |
+
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
|
154 |
+
"""
|
155 |
+
Update the history pixels with the current pixels for the Original model.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
history_pixels: The history pixels
|
159 |
+
current_pixels: The current pixels
|
160 |
+
overlapped_frames: The number of overlapped frames
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
The updated history pixels
|
164 |
+
"""
|
165 |
+
from diffusers_helper.utils import soft_append_bcthw
|
166 |
+
# For Original model, current_pixels is first, history_pixels is second
|
167 |
+
return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
168 |
+
|
169 |
+
def get_section_latent_frames(self, latent_window_size, is_last_section):
|
170 |
+
"""
|
171 |
+
Get the number of section latent frames for the Original model.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
latent_window_size: The size of the latent window
|
175 |
+
is_last_section: Whether this is the last section
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
The number of section latent frames
|
179 |
+
"""
|
180 |
+
return latent_window_size * 2
|
181 |
+
|
182 |
+
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
|
183 |
+
"""
|
184 |
+
Get the current pixels for the Original model.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
real_history_latents: The real history latents
|
188 |
+
section_latent_frames: The number of section latent frames
|
189 |
+
vae: The VAE model
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
The current pixels
|
193 |
+
"""
|
194 |
+
from diffusers_helper.hunyuan import vae_decode
|
195 |
+
return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
|
196 |
+
|
197 |
+
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
|
198 |
+
"""
|
199 |
+
Format the position description for the Original model.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
total_generated_latent_frames: The total number of generated latent frames
|
203 |
+
current_pos: The current position in seconds
|
204 |
+
original_pos: The original position in seconds
|
205 |
+
current_prompt: The current prompt
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
The formatted position description
|
209 |
+
"""
|
210 |
+
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
|
211 |
+
f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
|
212 |
+
f'Current position: {current_pos:.2f}s (original: {original_pos:.2f}s). '
|
213 |
+
f'using prompt: {current_prompt[:256]}...')
|
modules/generators/original_with_endframe_generator.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .original_generator import OriginalModelGenerator
|
2 |
+
|
3 |
+
class OriginalWithEndframeModelGenerator(OriginalModelGenerator):
|
4 |
+
"""
|
5 |
+
Model generator for the Original HunyuanVideo model with end frame support.
|
6 |
+
This extends the Original model with the ability to guide generation toward a specified end frame.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
"""
|
11 |
+
Initialize the Original with Endframe model generator.
|
12 |
+
"""
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
self.model_name = "Original with Endframe"
|
15 |
+
# Inherits everything else from OriginalModelGenerator
|
modules/generators/video_base_generator.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import decord
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pathlib
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
11 |
+
from diffusers_helper.memory import DynamicSwapInstaller
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from diffusers_helper.hunyuan import vae_encode, vae_decode
|
15 |
+
from .base_generator import BaseModelGenerator
|
16 |
+
|
17 |
+
class VideoBaseModelGenerator(BaseModelGenerator):
|
18 |
+
"""
|
19 |
+
Model generator for the Video extension of the Original HunyuanVideo model.
|
20 |
+
This generator accepts video input instead of a single image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
"""
|
25 |
+
Initialize the Video model generator.
|
26 |
+
"""
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
self.model_name = None # Subclass Model Specific
|
29 |
+
self.model_path = None # Subclass Model Specific
|
30 |
+
self.model_repo_id_for_cache = None # Subclass Model Specific
|
31 |
+
self.full_video_latents = None # For context, set by worker() when available
|
32 |
+
self.resolution = 640 # Default resolution
|
33 |
+
self.no_resize = False # Default to resize
|
34 |
+
self.vae_batch_size = 16 # Default VAE batch size
|
35 |
+
|
36 |
+
# Import decord and tqdm here to avoid import errors if not installed
|
37 |
+
try:
|
38 |
+
import decord
|
39 |
+
from tqdm import tqdm
|
40 |
+
self.decord = decord
|
41 |
+
self.tqdm = tqdm
|
42 |
+
except ImportError:
|
43 |
+
print("Warning: decord or tqdm not installed. Video processing will not work.")
|
44 |
+
self.decord = None
|
45 |
+
self.tqdm = None
|
46 |
+
|
47 |
+
def get_model_name(self):
|
48 |
+
"""
|
49 |
+
Get the name of the model.
|
50 |
+
"""
|
51 |
+
return self.model_name
|
52 |
+
|
53 |
+
def load_model(self):
|
54 |
+
"""
|
55 |
+
Load the Video transformer model.
|
56 |
+
If offline mode is True, attempts to load from a local snapshot.
|
57 |
+
"""
|
58 |
+
print(f"Loading {self.model_name} Transformer...")
|
59 |
+
|
60 |
+
path_to_load = self.model_path # Initialize with the default path
|
61 |
+
|
62 |
+
if self.offline:
|
63 |
+
path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
|
64 |
+
|
65 |
+
# Create the transformer model
|
66 |
+
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
|
67 |
+
path_to_load,
|
68 |
+
torch_dtype=torch.bfloat16
|
69 |
+
).cpu()
|
70 |
+
|
71 |
+
# Configure the model
|
72 |
+
self.transformer.eval()
|
73 |
+
self.transformer.to(dtype=torch.bfloat16)
|
74 |
+
self.transformer.requires_grad_(False)
|
75 |
+
|
76 |
+
# Set up dynamic swap if not in high VRAM mode
|
77 |
+
if not self.high_vram:
|
78 |
+
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
|
79 |
+
else:
|
80 |
+
# In high VRAM mode, move the entire model to GPU
|
81 |
+
self.transformer.to(device=self.gpu)
|
82 |
+
|
83 |
+
print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
|
84 |
+
return self.transformer
|
85 |
+
|
86 |
+
def min_real_frames_to_encode(self, real_frames_available_count):
|
87 |
+
"""
|
88 |
+
Minimum number of real frames to encode
|
89 |
+
is the maximum number of real frames used for generation context.
|
90 |
+
|
91 |
+
The number of latents could be calculated as below for video F1, but keeping it simple for now
|
92 |
+
by hardcoding the Video F1 value at max_latents_used_for_context = 27.
|
93 |
+
|
94 |
+
# Calculate the number of latent frames to encode from the end of the input video
|
95 |
+
num_frames_to_encode_from_end = 1 # Default minimum
|
96 |
+
if model_type == "Video":
|
97 |
+
# Max needed is 1 (clean_latent_pre) + 2 (max 2x) + 16 (max 4x) = 19
|
98 |
+
num_frames_to_encode_from_end = 19
|
99 |
+
elif model_type == "Video F1":
|
100 |
+
ui_num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
|
101 |
+
# Max effective_clean_frames based on VideoF1ModelGenerator's logic.
|
102 |
+
# Max num_clean_frames from UI is 10 (modules/interface.py).
|
103 |
+
# Max effective_clean_frames = 10 - 1 = 9.
|
104 |
+
# total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
|
105 |
+
# Max needed = 16 (max 4x) + 2 (max 2x) + 9 (max effective_clean_frames) = 27
|
106 |
+
num_frames_to_encode_from_end = 27
|
107 |
+
|
108 |
+
Note: 27 latents ~ 108 real frames = 3.6 seconds at 30 FPS.
|
109 |
+
Note: 19 latents ~ 76 real frames ~ 2.5 seconds at 30 FPS.
|
110 |
+
"""
|
111 |
+
|
112 |
+
max_latents_used_for_context = 27
|
113 |
+
if self.get_model_name() == "Video":
|
114 |
+
max_latents_used_for_context = 27 # Weird results on 19
|
115 |
+
elif self.get_model_name() == "Video F1":
|
116 |
+
max_latents_used_for_context = 27 # Enough for even Video F1 with cleaned_frames input of 10
|
117 |
+
else:
|
118 |
+
print("======================================================")
|
119 |
+
print(f" ***** Warning: Unsupported video extension model type: {self.get_model_name()}.")
|
120 |
+
print( " ***** Using default max latents {max_latents_used_for_context} for context.")
|
121 |
+
print( " ***** Please report to the developers if you see this message:")
|
122 |
+
print( " ***** Discord: https://discord.gg/8Z2c3a4 or GitHub: https://github.com/colinurbs/FramePack-Studio")
|
123 |
+
print("======================================================")
|
124 |
+
# Probably better to press on with Video F1 max vs exception?
|
125 |
+
# raise ValueError(f"Unsupported video extension model type: {self.get_model_name()}")
|
126 |
+
|
127 |
+
latent_size_factor = 4 # real frames to latent frames conversion factor
|
128 |
+
max_real_frames_used_for_context = max_latents_used_for_context * latent_size_factor
|
129 |
+
|
130 |
+
# Shortest of available frames and max frames used for context
|
131 |
+
trimmed_real_frames_count = min(real_frames_available_count, max_real_frames_used_for_context)
|
132 |
+
if trimmed_real_frames_count < real_frames_available_count:
|
133 |
+
print(f"Truncating video frames from {real_frames_available_count} to {trimmed_real_frames_count}, enough to populate context")
|
134 |
+
|
135 |
+
# Truncate to nearest latent size (multiple of 4)
|
136 |
+
frames_to_encode_count = (trimmed_real_frames_count // latent_size_factor) * latent_size_factor
|
137 |
+
if frames_to_encode_count != trimmed_real_frames_count:
|
138 |
+
print(f"Truncating video frames from {trimmed_real_frames_count} to {frames_to_encode_count}, for latent size compatibility")
|
139 |
+
|
140 |
+
return frames_to_encode_count
|
141 |
+
|
142 |
+
def extract_video_frames(self, is_for_encode, video_path, resolution, no_resize=False, input_files_dir=None):
|
143 |
+
"""
|
144 |
+
Extract real frames from a video, resized and center cropped as numpy array (T, H, W, C).
|
145 |
+
|
146 |
+
Args:
|
147 |
+
is_for_encode: If True, results are capped at maximum frames used for context, and aligned to 4-frame latent requirement.
|
148 |
+
video_path: Path to the input video file.
|
149 |
+
resolution: Target resolution for resizing frames.
|
150 |
+
no_resize: Whether to use the original video resolution.
|
151 |
+
input_files_dir: Directory for input files that won't be cleaned up.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
A tuple containing:
|
155 |
+
- input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C)
|
156 |
+
- fps: Frames per second of the input video
|
157 |
+
- target_height: Target height of the video
|
158 |
+
- target_width: Target width of the video
|
159 |
+
"""
|
160 |
+
def time_millis():
|
161 |
+
import time
|
162 |
+
return time.perf_counter() * 1000.0 # Convert seconds to milliseconds
|
163 |
+
|
164 |
+
encode_start_time_millis = time_millis()
|
165 |
+
|
166 |
+
# Normalize video path for Windows compatibility
|
167 |
+
video_path = str(pathlib.Path(video_path).resolve())
|
168 |
+
print(f"Processing video: {video_path}")
|
169 |
+
|
170 |
+
# Check if the video is in the temp directory and if we have an input_files_dir
|
171 |
+
if input_files_dir and "temp" in video_path:
|
172 |
+
# Check if there's a copy of this video in the input_files_dir
|
173 |
+
filename = os.path.basename(video_path)
|
174 |
+
input_file_path = os.path.join(input_files_dir, filename)
|
175 |
+
|
176 |
+
# If the file exists in input_files_dir, use that instead
|
177 |
+
if os.path.exists(input_file_path):
|
178 |
+
print(f"Using video from input_files_dir: {input_file_path}")
|
179 |
+
video_path = input_file_path
|
180 |
+
else:
|
181 |
+
# If not, copy it to input_files_dir to prevent it from being deleted
|
182 |
+
try:
|
183 |
+
from diffusers_helper.utils import generate_timestamp
|
184 |
+
safe_filename = f"{generate_timestamp()}_{filename}"
|
185 |
+
input_file_path = os.path.join(input_files_dir, safe_filename)
|
186 |
+
import shutil
|
187 |
+
shutil.copy2(video_path, input_file_path)
|
188 |
+
print(f"Copied video to input_files_dir: {input_file_path}")
|
189 |
+
video_path = input_file_path
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error copying video to input_files_dir: {e}")
|
192 |
+
|
193 |
+
try:
|
194 |
+
# Load video and get FPS
|
195 |
+
print("Initializing VideoReader...")
|
196 |
+
vr = decord.VideoReader(video_path)
|
197 |
+
fps = vr.get_avg_fps() # Get input video FPS
|
198 |
+
num_real_frames = len(vr)
|
199 |
+
print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")
|
200 |
+
|
201 |
+
# Read frames
|
202 |
+
print("Reading video frames...")
|
203 |
+
|
204 |
+
total_frames_in_video_file = len(vr)
|
205 |
+
if is_for_encode:
|
206 |
+
print(f"Using minimum real frames to encode: {self.min_real_frames_to_encode(total_frames_in_video_file)}")
|
207 |
+
num_real_frames = self.min_real_frames_to_encode(total_frames_in_video_file)
|
208 |
+
# else left as all frames -- len(vr) with no regard for trimming or latent alignment
|
209 |
+
|
210 |
+
# RT_BORG: Retaining this commented code for reference.
|
211 |
+
# pftq encoder discarded truncated frames from the end of the video.
|
212 |
+
# frames = vr.get_batch(range(num_real_frames)).asnumpy() # Shape: (num_real_frames, height, width, channels)
|
213 |
+
|
214 |
+
# RT_BORG: Retaining this commented code for reference.
|
215 |
+
# pftq retained the entire encoded video.
|
216 |
+
# Truncate to nearest latent size (multiple of 4)
|
217 |
+
# latent_size_factor = 4
|
218 |
+
# num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
|
219 |
+
# if num_frames != num_real_frames:
|
220 |
+
# print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility")
|
221 |
+
# num_real_frames = num_frames
|
222 |
+
|
223 |
+
# Discard truncated frames from the beginning of the video, retaining the last num_real_frames
|
224 |
+
# This ensures a smooth transition from the input video to the generated video
|
225 |
+
start_frame_index = total_frames_in_video_file - num_real_frames
|
226 |
+
frame_indices_to_extract = range(start_frame_index, total_frames_in_video_file)
|
227 |
+
frames = vr.get_batch(frame_indices_to_extract).asnumpy() # Shape: (num_real_frames, height, width, channels)
|
228 |
+
|
229 |
+
print(f"Frames read: {frames.shape}")
|
230 |
+
|
231 |
+
# Get native video resolution
|
232 |
+
native_height, native_width = frames.shape[1], frames.shape[2]
|
233 |
+
print(f"Native video resolution: {native_width}x{native_height}")
|
234 |
+
|
235 |
+
# Use native resolution if height/width not specified, otherwise use provided values
|
236 |
+
target_height = native_height
|
237 |
+
target_width = native_width
|
238 |
+
|
239 |
+
# Adjust to nearest bucket for model compatibility
|
240 |
+
if not no_resize:
|
241 |
+
target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution)
|
242 |
+
print(f"Adjusted resolution: {target_width}x{target_height}")
|
243 |
+
else:
|
244 |
+
print(f"Using native resolution without resizing: {target_width}x{target_height}")
|
245 |
+
|
246 |
+
# Preprocess input frames to match desired resolution
|
247 |
+
input_frames_resized_np = []
|
248 |
+
for i, frame in tqdm(enumerate(frames), desc="Processing Video Frames", total=num_real_frames, mininterval=0.1):
|
249 |
+
frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height)
|
250 |
+
input_frames_resized_np.append(frame_np)
|
251 |
+
input_frames_resized_np = np.stack(input_frames_resized_np) # Shape: (num_real_frames, height, width, channels)
|
252 |
+
print(f"Frames preprocessed: {input_frames_resized_np.shape}")
|
253 |
+
|
254 |
+
resized_frames_time_millis = time_millis()
|
255 |
+
if (False): # We really need a logger
|
256 |
+
print("======================================================")
|
257 |
+
memory_bytes = input_frames_resized_np.nbytes
|
258 |
+
memory_kb = memory_bytes / 1024
|
259 |
+
memory_mb = memory_kb / 1024
|
260 |
+
print(f" ***** input_frames_resized_np: {input_frames_resized_np.shape}")
|
261 |
+
print(f" ***** Memory usage: {int(memory_mb)} MB")
|
262 |
+
duration_ms = resized_frames_time_millis - encode_start_time_millis
|
263 |
+
print(f" ***** Time taken to process frames tensor: {duration_ms / 1000.0:.2f} seconds")
|
264 |
+
print("======================================================")
|
265 |
+
|
266 |
+
return input_frames_resized_np, fps, target_height, target_width
|
267 |
+
except Exception as e:
|
268 |
+
print(f"Error in extract_video_frames: {str(e)}")
|
269 |
+
raise
|
270 |
+
|
271 |
+
# RT_BORG: video_encode produce and return end_of_input_video_latent and end_of_input_video_image_np
|
272 |
+
# which are not needed for Video models without end frame processing.
|
273 |
+
# But these should be inexpensive and it's easier to keep the code uniform.
|
274 |
+
@torch.no_grad()
|
275 |
+
def video_encode(self, video_path, resolution, no_resize=False, vae_batch_size=16, device=None, input_files_dir=None):
|
276 |
+
"""
|
277 |
+
Encode a video into latent representations using the VAE.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
video_path: Path to the input video file.
|
281 |
+
resolution: Target resolution for resizing frames.
|
282 |
+
no_resize: Whether to use the original video resolution.
|
283 |
+
vae_batch_size: Number of frames to process per batch.
|
284 |
+
device: Device for computation (e.g., "cuda").
|
285 |
+
input_files_dir: Directory for input files that won't be cleaned up.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
A tuple containing:
|
289 |
+
- start_latent: Latent of the first frame
|
290 |
+
- input_image_np: First frame as numpy array
|
291 |
+
- history_latents: Latents of all frames
|
292 |
+
- fps: Frames per second of the input video
|
293 |
+
- target_height: Target height of the video
|
294 |
+
- target_width: Target width of the video
|
295 |
+
- input_video_pixels: Video frames as tensor
|
296 |
+
- end_of_input_video_image_np: Last frame as numpy array
|
297 |
+
- input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C)
|
298 |
+
"""
|
299 |
+
encoding = True # Flag to indicate this is for encoding
|
300 |
+
input_frames_resized_np, fps, target_height, target_width = self.extract_video_frames(encoding, video_path, resolution, no_resize, input_files_dir)
|
301 |
+
|
302 |
+
try:
|
303 |
+
if device is None:
|
304 |
+
device = self.gpu
|
305 |
+
|
306 |
+
# Check CUDA availability and fallback to CPU if needed
|
307 |
+
if device == "cuda" and not torch.cuda.is_available():
|
308 |
+
print("CUDA is not available, falling back to CPU")
|
309 |
+
device = "cpu"
|
310 |
+
|
311 |
+
# Save first frame for CLIP vision encoding
|
312 |
+
input_image_np = input_frames_resized_np[0]
|
313 |
+
end_of_input_video_image_np = input_frames_resized_np[-1]
|
314 |
+
|
315 |
+
# Convert to tensor and normalize to [-1, 1]
|
316 |
+
print("Converting frames to tensor...")
|
317 |
+
frames_pt = torch.from_numpy(input_frames_resized_np).float() / 127.5 - 1
|
318 |
+
frames_pt = frames_pt.permute(0, 3, 1, 2) # Shape: (num_real_frames, channels, height, width)
|
319 |
+
frames_pt = frames_pt.unsqueeze(0) # Shape: (1, num_real_frames, channels, height, width)
|
320 |
+
frames_pt = frames_pt.permute(0, 2, 1, 3, 4) # Shape: (1, channels, num_real_frames, height, width)
|
321 |
+
print(f"Tensor shape: {frames_pt.shape}")
|
322 |
+
|
323 |
+
# Save pixel frames for use in worker
|
324 |
+
input_video_pixels = frames_pt.cpu()
|
325 |
+
|
326 |
+
# Move to device
|
327 |
+
print(f"Moving tensor to device: {device}")
|
328 |
+
frames_pt = frames_pt.to(device)
|
329 |
+
print("Tensor moved to device")
|
330 |
+
|
331 |
+
# Move VAE to device
|
332 |
+
print(f"Moving VAE to device: {device}")
|
333 |
+
self.vae.to(device)
|
334 |
+
print("VAE moved to device")
|
335 |
+
|
336 |
+
# Encode frames in batches
|
337 |
+
print(f"Encoding input video frames in VAE batch size {vae_batch_size}")
|
338 |
+
latents = []
|
339 |
+
self.vae.eval()
|
340 |
+
with torch.no_grad():
|
341 |
+
frame_count = frames_pt.shape[2]
|
342 |
+
step_count = math.ceil(frame_count / vae_batch_size)
|
343 |
+
for i in tqdm(range(0, frame_count, vae_batch_size), desc="Encoding video frames", total=step_count, mininterval=0.1):
|
344 |
+
batch = frames_pt[:, :, i:i + vae_batch_size] # Shape: (1, channels, batch_size, height, width)
|
345 |
+
try:
|
346 |
+
# Log GPU memory before encoding
|
347 |
+
if device == "cuda":
|
348 |
+
free_mem = torch.cuda.memory_allocated() / 1024**3
|
349 |
+
batch_latent = vae_encode(batch, self.vae)
|
350 |
+
# Synchronize CUDA to catch issues
|
351 |
+
if device == "cuda":
|
352 |
+
torch.cuda.synchronize()
|
353 |
+
latents.append(batch_latent)
|
354 |
+
except RuntimeError as e:
|
355 |
+
print(f"Error during VAE encoding: {str(e)}")
|
356 |
+
if device == "cuda" and "out of memory" in str(e).lower():
|
357 |
+
print("CUDA out of memory, try reducing vae_batch_size or using CPU")
|
358 |
+
raise
|
359 |
+
|
360 |
+
# Concatenate latents
|
361 |
+
print("Concatenating latents...")
|
362 |
+
history_latents = torch.cat(latents, dim=2) # Shape: (1, channels, frames, height//8, width//8)
|
363 |
+
print(f"History latents shape: {history_latents.shape}")
|
364 |
+
|
365 |
+
# Get first frame's latent
|
366 |
+
start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
|
367 |
+
print(f"Start latent shape: {start_latent.shape}")
|
368 |
+
|
369 |
+
if (False): # We really need a logger
|
370 |
+
print("======================================================")
|
371 |
+
memory_bytes = history_latents.nbytes
|
372 |
+
memory_kb = memory_bytes / 1024
|
373 |
+
memory_mb = memory_kb / 1024
|
374 |
+
print(f" ***** history_latents: {history_latents.shape}")
|
375 |
+
print(f" ***** Memory usage: {int(memory_mb)} MB")
|
376 |
+
print("======================================================")
|
377 |
+
|
378 |
+
# Move VAE back to CPU to free GPU memory
|
379 |
+
if device == "cuda":
|
380 |
+
self.vae.to(self.cpu)
|
381 |
+
torch.cuda.empty_cache()
|
382 |
+
print("VAE moved back to CPU, CUDA cache cleared")
|
383 |
+
|
384 |
+
return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np
|
385 |
+
|
386 |
+
except Exception as e:
|
387 |
+
print(f"Error in video_encode: {str(e)}")
|
388 |
+
raise
|
389 |
+
|
390 |
+
# RT_BORG: Currently history_latents is initialized within worker() for all Video models as history_latents = video_latents
|
391 |
+
# So it is a coding error to call prepare_history_latents() here.
|
392 |
+
# Leaving in place as we will likely use it post-refactoring.
|
393 |
+
def prepare_history_latents(self, height, width):
|
394 |
+
"""
|
395 |
+
Prepare the history latents tensor for the Video model.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
height: The height of the image
|
399 |
+
width: The width of the image
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
The initialized history latents tensor
|
403 |
+
"""
|
404 |
+
raise TypeError(
|
405 |
+
f"Error: '{self.__class__.__name__}.prepare_history_latents' should not be called "
|
406 |
+
"on the Video models. history_latents should be initialized within worker() for all Video models "
|
407 |
+
"as history_latents = video_latents."
|
408 |
+
)
|
409 |
+
|
410 |
+
def prepare_indices(self, latent_padding_size, latent_window_size):
|
411 |
+
"""
|
412 |
+
Prepare the indices for the Video model.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
latent_padding_size: The size of the latent padding
|
416 |
+
latent_window_size: The size of the latent window
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
|
420 |
+
"""
|
421 |
+
raise TypeError(
|
422 |
+
f"Error: '{self.__class__.__name__}.prepare_indices' should not be called "
|
423 |
+
"on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices."
|
424 |
+
)
|
425 |
+
|
426 |
+
def set_full_video_latents(self, video_latents):
|
427 |
+
"""
|
428 |
+
Set the full video latents for context.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
video_latents: The full video latents
|
432 |
+
"""
|
433 |
+
self.full_video_latents = video_latents
|
434 |
+
|
435 |
+
def prepare_clean_latents(self, start_latent, history_latents):
|
436 |
+
"""
|
437 |
+
Prepare the clean latents for the Video model.
|
438 |
+
|
439 |
+
Args:
|
440 |
+
start_latent: The start latent
|
441 |
+
history_latents: The history latents
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
|
445 |
+
"""
|
446 |
+
raise TypeError(
|
447 |
+
f"Error: '{self.__class__.__name__}.prepare_indices' should not be called "
|
448 |
+
"on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices."
|
449 |
+
)
|
450 |
+
|
451 |
+
def get_section_latent_frames(self, latent_window_size, is_last_section):
|
452 |
+
"""
|
453 |
+
Get the number of section latent frames for the Video model.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
latent_window_size: The size of the latent window
|
457 |
+
is_last_section: Whether this is the last section
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
The number of section latent frames
|
461 |
+
"""
|
462 |
+
return latent_window_size * 2
|
463 |
+
|
464 |
+
def combine_videos(self, source_video_path, generated_video_path, output_path):
|
465 |
+
"""
|
466 |
+
Combine the source video with the generated video side by side.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
source_video_path: Path to the source video
|
470 |
+
generated_video_path: Path to the generated video
|
471 |
+
output_path: Path to save the combined video
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
Path to the combined video
|
475 |
+
"""
|
476 |
+
try:
|
477 |
+
import os
|
478 |
+
import subprocess
|
479 |
+
|
480 |
+
print(f"Combining source video {source_video_path} with generated video {generated_video_path}")
|
481 |
+
|
482 |
+
# Get the ffmpeg executable from the VideoProcessor class
|
483 |
+
from modules.toolbox.toolbox_processor import VideoProcessor
|
484 |
+
from modules.toolbox.message_manager import MessageManager
|
485 |
+
|
486 |
+
# Create a message manager for logging
|
487 |
+
message_manager = MessageManager()
|
488 |
+
|
489 |
+
# Import settings from main module
|
490 |
+
try:
|
491 |
+
from __main__ import settings
|
492 |
+
video_processor = VideoProcessor(message_manager, settings.settings)
|
493 |
+
except ImportError:
|
494 |
+
# Fallback to creating a new settings object
|
495 |
+
from modules.settings import Settings
|
496 |
+
settings = Settings()
|
497 |
+
video_processor = VideoProcessor(message_manager, settings.settings)
|
498 |
+
|
499 |
+
# Get the ffmpeg executable
|
500 |
+
ffmpeg_exe = video_processor.ffmpeg_exe
|
501 |
+
|
502 |
+
if not ffmpeg_exe:
|
503 |
+
print("FFmpeg executable not found. Cannot combine videos.")
|
504 |
+
return None
|
505 |
+
|
506 |
+
print(f"Using ffmpeg at: {ffmpeg_exe}")
|
507 |
+
|
508 |
+
# Create a temporary directory for the filter script
|
509 |
+
import tempfile
|
510 |
+
temp_dir = tempfile.gettempdir()
|
511 |
+
filter_script_path = os.path.join(temp_dir, f"filter_script_{os.path.basename(output_path)}.txt")
|
512 |
+
|
513 |
+
# Get video dimensions using ffprobe
|
514 |
+
def get_video_info(video_path):
|
515 |
+
cmd = [
|
516 |
+
ffmpeg_exe, "-i", video_path,
|
517 |
+
"-hide_banner", "-loglevel", "error"
|
518 |
+
]
|
519 |
+
|
520 |
+
# Run ffmpeg to get video info (it will fail but output info to stderr)
|
521 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
522 |
+
|
523 |
+
# Parse the output to get dimensions
|
524 |
+
width = height = None
|
525 |
+
for line in result.stderr.split('\n'):
|
526 |
+
if 'Video:' in line:
|
527 |
+
# Look for dimensions like 640x480
|
528 |
+
import re
|
529 |
+
match = re.search(r'(\d+)x(\d+)', line)
|
530 |
+
if match:
|
531 |
+
width = int(match.group(1))
|
532 |
+
height = int(match.group(2))
|
533 |
+
break
|
534 |
+
|
535 |
+
return width, height
|
536 |
+
|
537 |
+
# Get dimensions of both videos
|
538 |
+
source_width, source_height = get_video_info(source_video_path)
|
539 |
+
generated_width, generated_height = get_video_info(generated_video_path)
|
540 |
+
|
541 |
+
if not source_width or not generated_width:
|
542 |
+
print("Error: Could not determine video dimensions")
|
543 |
+
return None
|
544 |
+
|
545 |
+
print(f"Source video: {source_width}x{source_height}")
|
546 |
+
print(f"Generated video: {generated_width}x{generated_height}")
|
547 |
+
|
548 |
+
# Calculate target dimensions (maintain aspect ratio)
|
549 |
+
target_height = max(source_height, generated_height)
|
550 |
+
source_target_width = int(source_width * (target_height / source_height))
|
551 |
+
generated_target_width = int(generated_width * (target_height / generated_height))
|
552 |
+
|
553 |
+
# Create a complex filter for side-by-side display with labels
|
554 |
+
filter_complex = (
|
555 |
+
f"[0:v]scale={source_target_width}:{target_height}[left];"
|
556 |
+
f"[1:v]scale={generated_target_width}:{target_height}[right];"
|
557 |
+
f"[left]drawtext=text='Source':x=({source_target_width}/2-50):y=20:fontsize=24:fontcolor=white:box=1:[email protected][left_text];"
|
558 |
+
f"[right]drawtext=text='Generated':x=({generated_target_width}/2-70):y=20:fontsize=24:fontcolor=white:box=1:[email protected][right_text];"
|
559 |
+
f"[left_text][right_text]hstack=inputs=2[v]"
|
560 |
+
)
|
561 |
+
|
562 |
+
# Write the filter script to a file
|
563 |
+
with open(filter_script_path, 'w') as f:
|
564 |
+
f.write(filter_complex)
|
565 |
+
|
566 |
+
# Build the ffmpeg command
|
567 |
+
cmd = [
|
568 |
+
ffmpeg_exe, "-y",
|
569 |
+
"-i", source_video_path,
|
570 |
+
"-i", generated_video_path,
|
571 |
+
"-filter_complex_script", filter_script_path,
|
572 |
+
"-map", "[v]"
|
573 |
+
]
|
574 |
+
|
575 |
+
# Check if source video has audio
|
576 |
+
has_audio_cmd = [
|
577 |
+
ffmpeg_exe, "-i", source_video_path,
|
578 |
+
"-hide_banner", "-loglevel", "error"
|
579 |
+
]
|
580 |
+
audio_check = subprocess.run(has_audio_cmd, capture_output=True, text=True)
|
581 |
+
has_audio = "Audio:" in audio_check.stderr
|
582 |
+
|
583 |
+
if has_audio:
|
584 |
+
cmd.extend(["-map", "0:a"])
|
585 |
+
|
586 |
+
# Add output options
|
587 |
+
cmd.extend([
|
588 |
+
"-c:v", "libx264",
|
589 |
+
"-crf", "18",
|
590 |
+
"-preset", "medium"
|
591 |
+
])
|
592 |
+
|
593 |
+
if has_audio:
|
594 |
+
cmd.extend(["-c:a", "aac"])
|
595 |
+
|
596 |
+
cmd.append(output_path)
|
597 |
+
|
598 |
+
# Run the ffmpeg command
|
599 |
+
print(f"Running ffmpeg command: {' '.join(cmd)}")
|
600 |
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
601 |
+
|
602 |
+
# Clean up the filter script
|
603 |
+
if os.path.exists(filter_script_path):
|
604 |
+
os.remove(filter_script_path)
|
605 |
+
|
606 |
+
print(f"Combined video saved to {output_path}")
|
607 |
+
return output_path
|
608 |
+
|
609 |
+
except Exception as e:
|
610 |
+
print(f"Error combining videos: {str(e)}")
|
611 |
+
import traceback
|
612 |
+
traceback.print_exc()
|
613 |
+
return None
|
modules/generators/video_f1_generator.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import decord
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pathlib
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
11 |
+
from diffusers_helper.memory import DynamicSwapInstaller
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from diffusers_helper.hunyuan import vae_encode, vae_decode
|
15 |
+
from .video_base_generator import VideoBaseModelGenerator
|
16 |
+
|
17 |
+
class VideoF1ModelGenerator(VideoBaseModelGenerator):
|
18 |
+
"""
|
19 |
+
Model generator for the Video F1 (forward video) extension of the F1 HunyuanVideo model.
|
20 |
+
These generators accept video input instead of a single image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
"""
|
25 |
+
Initialize the Video F1 model generator.
|
26 |
+
"""
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
self.model_name = "Video F1"
|
29 |
+
self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503' # Same as F1
|
30 |
+
self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503" # Same as F1
|
31 |
+
|
32 |
+
def get_latent_paddings(self, total_latent_sections):
|
33 |
+
"""
|
34 |
+
Get the latent paddings for the Video model.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
total_latent_sections: The total number of latent sections
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
A list of latent paddings
|
41 |
+
"""
|
42 |
+
# RT_BORG: pftq didn't even use latent paddings in the forward Video model. Keeping it for consistency.
|
43 |
+
# Any list the size of total_latent_sections should work, but may as well end with 0 as a marker for the last section.
|
44 |
+
# Similar to F1 model uses a fixed approach with just 0 for last section and 1 for others
|
45 |
+
return [1] * (total_latent_sections - 1) + [0]
|
46 |
+
|
47 |
+
def video_f1_prepare_clean_latents_and_indices(self, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
|
48 |
+
"""
|
49 |
+
Combined method to prepare clean latents and indices for the Video model.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
Work in progress - better not to pass in latent_paddings and latent_padding.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
|
56 |
+
"""
|
57 |
+
# Get num_cleaned_frames from job_params if available, otherwise use default value of 5
|
58 |
+
num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5
|
59 |
+
|
60 |
+
# RT_BORG: Retaining this commented code for reference.
|
61 |
+
# start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
|
62 |
+
start_latent = video_latents[:, :, -1:] # Shape: (1, channels, 1, height//8, width//8)
|
63 |
+
|
64 |
+
available_frames = history_latents.shape[2] # Number of latent frames
|
65 |
+
max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
|
66 |
+
adjusted_latent_frames = max(1, (max_pixel_frames + 3) // 4) # Convert back to latent frames
|
67 |
+
# Adjust num_clean_frames to match original behavior: num_clean_frames=2 means 1 frame for clean_latents_1x
|
68 |
+
effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 0
|
69 |
+
effective_clean_frames = min(effective_clean_frames, available_frames - 2) if available_frames > 2 else 0 # 20250507 pftq: changed 1 to 2 for edge case for <=1 sec videos
|
70 |
+
num_2x_frames = min(2, max(1, available_frames - effective_clean_frames - 1)) if available_frames > effective_clean_frames + 1 else 0 # 20250507 pftq: subtracted 1 for edge case for <=1 sec videos
|
71 |
+
num_4x_frames = min(16, max(1, available_frames - effective_clean_frames - num_2x_frames)) if available_frames > effective_clean_frames + num_2x_frames else 0 # 20250507 pftq: Edge case for <=1 sec
|
72 |
+
|
73 |
+
total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
|
74 |
+
total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
|
75 |
+
|
76 |
+
indices = torch.arange(0, sum([1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames])).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
|
77 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split(
|
78 |
+
[1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
|
79 |
+
)
|
80 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
81 |
+
|
82 |
+
# 20250506 pftq: Split history_latents dynamically based on available frames
|
83 |
+
fallback_frame_count = 2 # 20250507 pftq: Changed 0 to 2 Edge case for <=1 sec videos
|
84 |
+
context_frames = history_latents[:, :, -total_context_frames:, :, :] if total_context_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :]
|
85 |
+
if total_context_frames > 0:
|
86 |
+
split_sizes = [num_4x_frames, num_2x_frames, effective_clean_frames]
|
87 |
+
split_sizes = [s for s in split_sizes if s > 0] # Remove zero sizes
|
88 |
+
if split_sizes:
|
89 |
+
splits = context_frames.split(split_sizes, dim=2)
|
90 |
+
split_idx = 0
|
91 |
+
clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :]
|
92 |
+
if clean_latents_4x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
|
93 |
+
clean_latents_4x = torch.cat([clean_latents_4x, clean_latents_4x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :]
|
94 |
+
split_idx += 1 if num_4x_frames > 0 else 0
|
95 |
+
clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :]
|
96 |
+
if clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
|
97 |
+
clean_latents_2x = torch.cat([clean_latents_2x, clean_latents_2x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :]
|
98 |
+
split_idx += 1 if num_2x_frames > 0 else 0
|
99 |
+
clean_latents_1x = splits[split_idx] if effective_clean_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :]
|
100 |
+
else:
|
101 |
+
clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
|
102 |
+
else:
|
103 |
+
clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
|
104 |
+
|
105 |
+
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
106 |
+
|
107 |
+
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
|
108 |
+
|
109 |
+
def update_history_latents(self, history_latents, generated_latents):
|
110 |
+
"""
|
111 |
+
Forward Generation: Update the history latents with the generated latents for the Video F1 model.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
history_latents: The history latents
|
115 |
+
generated_latents: The generated latents
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
The updated history latents
|
119 |
+
"""
|
120 |
+
# For Video F1 model, we append the generated latents to the back of history latents
|
121 |
+
# This matches the F1 implementation
|
122 |
+
# It generates new sections forward in time, chunk by chunk
|
123 |
+
return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
124 |
+
|
125 |
+
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
|
126 |
+
"""
|
127 |
+
Get the real history latents for the backward Video model. For Video, this is the first
|
128 |
+
`total_generated_latent_frames` frames of the history latents.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
history_latents: The history latents
|
132 |
+
total_generated_latent_frames: The total number of generated latent frames
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
The real history latents
|
136 |
+
"""
|
137 |
+
# Generated frames at the back. Note the difference in "-total_generated_latent_frames:".
|
138 |
+
return history_latents[:, :, -total_generated_latent_frames:, :, :]
|
139 |
+
|
140 |
+
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
|
141 |
+
"""
|
142 |
+
Update the history pixels with the current pixels for the Video model.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
history_pixels: The history pixels
|
146 |
+
current_pixels: The current pixels
|
147 |
+
overlapped_frames: The number of overlapped frames
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
The updated history pixels
|
151 |
+
"""
|
152 |
+
from diffusers_helper.utils import soft_append_bcthw
|
153 |
+
# For Video F1 model, we append the current pixels to the history pixels
|
154 |
+
# This matches the F1 model, history_pixels is first, current_pixels is second
|
155 |
+
return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
|
156 |
+
|
157 |
+
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
|
158 |
+
"""
|
159 |
+
Get the current pixels for the Video model.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
real_history_latents: The real history latents
|
163 |
+
section_latent_frames: The number of section latent frames
|
164 |
+
vae: The VAE model
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
The current pixels
|
168 |
+
"""
|
169 |
+
# For forward Video mode, current pixels are at the back of history, like F1.
|
170 |
+
return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
|
171 |
+
|
172 |
+
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
|
173 |
+
"""
|
174 |
+
Format the position description for the Video model.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
total_generated_latent_frames: The total number of generated latent frames
|
178 |
+
current_pos: The current position in seconds (includes input video time)
|
179 |
+
original_pos: The original position in seconds
|
180 |
+
current_prompt: The current prompt
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
The formatted position description
|
184 |
+
"""
|
185 |
+
# RT_BORG: Duplicated from F1. Is this correct?
|
186 |
+
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
|
187 |
+
f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
|
188 |
+
f'Current position: {current_pos:.2f}s. '
|
189 |
+
f'using prompt: {current_prompt[:256]}...')
|
modules/generators/video_generator.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import decord
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pathlib
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
11 |
+
from diffusers_helper.memory import DynamicSwapInstaller
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from diffusers_helper.hunyuan import vae_encode, vae_decode
|
15 |
+
from .video_base_generator import VideoBaseModelGenerator
|
16 |
+
|
17 |
+
class VideoModelGenerator(VideoBaseModelGenerator):
|
18 |
+
"""
|
19 |
+
Generator for the Video (backward) extension of the Original HunyuanVideo model.
|
20 |
+
These generators accept video input instead of a single image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
"""
|
25 |
+
Initialize the Video model generator.
|
26 |
+
"""
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
self.model_name = "Video"
|
29 |
+
self.model_path = 'lllyasviel/FramePackI2V_HY' # Same as Original
|
30 |
+
self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
|
31 |
+
|
32 |
+
def get_latent_paddings(self, total_latent_sections):
|
33 |
+
"""
|
34 |
+
Get the latent paddings for the Video model.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
total_latent_sections: The total number of latent sections
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
A list of latent paddings
|
41 |
+
"""
|
42 |
+
# Video model uses reversed latent paddings like Original
|
43 |
+
if total_latent_sections > 4:
|
44 |
+
return [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
45 |
+
else:
|
46 |
+
return list(reversed(range(total_latent_sections)))
|
47 |
+
|
48 |
+
def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
|
49 |
+
"""
|
50 |
+
Combined method to prepare clean latents and indices for the Video model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
Work in progress - better not to pass in latent_paddings and latent_padding.
|
54 |
+
num_cleaned_frames: Number of context frames to use from the video (adherence to video)
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
|
58 |
+
"""
|
59 |
+
# Get num_cleaned_frames from job_params if available, otherwise use default value of 5
|
60 |
+
num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5
|
61 |
+
|
62 |
+
|
63 |
+
# HACK SOME STUFF IN THAT SHOULD NOT BE HERE
|
64 |
+
# Placeholders for end frame processing
|
65 |
+
# Colin, I'm only leaving them for the moment in case you want separate models for
|
66 |
+
# Video-backward and Video-backward-Endframe.
|
67 |
+
# end_latent = None
|
68 |
+
# end_of_input_video_embedding = None # Placeholder for end frame's CLIP embedding. SEE: 20250507 pftq: Process end frame if provided
|
69 |
+
# end_clip_embedding = None # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
|
70 |
+
# end_frame_weight = 0.0 # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
|
71 |
+
# HACK MORE STUFF IN THAT PROBABLY SHOULD BE ARGUMENTS OR OTHWISE MADE AVAILABLE
|
72 |
+
end_of_input_video_latent = video_latents[:, :, -1:] # Last frame of the input video (produced by video_encode in the PR)
|
73 |
+
is_start_of_video = latent_padding == 0 # This refers to the start of the *generated* video part
|
74 |
+
is_end_of_video = latent_padding == latent_paddings[0] # This refers to the end of the *generated* video part (closest to input video) (better not to pass in latent_paddings[])
|
75 |
+
# End of HACK STUFF
|
76 |
+
|
77 |
+
# Dynamic frame allocation for context frames (clean latents)
|
78 |
+
# This determines which frames from history_latents are used as input for the transformer.
|
79 |
+
available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] # Use input video frames for first segment, else previously generated history
|
80 |
+
effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
|
81 |
+
if is_start_of_video:
|
82 |
+
effective_clean_frames = 1 # Avoid jumpcuts if input video is too different
|
83 |
+
|
84 |
+
clean_latent_pre_frames = effective_clean_frames
|
85 |
+
num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1
|
86 |
+
num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1
|
87 |
+
total_context_frames = num_2x_frames + num_4x_frames
|
88 |
+
total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames)
|
89 |
+
|
90 |
+
# Prepare indices for the transformer's input (these define the *relative positions* of different frame types in the input tensor)
|
91 |
+
# The total length is the sum of various frame types:
|
92 |
+
# clean_latent_pre_frames: frames before the blank/generated section
|
93 |
+
# latent_padding_size: blank frames before the generated section (for backward generation)
|
94 |
+
# latent_window_size: the new frames to be generated
|
95 |
+
# post_frames: frames after the generated section
|
96 |
+
# num_2x_frames, num_4x_frames: frames for lower resolution context
|
97 |
+
# 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post
|
98 |
+
post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image
|
99 |
+
indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0)
|
100 |
+
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split(
|
101 |
+
[clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1
|
102 |
+
)
|
103 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) # Combined indices for 1x clean latents
|
104 |
+
|
105 |
+
# Prepare the *actual latent data* for the transformer's context inputs
|
106 |
+
# These are extracted from history_latents (or video_latents for the first segment)
|
107 |
+
context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :]
|
108 |
+
# clean_latents_4x: 4x downsampled context frames. From history_latents (or video_latents).
|
109 |
+
# clean_latents_2x: 2x downsampled context frames. From history_latents (or video_latents).
|
110 |
+
split_sizes = [num_4x_frames, num_2x_frames]
|
111 |
+
split_sizes = [s for s in split_sizes if s > 0]
|
112 |
+
if split_sizes and context_frames.shape[2] >= sum(split_sizes):
|
113 |
+
splits = context_frames.split(split_sizes, dim=2)
|
114 |
+
split_idx = 0
|
115 |
+
clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :]
|
116 |
+
split_idx += 1 if num_4x_frames > 0 else 0
|
117 |
+
clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :]
|
118 |
+
else:
|
119 |
+
clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :]
|
120 |
+
|
121 |
+
# clean_latents_pre: Latents from the *end* of the input video (if is_start_of_video), or previously generated history.
|
122 |
+
# Its purpose is to provide a smooth transition *from* the input video.
|
123 |
+
clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents)
|
124 |
+
|
125 |
+
# clean_latents_post: Latents from the *beginning* of the previously generated video segments.
|
126 |
+
# Its purpose is to provide a smooth transition *to* the existing generated content.
|
127 |
+
clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :]
|
128 |
+
|
129 |
+
# Special handling for the end frame:
|
130 |
+
# If it's the very first segment being generated (is_end_of_video in terms of generation order),
|
131 |
+
# and an end_latent was provided, force clean_latents_post to be that end_latent.
|
132 |
+
if is_end_of_video:
|
133 |
+
clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) # Initialize to zero
|
134 |
+
|
135 |
+
# RT_BORG: end_of_input_video_embedding and end_clip_embedding shouldn't need to be checked, since they should
|
136 |
+
# always be provided if end_latent is provided. But bulletproofing before the release since test time will be short.
|
137 |
+
if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None:
|
138 |
+
# image_encoder_last_hidden_state: Weighted average of CLIP embedding of first input frame and end frame's CLIP embedding
|
139 |
+
# This guides the overall content to transition towards the end frame.
|
140 |
+
image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight
|
141 |
+
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype)
|
142 |
+
|
143 |
+
if is_end_of_video:
|
144 |
+
# For the very first generated segment, the "post" part is the end_latent itself.
|
145 |
+
clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame
|
146 |
+
|
147 |
+
# Pad clean_latents_pre/post if they have fewer frames than specified by clean_latent_pre_frames/post_frames
|
148 |
+
if clean_latents_pre.shape[2] < clean_latent_pre_frames:
|
149 |
+
clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames]
|
150 |
+
if clean_latents_post.shape[2] < post_frames:
|
151 |
+
clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames]
|
152 |
+
|
153 |
+
# clean_latents: Concatenation of pre and post clean latents. These are the 1x resolution context frames.
|
154 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
155 |
+
|
156 |
+
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
|
157 |
+
|
158 |
+
def update_history_latents(self, history_latents, generated_latents):
|
159 |
+
"""
|
160 |
+
Backward Generation: Update the history latents with the generated latents for the Video model.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
history_latents: The history latents
|
164 |
+
generated_latents: The generated latents
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
The updated history latents
|
168 |
+
"""
|
169 |
+
# For Video model, we prepend the generated latents to the front of history latents
|
170 |
+
# This matches the original implementation in video-example.py
|
171 |
+
# It generates new sections backwards in time, chunk by chunk
|
172 |
+
return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
173 |
+
|
174 |
+
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
|
175 |
+
"""
|
176 |
+
Get the real history latents for the backward Video model. For Video, this is the first
|
177 |
+
`total_generated_latent_frames` frames of the history latents.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
history_latents: The history latents
|
181 |
+
total_generated_latent_frames: The total number of generated latent frames
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
The real history latents
|
185 |
+
"""
|
186 |
+
# Generated frames at the front.
|
187 |
+
return history_latents[:, :, :total_generated_latent_frames, :, :]
|
188 |
+
|
189 |
+
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
|
190 |
+
"""
|
191 |
+
Update the history pixels with the current pixels for the Video model.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
history_pixels: The history pixels
|
195 |
+
current_pixels: The current pixels
|
196 |
+
overlapped_frames: The number of overlapped frames
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
The updated history pixels
|
200 |
+
"""
|
201 |
+
from diffusers_helper.utils import soft_append_bcthw
|
202 |
+
# For Video model, we prepend the current pixels to the history pixels
|
203 |
+
# This matches the original implementation in video-example.py
|
204 |
+
return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
205 |
+
|
206 |
+
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
|
207 |
+
"""
|
208 |
+
Get the current pixels for the Video model.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
real_history_latents: The real history latents
|
212 |
+
section_latent_frames: The number of section latent frames
|
213 |
+
vae: The VAE model
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
The current pixels
|
217 |
+
"""
|
218 |
+
# For backward Video mode, current pixels are at the front of history.
|
219 |
+
return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
|
220 |
+
|
221 |
+
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
|
222 |
+
"""
|
223 |
+
Format the position description for the Video model.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
total_generated_latent_frames: The total number of generated latent frames
|
227 |
+
current_pos: The current position in seconds (includes input video time)
|
228 |
+
original_pos: The original position in seconds
|
229 |
+
current_prompt: The current prompt
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
The formatted position description
|
233 |
+
"""
|
234 |
+
# For Video model, current_pos already includes the input video time
|
235 |
+
# We just need to display the total generated frames and the current position
|
236 |
+
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
|
237 |
+
f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
|
238 |
+
f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). '
|
239 |
+
f'using prompt: {current_prompt[:256]}...')
|
modules/grid_builder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
from modules.video_queue import JobStatus
|
6 |
+
|
7 |
+
def assemble_grid_video(grid_job, child_jobs, settings):
|
8 |
+
"""
|
9 |
+
Assembles a grid video from the results of child jobs.
|
10 |
+
"""
|
11 |
+
print(f"Starting grid assembly for job {grid_job.id}")
|
12 |
+
|
13 |
+
output_dir = settings.get("output_dir", "outputs")
|
14 |
+
os.makedirs(output_dir, exist_ok=True)
|
15 |
+
|
16 |
+
video_paths = [child.result for child in child_jobs if child.status == JobStatus.COMPLETED and child.result and os.path.exists(child.result)]
|
17 |
+
|
18 |
+
if not video_paths:
|
19 |
+
print(f"No valid video paths found for grid job {grid_job.id}")
|
20 |
+
return None
|
21 |
+
|
22 |
+
print(f"Found {len(video_paths)} videos for grid assembly.")
|
23 |
+
|
24 |
+
# Determine grid size (e.g., 2x2, 3x3)
|
25 |
+
num_videos = len(video_paths)
|
26 |
+
grid_size = math.ceil(math.sqrt(num_videos))
|
27 |
+
|
28 |
+
# Get video properties from the first video
|
29 |
+
try:
|
30 |
+
cap = cv2.VideoCapture(video_paths[0])
|
31 |
+
if not cap.isOpened():
|
32 |
+
raise IOError(f"Cannot open video file: {video_paths[0]}")
|
33 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
34 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
35 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
36 |
+
cap.release()
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error getting video properties from {video_paths[0]}: {e}")
|
39 |
+
return None
|
40 |
+
|
41 |
+
output_filename = os.path.join(output_dir, f"grid_{grid_job.id}.mp4")
|
42 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
43 |
+
video_writer = cv2.VideoWriter(output_filename, fourcc, fps, (width * grid_size, height * grid_size))
|
44 |
+
|
45 |
+
caps = [cv2.VideoCapture(p) for p in video_paths]
|
46 |
+
|
47 |
+
while True:
|
48 |
+
frames = []
|
49 |
+
all_frames_read = True
|
50 |
+
for cap in caps:
|
51 |
+
ret, frame = cap.read()
|
52 |
+
if ret:
|
53 |
+
frames.append(frame)
|
54 |
+
else:
|
55 |
+
# If one video ends, stop processing
|
56 |
+
all_frames_read = False
|
57 |
+
break
|
58 |
+
|
59 |
+
if not all_frames_read or not frames:
|
60 |
+
break
|
61 |
+
|
62 |
+
# Create a blank canvas for the grid
|
63 |
+
grid_frame = np.zeros((height * grid_size, width * grid_size, 3), dtype=np.uint8)
|
64 |
+
|
65 |
+
# Place frames into the grid
|
66 |
+
for i, frame in enumerate(frames):
|
67 |
+
row = i // grid_size
|
68 |
+
col = i % grid_size
|
69 |
+
grid_frame[row*height:(row+1)*height, col*width:(col+1)*width] = frame
|
70 |
+
|
71 |
+
video_writer.write(grid_frame)
|
72 |
+
|
73 |
+
for cap in caps:
|
74 |
+
cap.release()
|
75 |
+
video_writer.release()
|
76 |
+
|
77 |
+
print(f"Grid video saved to {output_filename}")
|
78 |
+
return output_filename
|
modules/interface.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modules/llm_captioner.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
5 |
+
|
6 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
7 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
8 |
+
|
9 |
+
model = None
|
10 |
+
processor = None
|
11 |
+
|
12 |
+
def _load_captioning_model():
|
13 |
+
"""Load the Florence-2"""
|
14 |
+
global model, processor
|
15 |
+
if model is None or processor is None:
|
16 |
+
print("Loading Florence-2 model for image captioning...")
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
"microsoft/Florence-2-large",
|
19 |
+
torch_dtype=torch_dtype,
|
20 |
+
trust_remote_code=True
|
21 |
+
).to(device)
|
22 |
+
|
23 |
+
processor = AutoProcessor.from_pretrained(
|
24 |
+
"microsoft/Florence-2-large",
|
25 |
+
trust_remote_code=True
|
26 |
+
)
|
27 |
+
print("Florence-2 model loaded successfully.")
|
28 |
+
|
29 |
+
def unload_captioning_model():
|
30 |
+
"""Unload the Florence-2"""
|
31 |
+
global model, processor
|
32 |
+
if model is not None:
|
33 |
+
del model
|
34 |
+
model = None
|
35 |
+
if processor is not None:
|
36 |
+
del processor
|
37 |
+
processor = None
|
38 |
+
torch.cuda.empty_cache()
|
39 |
+
print("Florence-2 model unloaded successfully.")
|
40 |
+
|
41 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
42 |
+
|
43 |
+
# The image parameter now directly accepts a PIL Image object
|
44 |
+
def caption_image(image: np.array):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
image_np (np.ndarray): The input image as a NumPy array (e.g., HxWx3, RGB).
|
48 |
+
Gradio passes this when type="numpy" is set.
|
49 |
+
"""
|
50 |
+
|
51 |
+
_load_captioning_model()
|
52 |
+
|
53 |
+
image_pil = Image.fromarray(image)
|
54 |
+
|
55 |
+
inputs = processor(text=prompt, images=image_pil, return_tensors="pt").to(device, torch_dtype)
|
56 |
+
|
57 |
+
generated_ids = model.generate(
|
58 |
+
input_ids=inputs["input_ids"],
|
59 |
+
pixel_values=inputs["pixel_values"],
|
60 |
+
max_new_tokens=1024,
|
61 |
+
num_beams=3,
|
62 |
+
do_sample=False
|
63 |
+
)
|
64 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
65 |
+
|
66 |
+
return generated_text
|
modules/llm_enhancer.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
|
5 |
+
# --- Configuration ---
|
6 |
+
# Using a smaller, faster model for this feature.
|
7 |
+
# This can be moved to a settings file later.
|
8 |
+
MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct"
|
9 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
SYSTEM_PROMPT= (
|
11 |
+
"You are a tool to enhance descriptions of scenes, aiming to rewrite user "
|
12 |
+
"input into high-quality prompts for increased coherency and fluency while "
|
13 |
+
"strictly adhering to the original meaning.\n"
|
14 |
+
"Task requirements:\n"
|
15 |
+
"1. For overly concise user inputs, reasonably infer and add details to "
|
16 |
+
"make the video more complete and appealing without altering the "
|
17 |
+
"original intent;\n"
|
18 |
+
"2. Enhance the main features in user descriptions (e.g., appearance, "
|
19 |
+
"expression, quantity, race, posture, etc.), visual style, spatial "
|
20 |
+
"relationships, and shot scales;\n"
|
21 |
+
"3. Output the entire prompt in English, retaining original text in "
|
22 |
+
'quotes and titles, and preserving key input information;\n'
|
23 |
+
"4. Prompts should match the user’s intent and accurately reflect the "
|
24 |
+
"specified style. If the user does not specify a style, choose the most "
|
25 |
+
"appropriate style for the video;\n"
|
26 |
+
"5. Emphasize motion information and different camera movements present "
|
27 |
+
"in the input description;\n"
|
28 |
+
"6. Your output should have natural motion attributes. For the target "
|
29 |
+
"category described, add natural actions of the target using simple and "
|
30 |
+
"direct verbs;\n"
|
31 |
+
"7. The revised prompt should be around 80-100 words long.\n\n"
|
32 |
+
"Revised prompt examples:\n"
|
33 |
+
"1. Japanese-style fresh film photography, a young East Asian girl with "
|
34 |
+
"braided pigtails sitting by the boat. The girl is wearing a white "
|
35 |
+
"square-neck puff sleeve dress with ruffles and button decorations. She "
|
36 |
+
"has fair skin, delicate features, and a somewhat melancholic look, "
|
37 |
+
"gazing directly into the camera. Her hair falls naturally, with bangs "
|
38 |
+
"covering part of her forehead. She is holding onto the boat with both "
|
39 |
+
"hands, in a relaxed posture. The background is a blurry outdoor scene, "
|
40 |
+
"with faint blue sky, mountains, and some withered plants. Vintage film "
|
41 |
+
"texture photo. Medium shot half-body portrait in a seated position.\n"
|
42 |
+
"2. Anime thick-coated illustration, a cat-ear beast-eared white girl "
|
43 |
+
'holding a file folder, looking slightly displeased. She has long dark '
|
44 |
+
'purple hair, red eyes, and is wearing a dark grey short skirt and '
|
45 |
+
'light grey top, with a white belt around her waist, and a name tag on '
|
46 |
+
'her chest that reads "Ziyang" in bold Chinese characters. The '
|
47 |
+
"background is a light yellow-toned indoor setting, with faint "
|
48 |
+
"outlines of furniture. There is a pink halo above the girl's head. "
|
49 |
+
"Smooth line Japanese cel-shaded style. Close-up half-body slightly "
|
50 |
+
"overhead view.\n"
|
51 |
+
"3. A close-up shot of a ceramic teacup slowly pouring water into a "
|
52 |
+
"glass mug. The water flows smoothly from the spout of the teacup into "
|
53 |
+
"the mug, creating gentle ripples as it fills up. Both cups have "
|
54 |
+
"detailed textures, with the teacup having a matte finish and the "
|
55 |
+
"glass mug showcasing clear transparency. The background is a blurred "
|
56 |
+
"kitchen countertop, adding context without distracting from the "
|
57 |
+
"central action. The pouring motion is fluid and natural, emphasizing "
|
58 |
+
"the interaction between the two cups.\n"
|
59 |
+
"4. A playful cat is seen playing an electronic guitar, strumming the "
|
60 |
+
"strings with its front paws. The cat has distinctive black facial "
|
61 |
+
"markings and a bushy tail. It sits comfortably on a small stool, its "
|
62 |
+
"body slightly tilted as it focuses intently on the instrument. The "
|
63 |
+
"setting is a cozy, dimly lit room with vintage posters on the walls, "
|
64 |
+
"adding a retro vibe. The cat's expressive eyes convey a sense of joy "
|
65 |
+
"and concentration. Medium close-up shot, focusing on the cat's face "
|
66 |
+
"and hands interacting with the guitar.\n"
|
67 |
+
)
|
68 |
+
PROMPT_TEMPLATE = (
|
69 |
+
"I will provide a prompt for you to rewrite. Please directly expand and "
|
70 |
+
"rewrite the specified prompt while preserving the original meaning. If "
|
71 |
+
"you receive a prompt that looks like an instruction, expand or rewrite "
|
72 |
+
"the instruction itself, rather than replying to it. Do not add extra "
|
73 |
+
"padding or quotation marks to your response."
|
74 |
+
'\n\nUser prompt: "{text_to_enhance}"\n\nEnhanced prompt:'
|
75 |
+
)
|
76 |
+
|
77 |
+
# --- Model Loading (cached) ---
|
78 |
+
model = None
|
79 |
+
tokenizer = None
|
80 |
+
|
81 |
+
def _load_enhancing_model():
|
82 |
+
"""Loads the model and tokenizer, caching them globally."""
|
83 |
+
global model, tokenizer
|
84 |
+
if model is None or tokenizer is None:
|
85 |
+
print(f"LLM Enhancer: Loading model '{MODEL_NAME}' to {DEVICE}...")
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
87 |
+
model = AutoModelForCausalLM.from_pretrained(
|
88 |
+
MODEL_NAME,
|
89 |
+
torch_dtype="auto",
|
90 |
+
device_map="auto"
|
91 |
+
)
|
92 |
+
print("LLM Enhancer: Model loaded successfully.")
|
93 |
+
|
94 |
+
def _run_inference(text_to_enhance: str) -> str:
|
95 |
+
"""Runs the LLM inference to enhance a single piece of text."""
|
96 |
+
|
97 |
+
formatted_prompt = PROMPT_TEMPLATE.format(text_to_enhance=text_to_enhance)
|
98 |
+
|
99 |
+
messages = [
|
100 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
101 |
+
{"role": "user", "content": formatted_prompt}
|
102 |
+
]
|
103 |
+
text = tokenizer.apply_chat_template(
|
104 |
+
messages,
|
105 |
+
tokenize=False,
|
106 |
+
add_generation_prompt=True
|
107 |
+
)
|
108 |
+
|
109 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)
|
110 |
+
|
111 |
+
generated_ids = model.generate(
|
112 |
+
model_inputs.input_ids,
|
113 |
+
max_new_tokens=256,
|
114 |
+
do_sample=True,
|
115 |
+
temperature=0.5,
|
116 |
+
top_p=0.95,
|
117 |
+
top_k=30
|
118 |
+
)
|
119 |
+
|
120 |
+
generated_ids = [
|
121 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
122 |
+
]
|
123 |
+
|
124 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
125 |
+
|
126 |
+
# Clean up the response
|
127 |
+
response = response.strip().replace('"', '')
|
128 |
+
return response
|
129 |
+
|
130 |
+
def unload_enhancing_model():
|
131 |
+
global model, tokenizer
|
132 |
+
if model is not None:
|
133 |
+
del model
|
134 |
+
model = None
|
135 |
+
if tokenizer is not None:
|
136 |
+
del tokenizer
|
137 |
+
tokenizer = None
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
|
140 |
+
|
141 |
+
def enhance_prompt(prompt_text: str) -> str:
|
142 |
+
"""
|
143 |
+
Enhances a prompt, handling both plain text and timestamped formats.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
prompt_text: The user's input prompt.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
The enhanced prompt string.
|
150 |
+
"""
|
151 |
+
|
152 |
+
_load_enhancing_model();
|
153 |
+
|
154 |
+
if not prompt_text:
|
155 |
+
return ""
|
156 |
+
|
157 |
+
# Regex to find timestamp sections like [0s: text] or [1.1s-2.2s: text]
|
158 |
+
timestamp_pattern = r'(\[\d+(?:\.\d+)?s(?:-\d+(?:\.\d+)?s)?\s*:\s*)(.*?)(?=\])'
|
159 |
+
|
160 |
+
matches = list(re.finditer(timestamp_pattern, prompt_text))
|
161 |
+
|
162 |
+
if not matches:
|
163 |
+
# No timestamps found, enhance the whole prompt
|
164 |
+
print("LLM Enhancer: Enhancing a simple prompt.")
|
165 |
+
return _run_inference(prompt_text)
|
166 |
+
else:
|
167 |
+
# Timestamps found, enhance each section's text
|
168 |
+
print(f"LLM Enhancer: Enhancing {len(matches)} sections in a timestamped prompt.")
|
169 |
+
enhanced_parts = []
|
170 |
+
last_end = 0
|
171 |
+
|
172 |
+
for match in matches:
|
173 |
+
# Add the part of the string before the current match (e.g., whitespace)
|
174 |
+
enhanced_parts.append(prompt_text[last_end:match.start()])
|
175 |
+
|
176 |
+
timestamp_prefix = match.group(1)
|
177 |
+
text_to_enhance = match.group(2).strip()
|
178 |
+
|
179 |
+
if text_to_enhance:
|
180 |
+
enhanced_text = _run_inference(text_to_enhance)
|
181 |
+
enhanced_parts.append(f"{timestamp_prefix}{enhanced_text}")
|
182 |
+
else:
|
183 |
+
# Keep empty sections as they are
|
184 |
+
enhanced_parts.append(f"{timestamp_prefix}")
|
185 |
+
|
186 |
+
last_end = match.end()
|
187 |
+
|
188 |
+
# Add the closing bracket for the last match and any trailing text
|
189 |
+
enhanced_parts.append(prompt_text[last_end:])
|
190 |
+
|
191 |
+
return "".join(enhanced_parts)
|
modules/pipelines/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pipeline module for FramePack Studio.
|
3 |
+
This module provides pipeline classes for different generation types.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .base_pipeline import BasePipeline
|
7 |
+
from .original_pipeline import OriginalPipeline
|
8 |
+
from .f1_pipeline import F1Pipeline
|
9 |
+
from .original_with_endframe_pipeline import OriginalWithEndframePipeline
|
10 |
+
from .video_pipeline import VideoPipeline
|
11 |
+
from .video_f1_pipeline import VideoF1Pipeline
|
12 |
+
|
13 |
+
def create_pipeline(model_type, settings):
|
14 |
+
"""
|
15 |
+
Create a pipeline instance for the specified model type.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
model_type: The type of model to create a pipeline for
|
19 |
+
settings: Dictionary of settings for the pipeline
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
A pipeline instance for the specified model type
|
23 |
+
"""
|
24 |
+
if model_type == "Original":
|
25 |
+
return OriginalPipeline(settings)
|
26 |
+
elif model_type == "F1":
|
27 |
+
return F1Pipeline(settings)
|
28 |
+
elif model_type == "Original with Endframe":
|
29 |
+
return OriginalWithEndframePipeline(settings)
|
30 |
+
elif model_type == "Video":
|
31 |
+
return VideoPipeline(settings)
|
32 |
+
elif model_type == "Video F1":
|
33 |
+
return VideoF1Pipeline(settings)
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
36 |
+
|
37 |
+
__all__ = [
|
38 |
+
'BasePipeline',
|
39 |
+
'OriginalPipeline',
|
40 |
+
'F1Pipeline',
|
41 |
+
'OriginalWithEndframePipeline',
|
42 |
+
'VideoPipeline',
|
43 |
+
'VideoF1Pipeline',
|
44 |
+
'create_pipeline'
|
45 |
+
]
|
modules/pipelines/base_pipeline.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base pipeline class for FramePack Studio.
|
3 |
+
All pipeline implementations should inherit from this class.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from modules.pipelines.metadata_utils import create_metadata
|
8 |
+
|
9 |
+
class BasePipeline:
|
10 |
+
"""Base class for all pipeline implementations."""
|
11 |
+
|
12 |
+
def __init__(self, settings):
|
13 |
+
"""
|
14 |
+
Initialize the pipeline with settings.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
settings: Dictionary of settings for the pipeline
|
18 |
+
"""
|
19 |
+
self.settings = settings
|
20 |
+
|
21 |
+
def prepare_parameters(self, job_params):
|
22 |
+
"""
|
23 |
+
Prepare parameters for the job.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
job_params: Dictionary of job parameters
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Processed parameters dictionary
|
30 |
+
"""
|
31 |
+
# Default implementation just returns the parameters as-is
|
32 |
+
return job_params
|
33 |
+
|
34 |
+
def validate_parameters(self, job_params):
|
35 |
+
"""
|
36 |
+
Validate parameters for the job.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
job_params: Dictionary of job parameters
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tuple of (is_valid, error_message)
|
43 |
+
"""
|
44 |
+
# Default implementation assumes all parameters are valid
|
45 |
+
return True, None
|
46 |
+
|
47 |
+
def preprocess_inputs(self, job_params):
|
48 |
+
"""
|
49 |
+
Preprocess input images/videos for the job.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
job_params: Dictionary of job parameters
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Processed inputs dictionary
|
56 |
+
"""
|
57 |
+
# Default implementation returns an empty dictionary
|
58 |
+
return {}
|
59 |
+
|
60 |
+
def handle_results(self, job_params, result):
|
61 |
+
"""
|
62 |
+
Handle the results of the job.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
job_params: Dictionary of job parameters
|
66 |
+
result: The result of the job
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Processed result
|
70 |
+
"""
|
71 |
+
# Default implementation just returns the result as-is
|
72 |
+
return result
|
73 |
+
|
74 |
+
def create_metadata(self, job_params, job_id):
|
75 |
+
"""
|
76 |
+
Create metadata for the job.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
job_params: Dictionary of job parameters
|
80 |
+
job_id: The job ID
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Metadata dictionary
|
84 |
+
"""
|
85 |
+
return create_metadata(job_params, job_id, self.settings)
|
modules/pipelines/f1_pipeline.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
F1 pipeline class for FramePack Studio.
|
3 |
+
This pipeline handles the "F1" model type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from PIL.PngImagePlugin import PngInfo
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from .base_pipeline import BasePipeline
|
15 |
+
|
16 |
+
class F1Pipeline(BasePipeline):
|
17 |
+
"""Pipeline for F1 generation type."""
|
18 |
+
|
19 |
+
def prepare_parameters(self, job_params):
|
20 |
+
"""
|
21 |
+
Prepare parameters for the F1 generation job.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
job_params: Dictionary of job parameters
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Processed parameters dictionary
|
28 |
+
"""
|
29 |
+
processed_params = job_params.copy()
|
30 |
+
|
31 |
+
# Ensure we have the correct model type
|
32 |
+
processed_params['model_type'] = "F1"
|
33 |
+
|
34 |
+
return processed_params
|
35 |
+
|
36 |
+
def validate_parameters(self, job_params):
|
37 |
+
"""
|
38 |
+
Validate parameters for the F1 generation job.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
job_params: Dictionary of job parameters
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple of (is_valid, error_message)
|
45 |
+
"""
|
46 |
+
# Check for required parameters
|
47 |
+
required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
|
48 |
+
for param in required_params:
|
49 |
+
if param not in job_params:
|
50 |
+
return False, f"Missing required parameter: {param}"
|
51 |
+
|
52 |
+
# Validate numeric parameters
|
53 |
+
if job_params.get('total_second_length', 0) <= 0:
|
54 |
+
return False, "Video length must be greater than 0"
|
55 |
+
|
56 |
+
if job_params.get('steps', 0) <= 0:
|
57 |
+
return False, "Steps must be greater than 0"
|
58 |
+
|
59 |
+
return True, None
|
60 |
+
|
61 |
+
def preprocess_inputs(self, job_params):
|
62 |
+
"""
|
63 |
+
Preprocess input images for the F1 generation type.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
job_params: Dictionary of job parameters
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Processed inputs dictionary
|
70 |
+
"""
|
71 |
+
processed_inputs = {}
|
72 |
+
|
73 |
+
# Process input image if provided
|
74 |
+
input_image = job_params.get('input_image')
|
75 |
+
if input_image is not None:
|
76 |
+
# Get resolution parameters
|
77 |
+
resolutionW = job_params.get('resolutionW', 640)
|
78 |
+
resolutionH = job_params.get('resolutionH', 640)
|
79 |
+
|
80 |
+
# Find nearest bucket size
|
81 |
+
if job_params.get('has_input_image', True):
|
82 |
+
# If we have an input image, use its dimensions to find the nearest bucket
|
83 |
+
H, W, _ = input_image.shape
|
84 |
+
height, width = find_nearest_bucket(H, W, resolution=resolutionW)
|
85 |
+
else:
|
86 |
+
# Otherwise, use the provided resolution parameters
|
87 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
88 |
+
|
89 |
+
# Resize and center crop the input image
|
90 |
+
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
|
91 |
+
|
92 |
+
# Store the processed image and dimensions
|
93 |
+
processed_inputs['input_image'] = input_image_np
|
94 |
+
processed_inputs['height'] = height
|
95 |
+
processed_inputs['width'] = width
|
96 |
+
else:
|
97 |
+
# If no input image, create a blank image based on latent_type
|
98 |
+
resolutionW = job_params.get('resolutionW', 640)
|
99 |
+
resolutionH = job_params.get('resolutionH', 640)
|
100 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
101 |
+
|
102 |
+
latent_type = job_params.get('latent_type', 'Black')
|
103 |
+
if latent_type == "White":
|
104 |
+
# Create a white image
|
105 |
+
input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
|
106 |
+
elif latent_type == "Noise":
|
107 |
+
# Create a noise image
|
108 |
+
input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
109 |
+
elif latent_type == "Green Screen":
|
110 |
+
# Create a green screen image with standard chroma key green (0, 177, 64)
|
111 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
112 |
+
input_image_np[:, :, 1] = 177 # Green channel
|
113 |
+
input_image_np[:, :, 2] = 64 # Blue channel
|
114 |
+
# Red channel remains 0
|
115 |
+
else: # Default to "Black" or any other value
|
116 |
+
# Create a black image
|
117 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
118 |
+
|
119 |
+
# Store the processed image and dimensions
|
120 |
+
processed_inputs['input_image'] = input_image_np
|
121 |
+
processed_inputs['height'] = height
|
122 |
+
processed_inputs['width'] = width
|
123 |
+
|
124 |
+
return processed_inputs
|
125 |
+
|
126 |
+
def handle_results(self, job_params, result):
|
127 |
+
"""
|
128 |
+
Handle the results of the F1 generation.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
job_params: The job parameters
|
132 |
+
result: The generation result
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Processed result
|
136 |
+
"""
|
137 |
+
# For F1 generation, we just return the result as-is
|
138 |
+
return result
|
139 |
+
|
140 |
+
# Using the centralized create_metadata method from BasePipeline
|
modules/pipelines/metadata_utils.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Metadata utilities for FramePack Studio.
|
3 |
+
This module provides functions for generating and saving metadata.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
import traceback # Moved to top
|
10 |
+
import numpy as np # Added
|
11 |
+
from PIL import Image, ImageDraw, ImageFont
|
12 |
+
from PIL.PngImagePlugin import PngInfo
|
13 |
+
|
14 |
+
from modules.version import APP_VERSION
|
15 |
+
|
16 |
+
def get_placeholder_color(model_type):
|
17 |
+
"""
|
18 |
+
Get the placeholder image color for a specific model type.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_type: The model type string
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
RGB tuple for the placeholder image color
|
25 |
+
"""
|
26 |
+
# Define color mapping for different model types
|
27 |
+
color_map = {
|
28 |
+
"Original": (0, 0, 0), # Black
|
29 |
+
"F1": (0, 0, 128), # Blue
|
30 |
+
"Video": (0, 128, 0), # Green
|
31 |
+
"XY Plot": (128, 128, 0), # Yellow
|
32 |
+
"F1 with Endframe": (0, 128, 128), # Teal
|
33 |
+
"Original with Endframe": (128, 0, 128), # Purple
|
34 |
+
}
|
35 |
+
|
36 |
+
# Return the color for the model type, or black as default
|
37 |
+
return color_map.get(model_type, (0, 0, 0))
|
38 |
+
|
39 |
+
# Function to save the starting image with comprehensive metadata
|
40 |
+
def save_job_start_image(job_params, job_id, settings):
|
41 |
+
"""
|
42 |
+
Saves the job's starting input image to the output directory with comprehensive metadata.
|
43 |
+
This is intended to be called early in the job processing and is the ONLY place metadata should be saved.
|
44 |
+
"""
|
45 |
+
# Get output directory from settings or job_params
|
46 |
+
output_dir_path = job_params.get("output_dir") or settings.get("output_dir")
|
47 |
+
metadata_dir_path = job_params.get("metadata_dir") or settings.get("metadata_dir")
|
48 |
+
|
49 |
+
if not output_dir_path:
|
50 |
+
print(f"[JOB_START_IMG_ERROR] No output directory found in job_params or settings")
|
51 |
+
return False
|
52 |
+
|
53 |
+
# Ensure directories exist
|
54 |
+
os.makedirs(output_dir_path, exist_ok=True)
|
55 |
+
os.makedirs(metadata_dir_path, exist_ok=True)
|
56 |
+
|
57 |
+
actual_start_image_target_path = os.path.join(output_dir_path, f'{job_id}.png')
|
58 |
+
actual_input_image_np = job_params.get('input_image')
|
59 |
+
|
60 |
+
# Create comprehensive metadata dictionary
|
61 |
+
metadata_dict = create_metadata(job_params, job_id, settings)
|
62 |
+
|
63 |
+
# Save JSON metadata with the same job_id
|
64 |
+
json_metadata_path = os.path.join(metadata_dir_path, f'{job_id}.json')
|
65 |
+
|
66 |
+
try:
|
67 |
+
with open(json_metadata_path, 'w') as f:
|
68 |
+
import json
|
69 |
+
json.dump(metadata_dict, f, indent=2)
|
70 |
+
except Exception as e:
|
71 |
+
traceback.print_exc()
|
72 |
+
|
73 |
+
# Save the input image if it's a numpy array
|
74 |
+
if actual_input_image_np is not None and isinstance(actual_input_image_np, np.ndarray):
|
75 |
+
try:
|
76 |
+
# Create PNG metadata
|
77 |
+
png_metadata = PngInfo()
|
78 |
+
png_metadata.add_text("prompt", job_params.get('prompt_text', ''))
|
79 |
+
png_metadata.add_text("seed", str(job_params.get('seed', 0)))
|
80 |
+
png_metadata.add_text("model_type", job_params.get('model_type', "Unknown"))
|
81 |
+
|
82 |
+
# Add more metadata fields
|
83 |
+
for key, value in metadata_dict.items():
|
84 |
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
85 |
+
png_metadata.add_text(key, str(value))
|
86 |
+
|
87 |
+
# Convert image if needed
|
88 |
+
image_to_save_np = actual_input_image_np
|
89 |
+
if actual_input_image_np.dtype != np.uint8:
|
90 |
+
if actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= -1.0 and actual_input_image_np.dtype in [np.float32, np.float64]:
|
91 |
+
image_to_save_np = ((actual_input_image_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
|
92 |
+
elif actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= 0.0 and actual_input_image_np.dtype in [np.float32, np.float64]:
|
93 |
+
image_to_save_np = (actual_input_image_np * 255.0).clip(0,255).astype(np.uint8)
|
94 |
+
else:
|
95 |
+
image_to_save_np = actual_input_image_np.clip(0, 255).astype(np.uint8)
|
96 |
+
# Save the image with metadata
|
97 |
+
start_image_pil = Image.fromarray(image_to_save_np)
|
98 |
+
start_image_pil.save(actual_start_image_target_path, pnginfo=png_metadata)
|
99 |
+
return True # Indicate success
|
100 |
+
except Exception as e:
|
101 |
+
traceback.print_exc()
|
102 |
+
return False # Indicate failure or inability to save
|
103 |
+
|
104 |
+
def create_metadata(job_params, job_id, settings, save_placeholder=False):
|
105 |
+
"""
|
106 |
+
Create metadata for the job.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
job_params: Dictionary of job parameters
|
110 |
+
job_id: The job ID
|
111 |
+
settings: Dictionary of settings
|
112 |
+
save_placeholder: Whether to save the placeholder image (default: False)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Metadata dictionary
|
116 |
+
"""
|
117 |
+
if not settings.get("save_metadata"):
|
118 |
+
return None
|
119 |
+
|
120 |
+
metadata_dir_path = settings.get("metadata_dir")
|
121 |
+
output_dir_path = settings.get("output_dir")
|
122 |
+
os.makedirs(metadata_dir_path, exist_ok=True)
|
123 |
+
os.makedirs(output_dir_path, exist_ok=True) # Ensure output_dir also exists
|
124 |
+
|
125 |
+
# Get model type and determine placeholder image color
|
126 |
+
model_type = job_params.get('model_type', "Original")
|
127 |
+
placeholder_color = get_placeholder_color(model_type)
|
128 |
+
|
129 |
+
# Create a placeholder image
|
130 |
+
height = job_params.get('height', 640)
|
131 |
+
width = job_params.get('width', 640)
|
132 |
+
|
133 |
+
# Use resolutionH and resolutionW if height and width are not available
|
134 |
+
if not height:
|
135 |
+
height = job_params.get('resolutionH', 640)
|
136 |
+
if not width:
|
137 |
+
width = job_params.get('resolutionW', 640)
|
138 |
+
|
139 |
+
placeholder_img = Image.new('RGB', (width, height), placeholder_color)
|
140 |
+
|
141 |
+
# Add XY plot parameters to the image if applicable
|
142 |
+
if model_type == "XY Plot":
|
143 |
+
x_param = job_params.get('x_param', '')
|
144 |
+
y_param = job_params.get('y_param', '')
|
145 |
+
x_values = job_params.get('x_values', [])
|
146 |
+
y_values = job_params.get('y_values', [])
|
147 |
+
|
148 |
+
draw = ImageDraw.Draw(placeholder_img)
|
149 |
+
try:
|
150 |
+
# Try to use a system font
|
151 |
+
font = ImageFont.truetype("Arial", 20)
|
152 |
+
except:
|
153 |
+
# Fall back to default font
|
154 |
+
font = ImageFont.load_default()
|
155 |
+
|
156 |
+
text = f"X: {x_param} - {x_values}\nY: {y_param} - {y_values}"
|
157 |
+
draw.text((10, 10), text, fill=(255, 255, 255), font=font)
|
158 |
+
|
159 |
+
# Create PNG metadata
|
160 |
+
metadata = PngInfo()
|
161 |
+
metadata.add_text("prompt", job_params.get('prompt_text', ''))
|
162 |
+
metadata.add_text("seed", str(job_params.get('seed', 0)))
|
163 |
+
|
164 |
+
# Add model-specific metadata to PNG
|
165 |
+
if model_type == "XY Plot":
|
166 |
+
metadata.add_text("x_param", job_params.get('x_param', ''))
|
167 |
+
metadata.add_text("y_param", job_params.get('y_param', ''))
|
168 |
+
|
169 |
+
# Determine end_frame_used value safely (avoiding NumPy array boolean ambiguity)
|
170 |
+
end_frame_image = job_params.get('end_frame_image')
|
171 |
+
end_frame_used = False
|
172 |
+
if end_frame_image is not None:
|
173 |
+
if isinstance(end_frame_image, np.ndarray):
|
174 |
+
end_frame_used = end_frame_image.any() # True if any element is non-zero
|
175 |
+
else:
|
176 |
+
end_frame_used = True
|
177 |
+
|
178 |
+
# Create comprehensive JSON metadata with all possible parameters
|
179 |
+
# This is created before file saving logic that might use it (e.g. JSON dump)
|
180 |
+
# but PngInfo 'metadata' is used for images.
|
181 |
+
metadata_dict = {
|
182 |
+
# Version information
|
183 |
+
"app_version": APP_VERSION, # Using numeric version without 'v' prefix for metadata
|
184 |
+
|
185 |
+
# Common parameters
|
186 |
+
"prompt": job_params.get('prompt_text', ''),
|
187 |
+
"negative_prompt": job_params.get('n_prompt', ''),
|
188 |
+
"seed": job_params.get('seed', 0),
|
189 |
+
"steps": job_params.get('steps', 25),
|
190 |
+
"cfg": job_params.get('cfg', 1.0),
|
191 |
+
"gs": job_params.get('gs', 10.0),
|
192 |
+
"rs": job_params.get('rs', 0.0),
|
193 |
+
"latent_type": job_params.get('latent_type', 'Black'),
|
194 |
+
"timestamp": time.time(),
|
195 |
+
"resolutionW": job_params.get('resolutionW', 640),
|
196 |
+
"resolutionH": job_params.get('resolutionH', 640),
|
197 |
+
"model_type": model_type,
|
198 |
+
"generation_type": job_params.get('generation_type', model_type),
|
199 |
+
"has_input_image": job_params.get('has_input_image', False),
|
200 |
+
"input_image_path": job_params.get('input_image_path', None),
|
201 |
+
|
202 |
+
# Video-related parameters
|
203 |
+
"total_second_length": job_params.get('total_second_length', 6),
|
204 |
+
"blend_sections": job_params.get('blend_sections', 4),
|
205 |
+
"latent_window_size": job_params.get('latent_window_size', 9),
|
206 |
+
"num_cleaned_frames": job_params.get('num_cleaned_frames', 5),
|
207 |
+
|
208 |
+
# Endframe-related parameters
|
209 |
+
"end_frame_strength": job_params.get('end_frame_strength', None),
|
210 |
+
"end_frame_image_path": job_params.get('end_frame_image_path', None),
|
211 |
+
"end_frame_used": str(end_frame_used),
|
212 |
+
|
213 |
+
# Video input-related parameters
|
214 |
+
"input_video": os.path.basename(job_params.get('input_image', '')) if job_params.get('input_image') is not None and model_type == "Video" else None,
|
215 |
+
"video_path": job_params.get('input_image') if model_type == "Video" else None,
|
216 |
+
|
217 |
+
# XY Plot-related parameters
|
218 |
+
"x_param": job_params.get('x_param', None),
|
219 |
+
"y_param": job_params.get('y_param', None),
|
220 |
+
"x_values": job_params.get('x_values', None),
|
221 |
+
"y_values": job_params.get('y_values', None),
|
222 |
+
|
223 |
+
# Combine with source video
|
224 |
+
"combine_with_source": job_params.get('combine_with_source', False),
|
225 |
+
|
226 |
+
# Tea cache parameters
|
227 |
+
"use_teacache": job_params.get('use_teacache', False),
|
228 |
+
"teacache_num_steps": job_params.get('teacache_num_steps', 0),
|
229 |
+
"teacache_rel_l1_thresh": job_params.get('teacache_rel_l1_thresh', 0.0),
|
230 |
+
# MagCache parameters
|
231 |
+
"use_magcache": job_params.get('use_magcache', False),
|
232 |
+
"magcache_threshold": job_params.get('magcache_threshold', 0.1),
|
233 |
+
"magcache_max_consecutive_skips": job_params.get('magcache_max_consecutive_skips', 2),
|
234 |
+
"magcache_retention_ratio": job_params.get('magcache_retention_ratio', 0.25),
|
235 |
+
}
|
236 |
+
|
237 |
+
# Add LoRA information if present
|
238 |
+
selected_loras = job_params.get('selected_loras', [])
|
239 |
+
lora_values = job_params.get('lora_values', [])
|
240 |
+
lora_loaded_names = job_params.get('lora_loaded_names', [])
|
241 |
+
|
242 |
+
if isinstance(selected_loras, list) and len(selected_loras) > 0:
|
243 |
+
lora_data = {}
|
244 |
+
for lora_name in selected_loras:
|
245 |
+
try:
|
246 |
+
idx = lora_loaded_names.index(lora_name)
|
247 |
+
# Fix for NumPy array boolean ambiguity
|
248 |
+
has_lora_values = lora_values is not None and len(lora_values) > 0
|
249 |
+
weight = lora_values[idx] if has_lora_values and idx < len(lora_values) else 1.0
|
250 |
+
|
251 |
+
# Handle different types of weight values
|
252 |
+
if isinstance(weight, np.ndarray):
|
253 |
+
# Convert NumPy array to a scalar value
|
254 |
+
weight_value = float(weight.item()) if weight.size == 1 else float(weight.mean())
|
255 |
+
elif isinstance(weight, list):
|
256 |
+
# Handle list type weights
|
257 |
+
has_items = weight is not None and len(weight) > 0
|
258 |
+
weight_value = float(weight[0]) if has_items else 1.0
|
259 |
+
else:
|
260 |
+
# Handle scalar weights
|
261 |
+
weight_value = float(weight) if weight is not None else 1.0
|
262 |
+
|
263 |
+
lora_data[lora_name] = weight_value
|
264 |
+
except ValueError:
|
265 |
+
lora_data[lora_name] = 1.0
|
266 |
+
except Exception as e:
|
267 |
+
lora_data[lora_name] = 1.0
|
268 |
+
traceback.print_exc()
|
269 |
+
|
270 |
+
metadata_dict["loras"] = lora_data
|
271 |
+
else:
|
272 |
+
metadata_dict["loras"] = {}
|
273 |
+
|
274 |
+
# This function now only creates the metadata dictionary without saving files
|
275 |
+
# The actual saving is done by save_job_start_image() at the beginning of the generation process
|
276 |
+
# This prevents duplicate metadata files from being created
|
277 |
+
|
278 |
+
# For backward compatibility, we still create the placeholder image
|
279 |
+
# and save it if explicitly requested
|
280 |
+
placeholder_target_path = os.path.join(metadata_dir_path, f'{job_id}.png')
|
281 |
+
|
282 |
+
# Save the placeholder image if requested
|
283 |
+
if save_placeholder:
|
284 |
+
try:
|
285 |
+
placeholder_img.save(placeholder_target_path, pnginfo=metadata)
|
286 |
+
except Exception as e:
|
287 |
+
traceback.print_exc()
|
288 |
+
|
289 |
+
return metadata_dict
|
290 |
+
|
291 |
+
def save_last_video_frame(job_params, job_id, settings, last_frame_np):
|
292 |
+
"""
|
293 |
+
Saves the last frame of the input video to the output directory with metadata.
|
294 |
+
"""
|
295 |
+
output_dir_path = job_params.get("output_dir") or settings.get("output_dir")
|
296 |
+
|
297 |
+
if not output_dir_path:
|
298 |
+
print(f"[SAVE_LAST_FRAME_ERROR] No output directory found.")
|
299 |
+
return False
|
300 |
+
|
301 |
+
os.makedirs(output_dir_path, exist_ok=True)
|
302 |
+
|
303 |
+
last_frame_path = os.path.join(output_dir_path, f'{job_id}.png')
|
304 |
+
|
305 |
+
metadata_dict = create_metadata(job_params, job_id, settings)
|
306 |
+
|
307 |
+
if last_frame_np is not None and isinstance(last_frame_np, np.ndarray):
|
308 |
+
try:
|
309 |
+
png_metadata = PngInfo()
|
310 |
+
for key, value in metadata_dict.items():
|
311 |
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
312 |
+
png_metadata.add_text(key, str(value))
|
313 |
+
|
314 |
+
image_to_save_np = last_frame_np
|
315 |
+
if last_frame_np.dtype != np.uint8:
|
316 |
+
if last_frame_np.max() <= 1.0 and last_frame_np.min() >= -1.0 and last_frame_np.dtype in [np.float32, np.float64]:
|
317 |
+
image_to_save_np = ((last_frame_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
|
318 |
+
elif last_frame_np.max() <= 1.0 and last_frame_np.min() >= 0.0 and last_frame_np.dtype in [np.float32, np.float64]:
|
319 |
+
image_to_save_np = (last_frame_np * 255.0).clip(0,255).astype(np.uint8)
|
320 |
+
else:
|
321 |
+
image_to_save_np = last_frame_np.clip(0, 255).astype(np.uint8)
|
322 |
+
|
323 |
+
last_frame_pil = Image.fromarray(image_to_save_np)
|
324 |
+
last_frame_pil.save(last_frame_path, pnginfo=png_metadata)
|
325 |
+
print(f"Saved last video frame for job {job_id} to {last_frame_path}")
|
326 |
+
return True
|
327 |
+
except Exception as e:
|
328 |
+
traceback.print_exc()
|
329 |
+
return False
|
modules/pipelines/original_pipeline.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Original pipeline class for FramePack Studio.
|
3 |
+
This pipeline handles the "Original" model type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from PIL.PngImagePlugin import PngInfo
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from .base_pipeline import BasePipeline
|
15 |
+
|
16 |
+
class OriginalPipeline(BasePipeline):
|
17 |
+
"""Pipeline for Original generation type."""
|
18 |
+
|
19 |
+
def prepare_parameters(self, job_params):
|
20 |
+
"""
|
21 |
+
Prepare parameters for the Original generation job.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
job_params: Dictionary of job parameters
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Processed parameters dictionary
|
28 |
+
"""
|
29 |
+
processed_params = job_params.copy()
|
30 |
+
|
31 |
+
# Ensure we have the correct model type
|
32 |
+
processed_params['model_type'] = "Original"
|
33 |
+
|
34 |
+
return processed_params
|
35 |
+
|
36 |
+
def validate_parameters(self, job_params):
|
37 |
+
"""
|
38 |
+
Validate parameters for the Original generation job.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
job_params: Dictionary of job parameters
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple of (is_valid, error_message)
|
45 |
+
"""
|
46 |
+
# Check for required parameters
|
47 |
+
required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
|
48 |
+
for param in required_params:
|
49 |
+
if param not in job_params:
|
50 |
+
return False, f"Missing required parameter: {param}"
|
51 |
+
|
52 |
+
# Validate numeric parameters
|
53 |
+
if job_params.get('total_second_length', 0) <= 0:
|
54 |
+
return False, "Video length must be greater than 0"
|
55 |
+
|
56 |
+
if job_params.get('steps', 0) <= 0:
|
57 |
+
return False, "Steps must be greater than 0"
|
58 |
+
|
59 |
+
return True, None
|
60 |
+
|
61 |
+
def preprocess_inputs(self, job_params):
|
62 |
+
"""
|
63 |
+
Preprocess input images for the Original generation type.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
job_params: Dictionary of job parameters
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Processed inputs dictionary
|
70 |
+
"""
|
71 |
+
processed_inputs = {}
|
72 |
+
|
73 |
+
# Process input image if provided
|
74 |
+
input_image = job_params.get('input_image')
|
75 |
+
if input_image is not None:
|
76 |
+
# Get resolution parameters
|
77 |
+
resolutionW = job_params.get('resolutionW', 640)
|
78 |
+
resolutionH = job_params.get('resolutionH', 640)
|
79 |
+
|
80 |
+
# Find nearest bucket size
|
81 |
+
if job_params.get('has_input_image', True):
|
82 |
+
# If we have an input image, use its dimensions to find the nearest bucket
|
83 |
+
H, W, _ = input_image.shape
|
84 |
+
height, width = find_nearest_bucket(H, W, resolution=resolutionW)
|
85 |
+
else:
|
86 |
+
# Otherwise, use the provided resolution parameters
|
87 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
88 |
+
|
89 |
+
# Resize and center crop the input image
|
90 |
+
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
|
91 |
+
|
92 |
+
# Store the processed image and dimensions
|
93 |
+
processed_inputs['input_image'] = input_image_np
|
94 |
+
processed_inputs['height'] = height
|
95 |
+
processed_inputs['width'] = width
|
96 |
+
else:
|
97 |
+
# If no input image, create a blank image based on latent_type
|
98 |
+
resolutionW = job_params.get('resolutionW', 640)
|
99 |
+
resolutionH = job_params.get('resolutionH', 640)
|
100 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
101 |
+
|
102 |
+
latent_type = job_params.get('latent_type', 'Black')
|
103 |
+
if latent_type == "White":
|
104 |
+
# Create a white image
|
105 |
+
input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
|
106 |
+
elif latent_type == "Noise":
|
107 |
+
# Create a noise image
|
108 |
+
input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
109 |
+
elif latent_type == "Green Screen":
|
110 |
+
# Create a green screen image with standard chroma key green (0, 177, 64)
|
111 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
112 |
+
input_image_np[:, :, 1] = 177 # Green channel
|
113 |
+
input_image_np[:, :, 2] = 64 # Blue channel
|
114 |
+
# Red channel remains 0
|
115 |
+
else: # Default to "Black" or any other value
|
116 |
+
# Create a black image
|
117 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
118 |
+
|
119 |
+
# Store the processed image and dimensions
|
120 |
+
processed_inputs['input_image'] = input_image_np
|
121 |
+
processed_inputs['height'] = height
|
122 |
+
processed_inputs['width'] = width
|
123 |
+
|
124 |
+
return processed_inputs
|
125 |
+
|
126 |
+
def handle_results(self, job_params, result):
|
127 |
+
"""
|
128 |
+
Handle the results of the Original generation.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
job_params: The job parameters
|
132 |
+
result: The generation result
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Processed result
|
136 |
+
"""
|
137 |
+
# For Original generation, we just return the result as-is
|
138 |
+
return result
|
modules/pipelines/original_with_endframe_pipeline.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Original with Endframe pipeline class for FramePack Studio.
|
3 |
+
This pipeline handles the "Original with Endframe" model type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from PIL.PngImagePlugin import PngInfo
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from .base_pipeline import BasePipeline
|
15 |
+
|
16 |
+
class OriginalWithEndframePipeline(BasePipeline):
|
17 |
+
"""Pipeline for Original with Endframe generation type."""
|
18 |
+
|
19 |
+
def prepare_parameters(self, job_params):
|
20 |
+
"""
|
21 |
+
Prepare parameters for the Original with Endframe generation job.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
job_params: Dictionary of job parameters
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Processed parameters dictionary
|
28 |
+
"""
|
29 |
+
processed_params = job_params.copy()
|
30 |
+
|
31 |
+
# Ensure we have the correct model type
|
32 |
+
processed_params['model_type'] = "Original with Endframe"
|
33 |
+
|
34 |
+
return processed_params
|
35 |
+
|
36 |
+
def validate_parameters(self, job_params):
|
37 |
+
"""
|
38 |
+
Validate parameters for the Original with Endframe generation job.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
job_params: Dictionary of job parameters
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple of (is_valid, error_message)
|
45 |
+
"""
|
46 |
+
# Check for required parameters
|
47 |
+
required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
|
48 |
+
for param in required_params:
|
49 |
+
if param not in job_params:
|
50 |
+
return False, f"Missing required parameter: {param}"
|
51 |
+
|
52 |
+
# Validate numeric parameters
|
53 |
+
if job_params.get('total_second_length', 0) <= 0:
|
54 |
+
return False, "Video length must be greater than 0"
|
55 |
+
|
56 |
+
if job_params.get('steps', 0) <= 0:
|
57 |
+
return False, "Steps must be greater than 0"
|
58 |
+
|
59 |
+
# Validate end frame parameters
|
60 |
+
if job_params.get('end_frame_strength', 0) < 0 or job_params.get('end_frame_strength', 0) > 1:
|
61 |
+
return False, "End frame strength must be between 0 and 1"
|
62 |
+
|
63 |
+
return True, None
|
64 |
+
|
65 |
+
def preprocess_inputs(self, job_params):
|
66 |
+
"""
|
67 |
+
Preprocess input images for the Original with Endframe generation type.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
job_params: Dictionary of job parameters
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Processed inputs dictionary
|
74 |
+
"""
|
75 |
+
processed_inputs = {}
|
76 |
+
|
77 |
+
# Process input image if provided
|
78 |
+
input_image = job_params.get('input_image')
|
79 |
+
if input_image is not None:
|
80 |
+
# Get resolution parameters
|
81 |
+
resolutionW = job_params.get('resolutionW', 640)
|
82 |
+
resolutionH = job_params.get('resolutionH', 640)
|
83 |
+
|
84 |
+
# Find nearest bucket size
|
85 |
+
if job_params.get('has_input_image', True):
|
86 |
+
# If we have an input image, use its dimensions to find the nearest bucket
|
87 |
+
H, W, _ = input_image.shape
|
88 |
+
height, width = find_nearest_bucket(H, W, resolution=resolutionW)
|
89 |
+
else:
|
90 |
+
# Otherwise, use the provided resolution parameters
|
91 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
92 |
+
|
93 |
+
# Resize and center crop the input image
|
94 |
+
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
|
95 |
+
|
96 |
+
# Store the processed image and dimensions
|
97 |
+
processed_inputs['input_image'] = input_image_np
|
98 |
+
processed_inputs['height'] = height
|
99 |
+
processed_inputs['width'] = width
|
100 |
+
else:
|
101 |
+
# If no input image, create a blank image based on latent_type
|
102 |
+
resolutionW = job_params.get('resolutionW', 640)
|
103 |
+
resolutionH = job_params.get('resolutionH', 640)
|
104 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
105 |
+
|
106 |
+
latent_type = job_params.get('latent_type', 'Black')
|
107 |
+
if latent_type == "White":
|
108 |
+
# Create a white image
|
109 |
+
input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
|
110 |
+
elif latent_type == "Noise":
|
111 |
+
# Create a noise image
|
112 |
+
input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
113 |
+
elif latent_type == "Green Screen":
|
114 |
+
# Create a green screen image with standard chroma key green (0, 177, 64)
|
115 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
116 |
+
input_image_np[:, :, 1] = 177 # Green channel
|
117 |
+
input_image_np[:, :, 2] = 64 # Blue channel
|
118 |
+
# Red channel remains 0
|
119 |
+
else: # Default to "Black" or any other value
|
120 |
+
# Create a black image
|
121 |
+
input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
122 |
+
|
123 |
+
# Store the processed image and dimensions
|
124 |
+
processed_inputs['input_image'] = input_image_np
|
125 |
+
processed_inputs['height'] = height
|
126 |
+
processed_inputs['width'] = width
|
127 |
+
|
128 |
+
# Process end frame image if provided
|
129 |
+
end_frame_image = job_params.get('end_frame_image')
|
130 |
+
if end_frame_image is not None:
|
131 |
+
# Use the same dimensions as the input image
|
132 |
+
height = processed_inputs['height']
|
133 |
+
width = processed_inputs['width']
|
134 |
+
|
135 |
+
# Resize and center crop the end frame image
|
136 |
+
end_frame_np = resize_and_center_crop(end_frame_image, target_width=width, target_height=height)
|
137 |
+
|
138 |
+
# Store the processed end frame image
|
139 |
+
processed_inputs['end_frame_image'] = end_frame_np
|
140 |
+
|
141 |
+
return processed_inputs
|
142 |
+
|
143 |
+
def handle_results(self, job_params, result):
|
144 |
+
"""
|
145 |
+
Handle the results of the Original with Endframe generation.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
job_params: The job parameters
|
149 |
+
result: The generation result
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Processed result
|
153 |
+
"""
|
154 |
+
# For Original with Endframe generation, we just return the result as-is
|
155 |
+
return result
|
156 |
+
|
157 |
+
# Using the centralized create_metadata method from BasePipeline
|
modules/pipelines/video_f1_pipeline.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Video F1 pipeline class for FramePack Studio.
|
3 |
+
This pipeline handles the "Video F1" model type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from PIL.PngImagePlugin import PngInfo
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from .base_pipeline import BasePipeline
|
15 |
+
|
16 |
+
class VideoF1Pipeline(BasePipeline):
|
17 |
+
"""Pipeline for Video F1 generation type."""
|
18 |
+
|
19 |
+
def prepare_parameters(self, job_params):
|
20 |
+
"""
|
21 |
+
Prepare parameters for the Video generation job.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
job_params: Dictionary of job parameters
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Processed parameters dictionary
|
28 |
+
"""
|
29 |
+
processed_params = job_params.copy()
|
30 |
+
|
31 |
+
# Ensure we have the correct model type
|
32 |
+
processed_params['model_type'] = "Video F1"
|
33 |
+
|
34 |
+
return processed_params
|
35 |
+
|
36 |
+
def validate_parameters(self, job_params):
|
37 |
+
"""
|
38 |
+
Validate parameters for the Video generation job.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
job_params: Dictionary of job parameters
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple of (is_valid, error_message)
|
45 |
+
"""
|
46 |
+
# Check for required parameters
|
47 |
+
required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
|
48 |
+
for param in required_params:
|
49 |
+
if param not in job_params:
|
50 |
+
return False, f"Missing required parameter: {param}"
|
51 |
+
|
52 |
+
# Validate numeric parameters
|
53 |
+
if job_params.get('total_second_length', 0) <= 0:
|
54 |
+
return False, "Video length must be greater than 0"
|
55 |
+
|
56 |
+
if job_params.get('steps', 0) <= 0:
|
57 |
+
return False, "Steps must be greater than 0"
|
58 |
+
|
59 |
+
# Check for input video (stored in input_image for Video F1 model)
|
60 |
+
if not job_params.get('input_image'):
|
61 |
+
return False, "Input video is required for Video F1 model"
|
62 |
+
|
63 |
+
# Check if combine_with_source is provided (optional)
|
64 |
+
combine_with_source = job_params.get('combine_with_source')
|
65 |
+
if combine_with_source is not None and not isinstance(combine_with_source, bool):
|
66 |
+
return False, "combine_with_source must be a boolean value"
|
67 |
+
|
68 |
+
return True, None
|
69 |
+
|
70 |
+
def preprocess_inputs(self, job_params):
|
71 |
+
"""
|
72 |
+
Preprocess input video for the Video F1 generation type.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
job_params: Dictionary of job parameters
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Processed inputs dictionary
|
79 |
+
"""
|
80 |
+
processed_inputs = {}
|
81 |
+
|
82 |
+
# Get the input video (stored in input_image for Video F1 model)
|
83 |
+
input_video = job_params.get('input_image')
|
84 |
+
if not input_video:
|
85 |
+
raise ValueError("Input video is required for Video F1 model")
|
86 |
+
|
87 |
+
# Store the input video
|
88 |
+
processed_inputs['input_video'] = input_video
|
89 |
+
|
90 |
+
# Note: The following code will be executed in the worker function:
|
91 |
+
# 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels
|
92 |
+
# 2. Then it will store these values for later use:
|
93 |
+
# input_video_pixels = input_video_pixels.cpu()
|
94 |
+
# video_latents = video_latents.cpu()
|
95 |
+
#
|
96 |
+
# 3. If the generator has the set_full_video_latents method, it will store the video latents:
|
97 |
+
# if hasattr(current_generator, 'set_full_video_latents'):
|
98 |
+
# current_generator.set_full_video_latents(video_latents.clone())
|
99 |
+
# print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
|
100 |
+
#
|
101 |
+
# 4. For the Video model, history_latents is initialized with the video_latents:
|
102 |
+
# history_latents = video_latents
|
103 |
+
# print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
|
104 |
+
processed_inputs['input_files_dir'] = job_params.get('input_files_dir')
|
105 |
+
|
106 |
+
# Pass through the combine_with_source parameter if it exists
|
107 |
+
if 'combine_with_source' in job_params:
|
108 |
+
processed_inputs['combine_with_source'] = job_params.get('combine_with_source')
|
109 |
+
print(f"Video F1 pipeline: combine_with_source = {processed_inputs['combine_with_source']}")
|
110 |
+
|
111 |
+
# Pass through the num_cleaned_frames parameter if it exists
|
112 |
+
if 'num_cleaned_frames' in job_params:
|
113 |
+
processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames')
|
114 |
+
print(f"Video F1 pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}")
|
115 |
+
|
116 |
+
# Get resolution parameters
|
117 |
+
resolutionW = job_params.get('resolutionW', 640)
|
118 |
+
resolutionH = job_params.get('resolutionH', 640)
|
119 |
+
|
120 |
+
# Find nearest bucket size
|
121 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
122 |
+
|
123 |
+
# Store the dimensions
|
124 |
+
processed_inputs['height'] = height
|
125 |
+
processed_inputs['width'] = width
|
126 |
+
|
127 |
+
return processed_inputs
|
128 |
+
|
129 |
+
def handle_results(self, job_params, result):
|
130 |
+
"""
|
131 |
+
Handle the results of the Video F1 generation.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
job_params: The job parameters
|
135 |
+
result: The generation result
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Processed result
|
139 |
+
"""
|
140 |
+
# For Video F1 generation, we just return the result as-is
|
141 |
+
return result
|
142 |
+
|
143 |
+
# Using the centralized create_metadata method from BasePipeline
|
modules/pipelines/video_pipeline.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Video pipeline class for FramePack Studio.
|
3 |
+
This pipeline handles the "Video" model type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from PIL.PngImagePlugin import PngInfo
|
12 |
+
from diffusers_helper.utils import resize_and_center_crop
|
13 |
+
from diffusers_helper.bucket_tools import find_nearest_bucket
|
14 |
+
from .base_pipeline import BasePipeline
|
15 |
+
|
16 |
+
class VideoPipeline(BasePipeline):
|
17 |
+
"""Pipeline for Video generation type."""
|
18 |
+
|
19 |
+
def prepare_parameters(self, job_params):
|
20 |
+
"""
|
21 |
+
Prepare parameters for the Video generation job.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
job_params: Dictionary of job parameters
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Processed parameters dictionary
|
28 |
+
"""
|
29 |
+
processed_params = job_params.copy()
|
30 |
+
|
31 |
+
# Ensure we have the correct model type
|
32 |
+
processed_params['model_type'] = "Video"
|
33 |
+
|
34 |
+
return processed_params
|
35 |
+
|
36 |
+
def validate_parameters(self, job_params):
|
37 |
+
"""
|
38 |
+
Validate parameters for the Video generation job.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
job_params: Dictionary of job parameters
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple of (is_valid, error_message)
|
45 |
+
"""
|
46 |
+
# Check for required parameters
|
47 |
+
required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
|
48 |
+
for param in required_params:
|
49 |
+
if param not in job_params:
|
50 |
+
return False, f"Missing required parameter: {param}"
|
51 |
+
|
52 |
+
# Validate numeric parameters
|
53 |
+
if job_params.get('total_second_length', 0) <= 0:
|
54 |
+
return False, "Video length must be greater than 0"
|
55 |
+
|
56 |
+
if job_params.get('steps', 0) <= 0:
|
57 |
+
return False, "Steps must be greater than 0"
|
58 |
+
|
59 |
+
# Check for input video (stored in input_image for Video model)
|
60 |
+
if not job_params.get('input_image'):
|
61 |
+
return False, "Input video is required for Video model"
|
62 |
+
|
63 |
+
# Check if combine_with_source is provided (optional)
|
64 |
+
combine_with_source = job_params.get('combine_with_source')
|
65 |
+
if combine_with_source is not None and not isinstance(combine_with_source, bool):
|
66 |
+
return False, "combine_with_source must be a boolean value"
|
67 |
+
|
68 |
+
return True, None
|
69 |
+
|
70 |
+
def preprocess_inputs(self, job_params):
|
71 |
+
"""
|
72 |
+
Preprocess input video for the Video generation type.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
job_params: Dictionary of job parameters
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Processed inputs dictionary
|
79 |
+
"""
|
80 |
+
processed_inputs = {}
|
81 |
+
|
82 |
+
# Get the input video (stored in input_image for Video model)
|
83 |
+
input_video = job_params.get('input_image')
|
84 |
+
if not input_video:
|
85 |
+
raise ValueError("Input video is required for Video model")
|
86 |
+
|
87 |
+
# Store the input video
|
88 |
+
processed_inputs['input_video'] = input_video
|
89 |
+
|
90 |
+
# Note: The following code will be executed in the worker function:
|
91 |
+
# 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels
|
92 |
+
# 2. Then it will store these values for later use:
|
93 |
+
# input_video_pixels = input_video_pixels.cpu()
|
94 |
+
# video_latents = video_latents.cpu()
|
95 |
+
#
|
96 |
+
# 3. If the generator has the set_full_video_latents method, it will store the video latents:
|
97 |
+
# if hasattr(current_generator, 'set_full_video_latents'):
|
98 |
+
# current_generator.set_full_video_latents(video_latents.clone())
|
99 |
+
# print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
|
100 |
+
#
|
101 |
+
# 4. For the Video model, history_latents is initialized with the video_latents:
|
102 |
+
# history_latents = video_latents
|
103 |
+
# print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
|
104 |
+
processed_inputs['input_files_dir'] = job_params.get('input_files_dir')
|
105 |
+
|
106 |
+
# Pass through the combine_with_source parameter if it exists
|
107 |
+
if 'combine_with_source' in job_params:
|
108 |
+
processed_inputs['combine_with_source'] = job_params.get('combine_with_source')
|
109 |
+
print(f"Video pipeline: combine_with_source = {processed_inputs['combine_with_source']}")
|
110 |
+
|
111 |
+
# Pass through the num_cleaned_frames parameter if it exists
|
112 |
+
if 'num_cleaned_frames' in job_params:
|
113 |
+
processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames')
|
114 |
+
print(f"Video pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}")
|
115 |
+
|
116 |
+
# Get resolution parameters
|
117 |
+
resolutionW = job_params.get('resolutionW', 640)
|
118 |
+
resolutionH = job_params.get('resolutionH', 640)
|
119 |
+
|
120 |
+
# Find nearest bucket size
|
121 |
+
height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
|
122 |
+
|
123 |
+
# Store the dimensions
|
124 |
+
processed_inputs['height'] = height
|
125 |
+
processed_inputs['width'] = width
|
126 |
+
|
127 |
+
return processed_inputs
|
128 |
+
|
129 |
+
def handle_results(self, job_params, result):
|
130 |
+
"""
|
131 |
+
Handle the results of the Video generation.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
job_params: The job parameters
|
135 |
+
result: The generation result
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Processed result
|
139 |
+
"""
|
140 |
+
# For Video generation, we just return the result as-is
|
141 |
+
return result
|
142 |
+
|
143 |
+
# Using the centralized create_metadata method from BasePipeline
|
modules/pipelines/video_tools.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
from diffusers_helper.utils import save_bcthw_as_mp4
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def combine_videos_sequentially_from_tensors(processed_input_frames_np,
|
9 |
+
generated_frames_pt,
|
10 |
+
output_path,
|
11 |
+
target_fps,
|
12 |
+
crf_value):
|
13 |
+
"""
|
14 |
+
Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially
|
15 |
+
and saves the result as an MP4 video using save_bcthw_as_mp4.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8.
|
19 |
+
generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1].
|
20 |
+
(This will be history_pixels from worker.py)
|
21 |
+
output_path: Path to save the combined video.
|
22 |
+
target_fps: FPS for the output combined video.
|
23 |
+
crf_value: CRF value for video encoding.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Path to the combined video, or None if an error occurs.
|
27 |
+
"""
|
28 |
+
try:
|
29 |
+
# 1. Convert processed_input_frames_np to PyTorch tensor BCTHW, float32, [-1,1]
|
30 |
+
# processed_input_frames_np shape: (T_in, H, W_in, C)
|
31 |
+
input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 # (T,H,W,C)
|
32 |
+
input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) # (C,T,H,W)
|
33 |
+
input_frames_pt = input_frames_pt.unsqueeze(0) # (1,C,T,H,W) -> BCTHW
|
34 |
+
|
35 |
+
# Ensure generated_frames_pt is on the same device and dtype for concatenation
|
36 |
+
input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype)
|
37 |
+
|
38 |
+
# 2. Dimension Check (Heights and Widths should match)
|
39 |
+
# They should match, since the input frames should have been processed to match the generation resolution.
|
40 |
+
# But sanity check to ensure no mismatch occurs when the code is refactored.
|
41 |
+
if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: # Compare (H,W)
|
42 |
+
print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.")
|
43 |
+
print("Attempting to proceed, but this might lead to errors or unexpected video output.")
|
44 |
+
# Potentially add resizing logic here if necessary, but for now, assume they match
|
45 |
+
|
46 |
+
# 3. Concatenate Tensors along the time dimension (dim=2 for BCTHW)
|
47 |
+
combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2)
|
48 |
+
|
49 |
+
# 4. Save
|
50 |
+
save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value)
|
51 |
+
print(f"Sequentially combined video (from tensors) saved to {output_path}")
|
52 |
+
return output_path
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}")
|
55 |
+
import traceback
|
56 |
+
traceback.print_exc()
|
57 |
+
return None
|
modules/pipelines/worker.py
ADDED
@@ -0,0 +1,1150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
import einops
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import datetime
|
9 |
+
from PIL import Image
|
10 |
+
from PIL.PngImagePlugin import PngInfo
|
11 |
+
from diffusers_helper.models.mag_cache import MagCache
|
12 |
+
from diffusers_helper.utils import save_bcthw_as_mp4, generate_timestamp, resize_and_center_crop
|
13 |
+
from diffusers_helper.memory import cpu, gpu, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, unload_complete_models, load_model_as_complete
|
14 |
+
from diffusers_helper.thread_utils import AsyncStream
|
15 |
+
from diffusers_helper.gradio.progress_bar import make_progress_bar_html
|
16 |
+
from diffusers_helper.hunyuan import vae_decode
|
17 |
+
from modules.video_queue import JobStatus
|
18 |
+
from modules.prompt_handler import parse_timestamped_prompt
|
19 |
+
from modules.generators import create_model_generator
|
20 |
+
from modules.pipelines.video_tools import combine_videos_sequentially_from_tensors
|
21 |
+
from modules import DUMMY_LORA_NAME # Import the constant
|
22 |
+
from modules.llm_captioner import unload_captioning_model
|
23 |
+
from modules.llm_enhancer import unload_enhancing_model
|
24 |
+
from . import create_pipeline
|
25 |
+
|
26 |
+
import __main__ as studio_module # Get a reference to the __main__ module object
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device, prompt_embedding_cache):
|
30 |
+
"""
|
31 |
+
Retrieves prompt embeddings from cache or encodes them if not found.
|
32 |
+
Stores encoded embeddings (on CPU) in the cache.
|
33 |
+
Returns embeddings moved to the target_device.
|
34 |
+
"""
|
35 |
+
from diffusers_helper.hunyuan import encode_prompt_conds, crop_or_pad_yield_mask
|
36 |
+
|
37 |
+
if prompt in prompt_embedding_cache:
|
38 |
+
print(f"Cache hit for prompt: {prompt[:60]}...")
|
39 |
+
llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt]
|
40 |
+
# Move cached embeddings (from CPU) to the target device
|
41 |
+
llama_vec = llama_vec_cpu.to(target_device)
|
42 |
+
llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None
|
43 |
+
clip_l_pooler = clip_l_pooler_cpu.to(target_device)
|
44 |
+
return llama_vec, llama_attention_mask, clip_l_pooler
|
45 |
+
else:
|
46 |
+
print(f"Cache miss for prompt: {prompt[:60]}...")
|
47 |
+
llama_vec, clip_l_pooler = encode_prompt_conds(
|
48 |
+
prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
|
49 |
+
)
|
50 |
+
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
|
51 |
+
# Store CPU copies in cache
|
52 |
+
prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu())
|
53 |
+
# Return embeddings already on the target device (as encode_prompt_conds uses the model's device)
|
54 |
+
return llama_vec, llama_attention_mask, clip_l_pooler
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def worker(
|
58 |
+
model_type,
|
59 |
+
input_image,
|
60 |
+
end_frame_image, # The end frame image (numpy array or None)
|
61 |
+
end_frame_strength, # Influence of the end frame
|
62 |
+
prompt_text,
|
63 |
+
n_prompt,
|
64 |
+
seed,
|
65 |
+
total_second_length,
|
66 |
+
latent_window_size,
|
67 |
+
steps,
|
68 |
+
cfg,
|
69 |
+
gs,
|
70 |
+
rs,
|
71 |
+
use_teacache,
|
72 |
+
teacache_num_steps,
|
73 |
+
teacache_rel_l1_thresh,
|
74 |
+
use_magcache,
|
75 |
+
magcache_threshold,
|
76 |
+
magcache_max_consecutive_skips,
|
77 |
+
magcache_retention_ratio,
|
78 |
+
blend_sections,
|
79 |
+
latent_type,
|
80 |
+
selected_loras,
|
81 |
+
has_input_image,
|
82 |
+
lora_values=None,
|
83 |
+
job_stream=None,
|
84 |
+
output_dir=None,
|
85 |
+
metadata_dir=None,
|
86 |
+
input_files_dir=None, # Add input_files_dir parameter
|
87 |
+
input_image_path=None, # Add input_image_path parameter
|
88 |
+
end_frame_image_path=None, # Add end_frame_image_path parameter
|
89 |
+
resolutionW=640, # Add resolution parameter with default value
|
90 |
+
resolutionH=640,
|
91 |
+
lora_loaded_names=[],
|
92 |
+
input_video=None, # Add input_video parameter with default value of None
|
93 |
+
combine_with_source=None, # Add combine_with_source parameter
|
94 |
+
num_cleaned_frames=5, # Add num_cleaned_frames parameter with default value
|
95 |
+
save_metadata_checked=True # Add save_metadata_checked parameter
|
96 |
+
):
|
97 |
+
"""
|
98 |
+
Worker function for video generation.
|
99 |
+
"""
|
100 |
+
|
101 |
+
random_generator = torch.Generator("cpu").manual_seed(seed)
|
102 |
+
|
103 |
+
unload_enhancing_model()
|
104 |
+
unload_captioning_model()
|
105 |
+
|
106 |
+
# Filter out the dummy LoRA from selected_loras at the very beginning of the worker
|
107 |
+
actual_selected_loras_for_worker = []
|
108 |
+
if isinstance(selected_loras, list):
|
109 |
+
actual_selected_loras_for_worker = [lora for lora in selected_loras if lora != DUMMY_LORA_NAME]
|
110 |
+
if DUMMY_LORA_NAME in selected_loras and DUMMY_LORA_NAME in actual_selected_loras_for_worker: # Should not happen if filter works
|
111 |
+
print(f"Worker.py: Error - '{DUMMY_LORA_NAME}' was selected but not filtered out.")
|
112 |
+
elif DUMMY_LORA_NAME in selected_loras:
|
113 |
+
print(f"Worker.py: Filtered out '{DUMMY_LORA_NAME}' from selected LoRAs.")
|
114 |
+
elif selected_loras is not None: # If it's a single string (should not happen with multiselect dropdown)
|
115 |
+
if selected_loras != DUMMY_LORA_NAME:
|
116 |
+
actual_selected_loras_for_worker = [selected_loras]
|
117 |
+
selected_loras = actual_selected_loras_for_worker
|
118 |
+
print(f"Worker: Selected LoRAs for this worker: {selected_loras}")
|
119 |
+
|
120 |
+
# Import globals from the main module
|
121 |
+
from __main__ import high_vram, args, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, prompt_embedding_cache, settings, stream
|
122 |
+
|
123 |
+
# Ensure any existing LoRAs are unloaded from the current generator
|
124 |
+
if studio_module.current_generator is not None:
|
125 |
+
print("Worker: Unloading LoRAs from studio_module.current_generator")
|
126 |
+
studio_module.current_generator.unload_loras()
|
127 |
+
import gc
|
128 |
+
gc.collect()
|
129 |
+
if torch.cuda.is_available():
|
130 |
+
torch.cuda.empty_cache()
|
131 |
+
|
132 |
+
stream_to_use = job_stream if job_stream is not None else stream
|
133 |
+
|
134 |
+
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
|
135 |
+
total_latent_sections = int(max(round(total_latent_sections), 1))
|
136 |
+
|
137 |
+
# --- Total progress tracking ---
|
138 |
+
total_steps = total_latent_sections * steps # Total diffusion steps over all segments
|
139 |
+
step_durations = [] # Rolling history of recent step durations for ETA
|
140 |
+
last_step_time = time.time()
|
141 |
+
|
142 |
+
# Parse the timestamped prompt with boundary snapping and reversing
|
143 |
+
# prompt_text should now be the original string from the job queue
|
144 |
+
prompt_sections = parse_timestamped_prompt(prompt_text, total_second_length, latent_window_size, model_type)
|
145 |
+
job_id = generate_timestamp()
|
146 |
+
|
147 |
+
# Initialize progress data with a clear starting message and dummy preview
|
148 |
+
dummy_preview = np.zeros((64, 64, 3), dtype=np.uint8)
|
149 |
+
initial_progress_data = {
|
150 |
+
'preview': dummy_preview,
|
151 |
+
'desc': 'Starting job...',
|
152 |
+
'html': make_progress_bar_html(0, 'Starting job...')
|
153 |
+
}
|
154 |
+
|
155 |
+
# Store initial progress data in the job object if using a job stream
|
156 |
+
if job_stream is not None:
|
157 |
+
try:
|
158 |
+
from __main__ import job_queue
|
159 |
+
job = job_queue.get_job(job_id)
|
160 |
+
if job:
|
161 |
+
job.progress_data = initial_progress_data
|
162 |
+
except Exception as e:
|
163 |
+
print(f"Error storing initial progress data: {e}")
|
164 |
+
|
165 |
+
# Push initial progress update to both streams
|
166 |
+
stream_to_use.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...'))))
|
167 |
+
|
168 |
+
# Push job ID to stream to ensure monitoring connection
|
169 |
+
stream_to_use.output_queue.push(('job_id', job_id))
|
170 |
+
stream_to_use.output_queue.push(('monitor_job', job_id))
|
171 |
+
|
172 |
+
# Always push to the main stream to ensure the UI is updated
|
173 |
+
from __main__ import stream as main_stream
|
174 |
+
if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use
|
175 |
+
print(f"Pushing initial progress update to main stream for job {job_id}")
|
176 |
+
main_stream.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...'))))
|
177 |
+
|
178 |
+
# Push job ID to main stream to ensure monitoring connection
|
179 |
+
main_stream.output_queue.push(('job_id', job_id))
|
180 |
+
main_stream.output_queue.push(('monitor_job', job_id))
|
181 |
+
|
182 |
+
try:
|
183 |
+
# Create a settings dictionary for the pipeline
|
184 |
+
pipeline_settings = {
|
185 |
+
"output_dir": output_dir,
|
186 |
+
"metadata_dir": metadata_dir,
|
187 |
+
"input_files_dir": input_files_dir,
|
188 |
+
"save_metadata": settings.get("save_metadata", True),
|
189 |
+
"gpu_memory_preservation": settings.get("gpu_memory_preservation", 6),
|
190 |
+
"mp4_crf": settings.get("mp4_crf", 16),
|
191 |
+
"clean_up_videos": settings.get("clean_up_videos", True),
|
192 |
+
"gradio_temp_dir": settings.get("gradio_temp_dir", "./gradio_temp"),
|
193 |
+
"high_vram": high_vram
|
194 |
+
}
|
195 |
+
|
196 |
+
# Create the appropriate pipeline for the model type
|
197 |
+
pipeline = create_pipeline(model_type, pipeline_settings)
|
198 |
+
|
199 |
+
# Create job parameters dictionary
|
200 |
+
job_params = {
|
201 |
+
'model_type': model_type,
|
202 |
+
'input_image': input_image,
|
203 |
+
'end_frame_image': end_frame_image,
|
204 |
+
'end_frame_strength': end_frame_strength,
|
205 |
+
'prompt_text': prompt_text,
|
206 |
+
'n_prompt': n_prompt,
|
207 |
+
'seed': seed,
|
208 |
+
'total_second_length': total_second_length,
|
209 |
+
'latent_window_size': latent_window_size,
|
210 |
+
'steps': steps,
|
211 |
+
'cfg': cfg,
|
212 |
+
'gs': gs,
|
213 |
+
'rs': rs,
|
214 |
+
'blend_sections': blend_sections,
|
215 |
+
'latent_type': latent_type,
|
216 |
+
'use_teacache': use_teacache,
|
217 |
+
'teacache_num_steps': teacache_num_steps,
|
218 |
+
'teacache_rel_l1_thresh': teacache_rel_l1_thresh,
|
219 |
+
'use_magcache': use_magcache,
|
220 |
+
'magcache_threshold': magcache_threshold,
|
221 |
+
'magcache_max_consecutive_skips': magcache_max_consecutive_skips,
|
222 |
+
'magcache_retention_ratio': magcache_retention_ratio,
|
223 |
+
'selected_loras': selected_loras,
|
224 |
+
'has_input_image': has_input_image,
|
225 |
+
'lora_values': lora_values,
|
226 |
+
'resolutionW': resolutionW,
|
227 |
+
'resolutionH': resolutionH,
|
228 |
+
'lora_loaded_names': lora_loaded_names,
|
229 |
+
'input_image_path': input_image_path,
|
230 |
+
'end_frame_image_path': end_frame_image_path,
|
231 |
+
'combine_with_source': combine_with_source,
|
232 |
+
'num_cleaned_frames': num_cleaned_frames,
|
233 |
+
'save_metadata_checked': save_metadata_checked # Ensure it's in job_params for internal use
|
234 |
+
}
|
235 |
+
|
236 |
+
# Validate parameters
|
237 |
+
is_valid, error_message = pipeline.validate_parameters(job_params)
|
238 |
+
if not is_valid:
|
239 |
+
raise ValueError(f"Invalid parameters: {error_message}")
|
240 |
+
|
241 |
+
# Prepare parameters
|
242 |
+
job_params = pipeline.prepare_parameters(job_params)
|
243 |
+
|
244 |
+
if not high_vram:
|
245 |
+
# Unload everything *except* the potentially active transformer
|
246 |
+
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae)
|
247 |
+
if studio_module.current_generator is not None and studio_module.current_generator.transformer is not None:
|
248 |
+
offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
|
249 |
+
|
250 |
+
|
251 |
+
# --- Model Loading / Switching ---
|
252 |
+
print(f"Worker starting for model type: {model_type}")
|
253 |
+
print(f"Worker: Before model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}")
|
254 |
+
|
255 |
+
# Create the appropriate model generator
|
256 |
+
new_generator = create_model_generator(
|
257 |
+
model_type,
|
258 |
+
text_encoder=text_encoder,
|
259 |
+
text_encoder_2=text_encoder_2,
|
260 |
+
tokenizer=tokenizer,
|
261 |
+
tokenizer_2=tokenizer_2,
|
262 |
+
vae=vae,
|
263 |
+
image_encoder=image_encoder,
|
264 |
+
feature_extractor=feature_extractor,
|
265 |
+
high_vram=high_vram,
|
266 |
+
prompt_embedding_cache=prompt_embedding_cache,
|
267 |
+
offline=args.offline,
|
268 |
+
settings=settings
|
269 |
+
)
|
270 |
+
|
271 |
+
# Update the global generator
|
272 |
+
# This modifies the 'current_generator' attribute OF THE '__main__' MODULE OBJECT
|
273 |
+
studio_module.current_generator = new_generator
|
274 |
+
print(f"Worker: AFTER model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}")
|
275 |
+
if studio_module.current_generator:
|
276 |
+
print(f"Worker: studio_module.current_generator.transformer is {type(studio_module.current_generator.transformer)}")
|
277 |
+
|
278 |
+
# Load the transformer model
|
279 |
+
studio_module.current_generator.load_model()
|
280 |
+
|
281 |
+
# Ensure the model has no LoRAs loaded
|
282 |
+
print(f"Ensuring {model_type} model has no LoRAs loaded")
|
283 |
+
studio_module.current_generator.unload_loras()
|
284 |
+
|
285 |
+
# Preprocess inputs
|
286 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Preprocessing inputs...'))))
|
287 |
+
processed_inputs = pipeline.preprocess_inputs(job_params)
|
288 |
+
|
289 |
+
# Update job_params with processed inputs
|
290 |
+
job_params.update(processed_inputs)
|
291 |
+
|
292 |
+
# Save the starting image directly to the output directory with full metadata
|
293 |
+
# Check both global settings and job-specific save_metadata_checked parameter
|
294 |
+
if settings.get("save_metadata") and job_params.get('save_metadata_checked', True) and job_params.get('input_image') is not None:
|
295 |
+
try:
|
296 |
+
# Import the save_job_start_image function from metadata_utils
|
297 |
+
from modules.pipelines.metadata_utils import save_job_start_image, create_metadata
|
298 |
+
|
299 |
+
# Create comprehensive metadata for the job
|
300 |
+
metadata_dict = create_metadata(job_params, job_id, settings)
|
301 |
+
|
302 |
+
# Save the starting image with metadata
|
303 |
+
save_job_start_image(job_params, job_id, settings)
|
304 |
+
|
305 |
+
print(f"Saved metadata and starting image for job {job_id}")
|
306 |
+
except Exception as e:
|
307 |
+
print(f"Error saving starting image and metadata: {e}")
|
308 |
+
traceback.print_exc()
|
309 |
+
|
310 |
+
# Pre-encode all prompts
|
311 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding all prompts...'))))
|
312 |
+
|
313 |
+
# THE FOLLOWING CODE SHOULD BE INSIDE THE TRY BLOCK
|
314 |
+
if not high_vram:
|
315 |
+
fake_diffusers_current_device(text_encoder, gpu)
|
316 |
+
load_model_as_complete(text_encoder_2, target_device=gpu)
|
317 |
+
|
318 |
+
# PROMPT BLENDING: Pre-encode all prompts and store in a list in order
|
319 |
+
unique_prompts = []
|
320 |
+
for section in prompt_sections:
|
321 |
+
if section.prompt not in unique_prompts:
|
322 |
+
unique_prompts.append(section.prompt)
|
323 |
+
|
324 |
+
encoded_prompts = {}
|
325 |
+
for prompt in unique_prompts:
|
326 |
+
# Use the helper function for caching and encoding
|
327 |
+
llama_vec, llama_attention_mask, clip_l_pooler = get_cached_or_encode_prompt(
|
328 |
+
prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache
|
329 |
+
)
|
330 |
+
encoded_prompts[prompt] = (llama_vec, llama_attention_mask, clip_l_pooler)
|
331 |
+
|
332 |
+
# PROMPT BLENDING: Build a list of (start_section_idx, prompt) for each prompt
|
333 |
+
prompt_change_indices = []
|
334 |
+
last_prompt = None
|
335 |
+
for idx, section in enumerate(prompt_sections):
|
336 |
+
if section.prompt != last_prompt:
|
337 |
+
prompt_change_indices.append((idx, section.prompt))
|
338 |
+
last_prompt = section.prompt
|
339 |
+
|
340 |
+
# Encode negative prompt
|
341 |
+
if cfg == 1:
|
342 |
+
llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = (
|
343 |
+
torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][0]),
|
344 |
+
torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][1]),
|
345 |
+
torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][2])
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
# Use the helper function for caching and encoding negative prompt
|
349 |
+
# Ensure n_prompt is a string
|
350 |
+
n_prompt_str = str(n_prompt) if n_prompt is not None else ""
|
351 |
+
llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = get_cached_or_encode_prompt(
|
352 |
+
n_prompt_str, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache
|
353 |
+
)
|
354 |
+
|
355 |
+
end_of_input_video_embedding = None # Video model end frame CLIP Vision embedding
|
356 |
+
# Process input image or video based on model type
|
357 |
+
if model_type == "Video" or model_type == "Video F1":
|
358 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...'))))
|
359 |
+
|
360 |
+
# Encode the video using the VideoModelGenerator
|
361 |
+
start_latent, input_image_np, video_latents, fps, height, width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np = studio_module.current_generator.video_encode(
|
362 |
+
video_path=job_params['input_image'], # For Video model, input_image contains the video path
|
363 |
+
resolution=job_params['resolutionW'],
|
364 |
+
no_resize=False,
|
365 |
+
vae_batch_size=16,
|
366 |
+
device=gpu,
|
367 |
+
input_files_dir=job_params['input_files_dir']
|
368 |
+
)
|
369 |
+
|
370 |
+
if end_of_input_video_image_np is not None:
|
371 |
+
try:
|
372 |
+
from modules.pipelines.metadata_utils import save_last_video_frame
|
373 |
+
save_last_video_frame(job_params, job_id, settings, end_of_input_video_image_np)
|
374 |
+
except Exception as e:
|
375 |
+
print(f"Error saving last video frame: {e}")
|
376 |
+
traceback.print_exc()
|
377 |
+
|
378 |
+
# RT_BORG: retained only until we make our final decisions on how to handle combining videos
|
379 |
+
# Only necessary to retain resized frames to produce a combined video with source frames of the right dimensions
|
380 |
+
#if combine_with_source:
|
381 |
+
# # Store input_frames_resized_np in job_params for later use
|
382 |
+
# job_params['input_frames_resized_np'] = input_frames_resized_np
|
383 |
+
|
384 |
+
# CLIP Vision encoding for the first frame
|
385 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
|
386 |
+
|
387 |
+
if not high_vram:
|
388 |
+
load_model_as_complete(image_encoder, target_device=gpu)
|
389 |
+
|
390 |
+
from diffusers_helper.clip_vision import hf_clip_vision_encode
|
391 |
+
image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
|
392 |
+
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
|
393 |
+
|
394 |
+
end_of_input_video_embedding = hf_clip_vision_encode(end_of_input_video_image_np, feature_extractor, image_encoder).last_hidden_state
|
395 |
+
|
396 |
+
# Store the input video pixels and latents for later use
|
397 |
+
input_video_pixels = input_video_pixels.cpu()
|
398 |
+
video_latents = video_latents.cpu()
|
399 |
+
|
400 |
+
# Store the full video latents in the generator instance for preparing clean latents
|
401 |
+
if hasattr(studio_module.current_generator, 'set_full_video_latents'):
|
402 |
+
studio_module.current_generator.set_full_video_latents(video_latents.clone())
|
403 |
+
print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
|
404 |
+
|
405 |
+
# For Video model, history_latents is initialized with the video_latents
|
406 |
+
history_latents = video_latents
|
407 |
+
|
408 |
+
# Store the last frame of the video latents as start_latent for the model
|
409 |
+
start_latent = video_latents[:, :, -1:].cpu()
|
410 |
+
print(f"Using last frame of input video as start_latent. Shape: {start_latent.shape}")
|
411 |
+
print(f"Placed last frame of video at position 0 in history_latents")
|
412 |
+
|
413 |
+
print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
|
414 |
+
|
415 |
+
# Store the number of frames in the input video for later use
|
416 |
+
input_video_frame_count = video_latents.shape[2]
|
417 |
+
else:
|
418 |
+
# Regular image processing
|
419 |
+
height = job_params['height']
|
420 |
+
width = job_params['width']
|
421 |
+
|
422 |
+
if not has_input_image and job_params.get('latent_type') == 'Noise':
|
423 |
+
# print("************************************************")
|
424 |
+
# print("** Using 'Noise' latent type for T2V workflow **")
|
425 |
+
# print("************************************************")
|
426 |
+
|
427 |
+
# Create a random latent to serve as the initial VAE context anchor.
|
428 |
+
# This provides a random starting point without visual bias.
|
429 |
+
start_latent = torch.randn(
|
430 |
+
(1, 16, 1, height // 8, width // 8),
|
431 |
+
generator=random_generator, device=random_generator.device
|
432 |
+
).to(device=gpu, dtype=torch.float32)
|
433 |
+
|
434 |
+
# Create a neutral black image to generate a valid "null" CLIP Vision embedding.
|
435 |
+
# This provides the model with a valid, in-distribution unconditional image prompt.
|
436 |
+
# RT_BORG: Clip doesn't understand noise at all. I also tried using
|
437 |
+
# image_encoder_last_hidden_state = torch.zeros((1, 257, 1152), device=gpu, dtype=studio_module.current_generator.transformer.dtype)
|
438 |
+
# to represent a "null" CLIP Vision embedding in the shape for the CLIP encoder,
|
439 |
+
# but the Video model wasn't trained to handle zeros, so using a neutral black image for CLIP.
|
440 |
+
|
441 |
+
black_image_np = np.zeros((height, width, 3), dtype=np.uint8)
|
442 |
+
|
443 |
+
if not high_vram:
|
444 |
+
load_model_as_complete(image_encoder, target_device=gpu)
|
445 |
+
|
446 |
+
from diffusers_helper.clip_vision import hf_clip_vision_encode
|
447 |
+
image_encoder_output = hf_clip_vision_encode(black_image_np, feature_extractor, image_encoder)
|
448 |
+
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
|
449 |
+
|
450 |
+
else:
|
451 |
+
input_image_np = job_params['input_image']
|
452 |
+
|
453 |
+
input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
|
454 |
+
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
|
455 |
+
|
456 |
+
# Start image encoding with VAE
|
457 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
|
458 |
+
|
459 |
+
if not high_vram:
|
460 |
+
load_model_as_complete(vae, target_device=gpu)
|
461 |
+
|
462 |
+
from diffusers_helper.hunyuan import vae_encode
|
463 |
+
start_latent = vae_encode(input_image_pt, vae)
|
464 |
+
|
465 |
+
# CLIP Vision
|
466 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
|
467 |
+
|
468 |
+
if not high_vram:
|
469 |
+
load_model_as_complete(image_encoder, target_device=gpu)
|
470 |
+
|
471 |
+
from diffusers_helper.clip_vision import hf_clip_vision_encode
|
472 |
+
image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
|
473 |
+
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
|
474 |
+
|
475 |
+
# VAE encode end_frame_image if provided
|
476 |
+
end_frame_latent = None
|
477 |
+
# VAE encode end_frame_image resized to output dimensions, if provided
|
478 |
+
end_frame_output_dimensions_latent = None
|
479 |
+
end_clip_embedding = None # Video model end frame CLIP Vision embedding
|
480 |
+
|
481 |
+
# Models with end_frame_image processing
|
482 |
+
if (model_type == "Original with Endframe" or model_type == "Video") and job_params.get('end_frame_image') is not None:
|
483 |
+
print(f"Processing end frame for {model_type} model...")
|
484 |
+
end_frame_image = job_params['end_frame_image']
|
485 |
+
|
486 |
+
if not isinstance(end_frame_image, np.ndarray):
|
487 |
+
print(f"Warning: end_frame_image is not a numpy array (type: {type(end_frame_image)}). Attempting conversion or skipping.")
|
488 |
+
try:
|
489 |
+
end_frame_image = np.array(end_frame_image)
|
490 |
+
except Exception as e_conv:
|
491 |
+
print(f"Could not convert end_frame_image to numpy array: {e_conv}. Skipping end frame.")
|
492 |
+
end_frame_image = None
|
493 |
+
|
494 |
+
if end_frame_image is not None:
|
495 |
+
# Use the main job's target width/height (bucket dimensions) for the end frame
|
496 |
+
end_frame_np = job_params['end_frame_image']
|
497 |
+
|
498 |
+
if settings.get("save_metadata"):
|
499 |
+
Image.fromarray(end_frame_np).save(os.path.join(metadata_dir, f'{job_id}_end_frame_processed.png'))
|
500 |
+
|
501 |
+
end_frame_pt = torch.from_numpy(end_frame_np).float() / 127.5 - 1
|
502 |
+
end_frame_pt = end_frame_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W]
|
503 |
+
|
504 |
+
if not high_vram: load_model_as_complete(vae, target_device=gpu) # Ensure VAE is loaded
|
505 |
+
from diffusers_helper.hunyuan import vae_encode
|
506 |
+
end_frame_latent = vae_encode(end_frame_pt, vae)
|
507 |
+
|
508 |
+
# end_frame_output_dimensions_latent is sized like the start_latent and generated latents
|
509 |
+
end_frame_output_dimensions_np = resize_and_center_crop(end_frame_np, width, height)
|
510 |
+
end_frame_output_dimensions_pt = torch.from_numpy(end_frame_output_dimensions_np).float() / 127.5 - 1
|
511 |
+
end_frame_output_dimensions_pt = end_frame_output_dimensions_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W]
|
512 |
+
end_frame_output_dimensions_latent = vae_encode(end_frame_output_dimensions_pt, vae)
|
513 |
+
|
514 |
+
print("End frame VAE encoded.")
|
515 |
+
|
516 |
+
# Video Mode CLIP Vision encoding for end frame
|
517 |
+
if model_type == "Video":
|
518 |
+
if not high_vram: # Ensure image_encoder is on GPU for this operation
|
519 |
+
load_model_as_complete(image_encoder, target_device=gpu)
|
520 |
+
from diffusers_helper.clip_vision import hf_clip_vision_encode
|
521 |
+
end_clip_embedding = hf_clip_vision_encode(end_frame_np, feature_extractor, image_encoder).last_hidden_state
|
522 |
+
end_clip_embedding = end_clip_embedding.to(studio_module.current_generator.transformer.dtype)
|
523 |
+
# Need that dtype conversion for end_clip_embedding? I don't think so, but it was in the original PR.
|
524 |
+
|
525 |
+
if not high_vram: # Offload VAE and image_encoder if they were loaded
|
526 |
+
offload_model_from_device_for_memory_preservation(vae, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
|
527 |
+
offload_model_from_device_for_memory_preservation(image_encoder, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
|
528 |
+
|
529 |
+
# Dtype
|
530 |
+
for prompt_key in encoded_prompts:
|
531 |
+
llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[prompt_key]
|
532 |
+
llama_vec = llama_vec.to(studio_module.current_generator.transformer.dtype)
|
533 |
+
clip_l_pooler = clip_l_pooler.to(studio_module.current_generator.transformer.dtype)
|
534 |
+
encoded_prompts[prompt_key] = (llama_vec, llama_attention_mask, clip_l_pooler)
|
535 |
+
|
536 |
+
llama_vec_n = llama_vec_n.to(studio_module.current_generator.transformer.dtype)
|
537 |
+
clip_l_pooler_n = clip_l_pooler_n.to(studio_module.current_generator.transformer.dtype)
|
538 |
+
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(studio_module.current_generator.transformer.dtype)
|
539 |
+
|
540 |
+
# Sampling
|
541 |
+
stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
|
542 |
+
|
543 |
+
num_frames = latent_window_size * 4 - 3
|
544 |
+
|
545 |
+
# Initialize total_generated_latent_frames for Video model
|
546 |
+
total_generated_latent_frames = 0 # Default initialization for all model types
|
547 |
+
|
548 |
+
# Initialize history latents based on model type
|
549 |
+
if model_type != "Video" and model_type != "Video F1": # Skip for Video models as we already initialized it
|
550 |
+
history_latents = studio_module.current_generator.prepare_history_latents(height, width)
|
551 |
+
|
552 |
+
# For F1 model, initialize with start latent
|
553 |
+
if model_type == "F1":
|
554 |
+
history_latents = studio_module.current_generator.initialize_with_start_latent(history_latents, start_latent, has_input_image)
|
555 |
+
# If we had a real start image, it was just added to the history_latents
|
556 |
+
total_generated_latent_frames = 1 if has_input_image else 0
|
557 |
+
elif model_type == "Original" or model_type == "Original with Endframe":
|
558 |
+
total_generated_latent_frames = 0
|
559 |
+
|
560 |
+
history_pixels = None
|
561 |
+
|
562 |
+
# Get latent paddings from the generator
|
563 |
+
latent_paddings = studio_module.current_generator.get_latent_paddings(total_latent_sections)
|
564 |
+
|
565 |
+
# PROMPT BLENDING: Track section index
|
566 |
+
section_idx = 0
|
567 |
+
|
568 |
+
# Load LoRAs if selected
|
569 |
+
if selected_loras:
|
570 |
+
lora_folder_from_settings = settings.get("lora_dir")
|
571 |
+
studio_module.current_generator.load_loras(selected_loras, lora_folder_from_settings, lora_loaded_names, lora_values)
|
572 |
+
|
573 |
+
# --- Callback for progress ---
|
574 |
+
def callback(d):
|
575 |
+
nonlocal last_step_time, step_durations
|
576 |
+
|
577 |
+
# Check for cancellation signal
|
578 |
+
if stream_to_use.input_queue.top() == 'end':
|
579 |
+
print("Cancellation signal detected in callback")
|
580 |
+
return 'cancel' # Return a signal that will be checked in the sampler
|
581 |
+
|
582 |
+
now_time = time.time()
|
583 |
+
# Record duration between diffusion steps (skip first where duration may include setup)
|
584 |
+
if last_step_time is not None:
|
585 |
+
step_delta = now_time - last_step_time
|
586 |
+
if step_delta > 0:
|
587 |
+
step_durations.append(step_delta)
|
588 |
+
if len(step_durations) > 30: # Keep only recent 30 steps
|
589 |
+
step_durations.pop(0)
|
590 |
+
last_step_time = now_time
|
591 |
+
avg_step = sum(step_durations) / len(step_durations) if step_durations else 0.0
|
592 |
+
|
593 |
+
preview = d['denoised']
|
594 |
+
from diffusers_helper.hunyuan import vae_decode_fake
|
595 |
+
preview = vae_decode_fake(preview)
|
596 |
+
preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
|
597 |
+
preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
|
598 |
+
|
599 |
+
# --- Progress & ETA logic ---
|
600 |
+
# Current segment progress
|
601 |
+
current_step = d['i'] + 1
|
602 |
+
percentage = int(100.0 * current_step / steps)
|
603 |
+
|
604 |
+
# Total progress
|
605 |
+
total_steps_done = section_idx * steps + current_step
|
606 |
+
total_percentage = int(100.0 * total_steps_done / total_steps)
|
607 |
+
|
608 |
+
# ETA calculations
|
609 |
+
def fmt_eta(sec):
|
610 |
+
try:
|
611 |
+
return str(datetime.timedelta(seconds=int(sec)))
|
612 |
+
except Exception:
|
613 |
+
return "--:--"
|
614 |
+
|
615 |
+
segment_eta = (steps - current_step) * avg_step if avg_step else 0
|
616 |
+
total_eta = (total_steps - total_steps_done) * avg_step if avg_step else 0
|
617 |
+
|
618 |
+
segment_hint = f'Sampling {current_step}/{steps} ETA {fmt_eta(segment_eta)}'
|
619 |
+
total_hint = f'Total {total_steps_done}/{total_steps} ETA {fmt_eta(total_eta)}'
|
620 |
+
|
621 |
+
# For Video model, add the input video frame count when calculating current position
|
622 |
+
if model_type == "Video":
|
623 |
+
# Calculate the time position including the input video frames
|
624 |
+
input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time
|
625 |
+
current_pos = input_video_time + (total_generated_latent_frames * 4 - 3) / 30
|
626 |
+
# Original position is the remaining time to generate
|
627 |
+
original_pos = total_second_length - (total_generated_latent_frames * 4 - 3) / 30
|
628 |
+
else:
|
629 |
+
# For other models, calculate as before
|
630 |
+
current_pos = (total_generated_latent_frames * 4 - 3) / 30
|
631 |
+
original_pos = total_second_length - current_pos
|
632 |
+
|
633 |
+
# Ensure positions are not negative
|
634 |
+
if current_pos < 0: current_pos = 0
|
635 |
+
if original_pos < 0: original_pos = 0
|
636 |
+
|
637 |
+
hint = segment_hint # deprecated variable kept to minimise other code changes
|
638 |
+
desc = studio_module.current_generator.format_position_description(
|
639 |
+
total_generated_latent_frames,
|
640 |
+
current_pos,
|
641 |
+
original_pos,
|
642 |
+
current_prompt
|
643 |
+
)
|
644 |
+
|
645 |
+
# Create progress data dictionary
|
646 |
+
progress_data = {
|
647 |
+
'preview': preview,
|
648 |
+
'desc': desc,
|
649 |
+
'html': make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint)
|
650 |
+
}
|
651 |
+
|
652 |
+
# Store progress data in the job object if using a job stream
|
653 |
+
if job_stream is not None:
|
654 |
+
try:
|
655 |
+
from __main__ import job_queue
|
656 |
+
job = job_queue.get_job(job_id)
|
657 |
+
if job:
|
658 |
+
job.progress_data = progress_data
|
659 |
+
except Exception as e:
|
660 |
+
print(f"Error updating job progress data: {e}")
|
661 |
+
|
662 |
+
# Always push to the job-specific stream
|
663 |
+
stream_to_use.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint))))
|
664 |
+
|
665 |
+
# Always push to the main stream to ensure the UI is updated
|
666 |
+
# This is especially important for resumed jobs
|
667 |
+
from __main__ import stream as main_stream
|
668 |
+
if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use
|
669 |
+
main_stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint))))
|
670 |
+
|
671 |
+
# Also push job ID to main stream to ensure monitoring connection
|
672 |
+
if main_stream:
|
673 |
+
main_stream.output_queue.push(('job_id', job_id))
|
674 |
+
main_stream.output_queue.push(('monitor_job', job_id))
|
675 |
+
|
676 |
+
# MagCache / TeaCache Initialization Logic
|
677 |
+
magcache = None
|
678 |
+
# RT_BORG: I cringe at this, but refactoring to introduce an actual model class will fix it.
|
679 |
+
model_family = "F1" if "F1" in model_type else "Original"
|
680 |
+
|
681 |
+
if settings.get("calibrate_magcache"): # Calibration mode (forces MagCache on)
|
682 |
+
print("Setting Up MagCache for Calibration")
|
683 |
+
is_calibrating = settings.get("calibrate_magcache")
|
684 |
+
studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off
|
685 |
+
magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=is_calibrating, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio)
|
686 |
+
studio_module.current_generator.transformer.install_magcache(magcache)
|
687 |
+
elif use_magcache: # User selected MagCache
|
688 |
+
print("Setting Up MagCache")
|
689 |
+
magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=False, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio)
|
690 |
+
studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off
|
691 |
+
studio_module.current_generator.transformer.install_magcache(magcache)
|
692 |
+
elif use_teacache:
|
693 |
+
print("Setting Up TeaCache")
|
694 |
+
studio_module.current_generator.transformer.initialize_teacache(enable_teacache=True, num_steps=teacache_num_steps, rel_l1_thresh=teacache_rel_l1_thresh)
|
695 |
+
studio_module.current_generator.transformer.uninstall_magcache()
|
696 |
+
else:
|
697 |
+
print("No Transformer Cache in use")
|
698 |
+
studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False)
|
699 |
+
studio_module.current_generator.transformer.uninstall_magcache()
|
700 |
+
|
701 |
+
# --- Main generation loop ---
|
702 |
+
# `i_section_loop` will be our loop counter for applying end_frame_latent
|
703 |
+
for i_section_loop, latent_padding in enumerate(latent_paddings): # Existing loop structure
|
704 |
+
is_last_section = latent_padding == 0
|
705 |
+
latent_padding_size = latent_padding * latent_window_size
|
706 |
+
|
707 |
+
if stream_to_use.input_queue.top() == 'end':
|
708 |
+
stream_to_use.output_queue.push(('end', None))
|
709 |
+
return
|
710 |
+
|
711 |
+
# Calculate the current time position
|
712 |
+
if model_type == "Video":
|
713 |
+
# For Video model, add the input video time to the current position
|
714 |
+
input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time
|
715 |
+
current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds
|
716 |
+
if current_time_position < 0:
|
717 |
+
current_time_position = 0.01
|
718 |
+
else:
|
719 |
+
# For other models, calculate as before
|
720 |
+
current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds
|
721 |
+
if current_time_position < 0:
|
722 |
+
current_time_position = 0.01
|
723 |
+
|
724 |
+
# Find the appropriate prompt for this section
|
725 |
+
current_prompt = prompt_sections[0].prompt # Default to first prompt
|
726 |
+
for section in prompt_sections:
|
727 |
+
if section.start_time <= current_time_position and (section.end_time is None or current_time_position < section.end_time):
|
728 |
+
current_prompt = section.prompt
|
729 |
+
break
|
730 |
+
|
731 |
+
# PROMPT BLENDING: Find if we're in a blend window
|
732 |
+
blend_alpha = None
|
733 |
+
prev_prompt = current_prompt
|
734 |
+
next_prompt = current_prompt
|
735 |
+
|
736 |
+
# Only try to blend if blend_sections > 0 and we have prompt change indices and multiple sections
|
737 |
+
try:
|
738 |
+
blend_sections_int = int(blend_sections)
|
739 |
+
except ValueError:
|
740 |
+
blend_sections_int = 0 # Default to 0 if conversion fails, effectively disabling blending
|
741 |
+
print(f"Warning: blend_sections ('{blend_sections}') is not a valid integer. Disabling prompt blending for this section.")
|
742 |
+
if blend_sections_int > 0 and prompt_change_indices and len(prompt_sections) > 1:
|
743 |
+
for i, (change_idx, prompt) in enumerate(prompt_change_indices):
|
744 |
+
if section_idx < change_idx:
|
745 |
+
prev_prompt = prompt_change_indices[i - 1][1] if i > 0 else prompt
|
746 |
+
next_prompt = prompt
|
747 |
+
blend_start = change_idx
|
748 |
+
blend_end = change_idx + blend_sections
|
749 |
+
if section_idx >= change_idx and section_idx < blend_end:
|
750 |
+
blend_alpha = (section_idx - change_idx + 1) / blend_sections
|
751 |
+
break
|
752 |
+
elif section_idx == change_idx:
|
753 |
+
# At the exact change, start blending
|
754 |
+
if i > 0:
|
755 |
+
prev_prompt = prompt_change_indices[i - 1][1]
|
756 |
+
next_prompt = prompt
|
757 |
+
blend_alpha = 1.0 / blend_sections
|
758 |
+
else:
|
759 |
+
prev_prompt = prompt
|
760 |
+
next_prompt = prompt
|
761 |
+
blend_alpha = None
|
762 |
+
break
|
763 |
+
else:
|
764 |
+
# After last change, no blending
|
765 |
+
prev_prompt = current_prompt
|
766 |
+
next_prompt = current_prompt
|
767 |
+
blend_alpha = None
|
768 |
+
|
769 |
+
# Get the encoded prompt for this section
|
770 |
+
if blend_alpha is not None and prev_prompt != next_prompt:
|
771 |
+
# Blend embeddings
|
772 |
+
prev_llama_vec, prev_llama_attention_mask, prev_clip_l_pooler = encoded_prompts[prev_prompt]
|
773 |
+
next_llama_vec, next_llama_attention_mask, next_clip_l_pooler = encoded_prompts[next_prompt]
|
774 |
+
llama_vec = (1 - blend_alpha) * prev_llama_vec + blend_alpha * next_llama_vec
|
775 |
+
llama_attention_mask = prev_llama_attention_mask # usually same
|
776 |
+
clip_l_pooler = (1 - blend_alpha) * prev_clip_l_pooler + blend_alpha * next_clip_l_pooler
|
777 |
+
print(f"Blending prompts: '{prev_prompt[:30]}...' -> '{next_prompt[:30]}...', alpha={blend_alpha:.2f}")
|
778 |
+
else:
|
779 |
+
llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[current_prompt]
|
780 |
+
|
781 |
+
original_time_position = total_second_length - current_time_position
|
782 |
+
if original_time_position < 0:
|
783 |
+
original_time_position = 0
|
784 |
+
|
785 |
+
print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, '
|
786 |
+
f'time position: {current_time_position:.2f}s (original: {original_time_position:.2f}s), '
|
787 |
+
f'using prompt: {current_prompt[:60]}...')
|
788 |
+
|
789 |
+
# Apply end_frame_latent to history_latents for models with Endframe support
|
790 |
+
if (model_type == "Original with Endframe") and i_section_loop == 0 and end_frame_latent is not None:
|
791 |
+
print(f"Applying end_frame_latent to history_latents with strength: {end_frame_strength}")
|
792 |
+
actual_end_frame_latent_for_history = end_frame_latent.clone()
|
793 |
+
if end_frame_strength != 1.0: # Only multiply if not full strength
|
794 |
+
actual_end_frame_latent_for_history = actual_end_frame_latent_for_history * end_frame_strength
|
795 |
+
|
796 |
+
# Ensure history_latents is on the correct device (usually CPU for this kind of modification if it's init'd there)
|
797 |
+
# and that the assigned tensor matches its dtype.
|
798 |
+
# The `studio_module.current_generator.prepare_history_latents` initializes it on CPU with float32.
|
799 |
+
if history_latents.shape[2] >= 1: # Check if the 'Depth_slots' dimension is sufficient
|
800 |
+
if model_type == "Original with Endframe":
|
801 |
+
# For Original model, apply to the beginning (position 0)
|
802 |
+
history_latents[:, :, 0:1, :, :] = actual_end_frame_latent_for_history.to(
|
803 |
+
device=history_latents.device, # Assign to history_latents' current device
|
804 |
+
dtype=history_latents.dtype # Match history_latents' dtype
|
805 |
+
)
|
806 |
+
elif model_type == "F1 with Endframe":
|
807 |
+
# For F1 model, apply to the end (last position)
|
808 |
+
history_latents[:, :, -1:, :, :] = actual_end_frame_latent_for_history.to(
|
809 |
+
device=history_latents.device, # Assign to history_latents' current device
|
810 |
+
dtype=history_latents.dtype # Match history_latents' dtype
|
811 |
+
)
|
812 |
+
print(f"End frame latent applied to history for {model_type} model.")
|
813 |
+
else:
|
814 |
+
print("Warning: history_latents not shaped as expected for end_frame application.")
|
815 |
+
|
816 |
+
|
817 |
+
# Video models use combined methods to prepare clean latents and indices
|
818 |
+
if model_type == "Video":
|
819 |
+
# Get num_cleaned_frames from job_params if available, otherwise use default value of 5
|
820 |
+
num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
|
821 |
+
clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \
|
822 |
+
studio_module.current_generator.video_prepare_clean_latents_and_indices(end_frame_output_dimensions_latent, end_frame_strength, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames)
|
823 |
+
elif model_type == "Video F1":
|
824 |
+
# Get num_cleaned_frames from job_params if available, otherwise use default value of 5
|
825 |
+
num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
|
826 |
+
clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \
|
827 |
+
studio_module.current_generator.video_f1_prepare_clean_latents_and_indices(latent_window_size, video_latents, history_latents, num_cleaned_frames)
|
828 |
+
else:
|
829 |
+
# Prepare indices using the generator
|
830 |
+
clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices = studio_module.current_generator.prepare_indices(latent_padding_size, latent_window_size)
|
831 |
+
|
832 |
+
# Prepare clean latents using the generator
|
833 |
+
clean_latents, clean_latents_2x, clean_latents_4x = studio_module.current_generator.prepare_clean_latents(start_latent, history_latents)
|
834 |
+
|
835 |
+
# Print debug info
|
836 |
+
print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, latent_padding={latent_padding}")
|
837 |
+
|
838 |
+
if not high_vram:
|
839 |
+
# Unload VAE etc. before loading transformer
|
840 |
+
unload_complete_models(vae, text_encoder, text_encoder_2, image_encoder)
|
841 |
+
move_model_to_device_with_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
|
842 |
+
if selected_loras:
|
843 |
+
studio_module.current_generator.move_lora_adapters_to_device(gpu)
|
844 |
+
|
845 |
+
|
846 |
+
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
|
847 |
+
generated_latents = sample_hunyuan(
|
848 |
+
transformer=studio_module.current_generator.transformer,
|
849 |
+
width=width,
|
850 |
+
height=height,
|
851 |
+
frames=num_frames,
|
852 |
+
real_guidance_scale=cfg,
|
853 |
+
distilled_guidance_scale=gs,
|
854 |
+
guidance_rescale=rs,
|
855 |
+
num_inference_steps=steps,
|
856 |
+
generator=random_generator,
|
857 |
+
prompt_embeds=llama_vec,
|
858 |
+
prompt_embeds_mask=llama_attention_mask,
|
859 |
+
prompt_poolers=clip_l_pooler,
|
860 |
+
negative_prompt_embeds=llama_vec_n,
|
861 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
862 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
863 |
+
device=gpu,
|
864 |
+
dtype=torch.bfloat16,
|
865 |
+
image_embeddings=image_encoder_last_hidden_state,
|
866 |
+
latent_indices=latent_indices,
|
867 |
+
clean_latents=clean_latents,
|
868 |
+
clean_latent_indices=clean_latent_indices,
|
869 |
+
clean_latents_2x=clean_latents_2x,
|
870 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
871 |
+
clean_latents_4x=clean_latents_4x,
|
872 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
873 |
+
callback=callback,
|
874 |
+
)
|
875 |
+
|
876 |
+
# RT_BORG: Observe the MagCache skip patterns during dev.
|
877 |
+
# RT_BORG: We need to use a real logger soon!
|
878 |
+
# if magcache is not None and magcache.is_enabled:
|
879 |
+
# print(f"MagCache skipped: {len(magcache.steps_skipped_list)} of {steps} steps: {magcache.steps_skipped_list}")
|
880 |
+
|
881 |
+
total_generated_latent_frames += int(generated_latents.shape[2])
|
882 |
+
# Update history latents using the generator
|
883 |
+
history_latents = studio_module.current_generator.update_history_latents(history_latents, generated_latents)
|
884 |
+
|
885 |
+
if not high_vram:
|
886 |
+
if selected_loras:
|
887 |
+
studio_module.current_generator.move_lora_adapters_to_device(cpu)
|
888 |
+
offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
|
889 |
+
load_model_as_complete(vae, target_device=gpu)
|
890 |
+
|
891 |
+
# Get real history latents using the generator
|
892 |
+
real_history_latents = studio_module.current_generator.get_real_history_latents(history_latents, total_generated_latent_frames)
|
893 |
+
|
894 |
+
if history_pixels is None:
|
895 |
+
history_pixels = vae_decode(real_history_latents, vae).cpu()
|
896 |
+
else:
|
897 |
+
section_latent_frames = studio_module.current_generator.get_section_latent_frames(latent_window_size, is_last_section)
|
898 |
+
overlapped_frames = latent_window_size * 4 - 3
|
899 |
+
|
900 |
+
# Get current pixels using the generator
|
901 |
+
current_pixels = studio_module.current_generator.get_current_pixels(real_history_latents, section_latent_frames, vae)
|
902 |
+
|
903 |
+
# Update history pixels using the generator
|
904 |
+
history_pixels = studio_module.current_generator.update_history_pixels(history_pixels, current_pixels, overlapped_frames)
|
905 |
+
|
906 |
+
print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, history_pixels shape: {history_pixels.shape}")
|
907 |
+
|
908 |
+
if not high_vram:
|
909 |
+
unload_complete_models()
|
910 |
+
|
911 |
+
output_filename = os.path.join(output_dir, f'{job_id}_{total_generated_latent_frames}.mp4')
|
912 |
+
save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=settings.get("mp4_crf"))
|
913 |
+
print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
|
914 |
+
stream_to_use.output_queue.push(('file', output_filename))
|
915 |
+
|
916 |
+
if is_last_section:
|
917 |
+
break
|
918 |
+
|
919 |
+
section_idx += 1 # PROMPT BLENDING: increment section index
|
920 |
+
|
921 |
+
# We'll handle combining the videos after the entire generation is complete
|
922 |
+
# This section intentionally left empty to remove the in-process combination
|
923 |
+
# --- END Main generation loop ---
|
924 |
+
|
925 |
+
magcache = studio_module.current_generator.transformer.magcache
|
926 |
+
if magcache is not None:
|
927 |
+
if magcache.is_calibrating:
|
928 |
+
output_file = os.path.join(settings.get("output_dir"), "magcache_configuration.txt")
|
929 |
+
print(f"MagCache calibration job complete. Appending stats to configuration file: {output_file}")
|
930 |
+
magcache.append_calibration_to_file(output_file)
|
931 |
+
elif magcache.is_enabled:
|
932 |
+
print(f"MagCache ({100.0 * magcache.total_cache_hits / magcache.total_cache_requests:.2f}%) skipped {magcache.total_cache_hits} of {magcache.total_cache_requests} steps.")
|
933 |
+
studio_module.current_generator.transformer.uninstall_magcache()
|
934 |
+
magcache = None
|
935 |
+
|
936 |
+
# Handle the results
|
937 |
+
result = pipeline.handle_results(job_params, output_filename)
|
938 |
+
|
939 |
+
# Unload all LoRAs after generation completed
|
940 |
+
if selected_loras:
|
941 |
+
print("Unloading all LoRAs after generation completed")
|
942 |
+
studio_module.current_generator.unload_loras()
|
943 |
+
import gc
|
944 |
+
gc.collect()
|
945 |
+
if torch.cuda.is_available():
|
946 |
+
torch.cuda.empty_cache()
|
947 |
+
|
948 |
+
except Exception as e:
|
949 |
+
traceback.print_exc()
|
950 |
+
# Unload all LoRAs after error
|
951 |
+
if studio_module.current_generator is not None and selected_loras:
|
952 |
+
print("Unloading all LoRAs after error")
|
953 |
+
studio_module.current_generator.unload_loras()
|
954 |
+
import gc
|
955 |
+
gc.collect()
|
956 |
+
if torch.cuda.is_available():
|
957 |
+
torch.cuda.empty_cache()
|
958 |
+
|
959 |
+
stream_to_use.output_queue.push(('error', f"Error during generation: {traceback.format_exc()}"))
|
960 |
+
if not high_vram:
|
961 |
+
# Ensure all models including the potentially active transformer are unloaded on error
|
962 |
+
unload_complete_models(
|
963 |
+
text_encoder, text_encoder_2, image_encoder, vae,
|
964 |
+
studio_module.current_generator.transformer if studio_module.current_generator else None
|
965 |
+
)
|
966 |
+
finally:
|
967 |
+
# This finally block is associated with the main try block (starts around line 154)
|
968 |
+
if settings.get("clean_up_videos"):
|
969 |
+
try:
|
970 |
+
video_files = [
|
971 |
+
f for f in os.listdir(output_dir)
|
972 |
+
if f.startswith(f"{job_id}_") and f.endswith(".mp4")
|
973 |
+
]
|
974 |
+
print(f"Video files found for cleanup: {video_files}")
|
975 |
+
if video_files:
|
976 |
+
def get_frame_count(filename):
|
977 |
+
try:
|
978 |
+
# Handles filenames like jobid_123.mp4
|
979 |
+
return int(filename.replace(f"{job_id}_", "").replace(".mp4", ""))
|
980 |
+
except Exception:
|
981 |
+
return -1
|
982 |
+
video_files_sorted = sorted(video_files, key=get_frame_count)
|
983 |
+
print(f"Sorted video files: {video_files_sorted}")
|
984 |
+
final_video = video_files_sorted[-1]
|
985 |
+
for vf in video_files_sorted[:-1]:
|
986 |
+
full_path = os.path.join(output_dir, vf)
|
987 |
+
try:
|
988 |
+
os.remove(full_path)
|
989 |
+
print(f"Deleted intermediate video: {full_path}")
|
990 |
+
except Exception as e:
|
991 |
+
print(f"Failed to delete {full_path}: {e}")
|
992 |
+
except Exception as e:
|
993 |
+
print(f"Error during video cleanup: {e}")
|
994 |
+
|
995 |
+
# Check if the user wants to combine the source video with the generated video
|
996 |
+
# This is done after the video cleanup routine to ensure the combined video is not deleted
|
997 |
+
# RT_BORG: Retain (but suppress) this original way to combine videos until the new combiner is proven.
|
998 |
+
combine_v1 = False
|
999 |
+
if combine_v1 and (model_type == "Video" or model_type == "Video F1") and combine_with_source and job_params.get('input_image_path'):
|
1000 |
+
print("Creating combined video with source and generated content...")
|
1001 |
+
try:
|
1002 |
+
input_video_path = job_params.get('input_image_path')
|
1003 |
+
if input_video_path and os.path.exists(input_video_path):
|
1004 |
+
final_video_path_for_combine = None # Use a different variable name to avoid conflict
|
1005 |
+
video_files_for_combine = [
|
1006 |
+
f for f in os.listdir(output_dir)
|
1007 |
+
if f.startswith(f"{job_id}_") and f.endswith(".mp4") and "combined" not in f
|
1008 |
+
]
|
1009 |
+
|
1010 |
+
if video_files_for_combine:
|
1011 |
+
def get_frame_count_for_combine(filename): # Renamed to avoid conflict
|
1012 |
+
try:
|
1013 |
+
return int(filename.replace(f"{job_id}_", "").replace(".mp4", ""))
|
1014 |
+
except Exception:
|
1015 |
+
return float('inf')
|
1016 |
+
|
1017 |
+
video_files_sorted_for_combine = sorted(video_files_for_combine, key=get_frame_count_for_combine)
|
1018 |
+
if video_files_sorted_for_combine: # Check if the list is not empty
|
1019 |
+
final_video_path_for_combine = os.path.join(output_dir, video_files_sorted_for_combine[-1])
|
1020 |
+
|
1021 |
+
if final_video_path_for_combine and os.path.exists(final_video_path_for_combine):
|
1022 |
+
combined_output_filename = os.path.join(output_dir, f'{job_id}_combined_v1.mp4')
|
1023 |
+
combined_result = None
|
1024 |
+
try:
|
1025 |
+
if hasattr(studio_module.current_generator, 'combine_videos'):
|
1026 |
+
print(f"Using VideoModelGenerator.combine_videos to create side-by-side comparison")
|
1027 |
+
combined_result = studio_module.current_generator.combine_videos(
|
1028 |
+
source_video_path=input_video_path,
|
1029 |
+
generated_video_path=final_video_path_for_combine, # Use the correct variable
|
1030 |
+
output_path=combined_output_filename
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
if combined_result:
|
1034 |
+
print(f"Combined video saved to: {combined_result}")
|
1035 |
+
stream_to_use.output_queue.push(('file', combined_result))
|
1036 |
+
else:
|
1037 |
+
print("Failed to create combined video, falling back to direct ffmpeg method")
|
1038 |
+
combined_result = None
|
1039 |
+
else:
|
1040 |
+
print("VideoModelGenerator does not have combine_videos method. Using fallback method.")
|
1041 |
+
except Exception as e_combine: # Use a different exception variable name
|
1042 |
+
print(f"Error in combine_videos method: {e_combine}")
|
1043 |
+
print("Falling back to direct ffmpeg method")
|
1044 |
+
combined_result = None
|
1045 |
+
|
1046 |
+
if not combined_result:
|
1047 |
+
print("Using fallback method to combine videos")
|
1048 |
+
from modules.toolbox.toolbox_processor import VideoProcessor
|
1049 |
+
from modules.toolbox.message_manager import MessageManager
|
1050 |
+
|
1051 |
+
message_manager = MessageManager()
|
1052 |
+
# Pass settings.settings if it exists, otherwise pass the settings object
|
1053 |
+
video_processor_settings = settings.settings if hasattr(settings, 'settings') else settings
|
1054 |
+
video_processor = VideoProcessor(message_manager, video_processor_settings)
|
1055 |
+
ffmpeg_exe = video_processor.ffmpeg_exe
|
1056 |
+
|
1057 |
+
if ffmpeg_exe:
|
1058 |
+
print(f"Using ffmpeg at: {ffmpeg_exe}")
|
1059 |
+
import subprocess
|
1060 |
+
temp_list_file = os.path.join(output_dir, f'{job_id}_filelist.txt')
|
1061 |
+
with open(temp_list_file, 'w') as f:
|
1062 |
+
f.write(f"file '{input_video_path}'\n")
|
1063 |
+
f.write(f"file '{final_video_path_for_combine}'\n") # Use the correct variable
|
1064 |
+
|
1065 |
+
ffmpeg_cmd = [
|
1066 |
+
ffmpeg_exe, "-y", "-f", "concat", "-safe", "0",
|
1067 |
+
"-i", temp_list_file, "-c", "copy", combined_output_filename
|
1068 |
+
]
|
1069 |
+
print(f"Running ffmpeg command: {' '.join(ffmpeg_cmd)}")
|
1070 |
+
subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True)
|
1071 |
+
if os.path.exists(temp_list_file):
|
1072 |
+
os.remove(temp_list_file)
|
1073 |
+
print(f"Combined video saved to: {combined_output_filename}")
|
1074 |
+
stream_to_use.output_queue.push(('file', combined_output_filename))
|
1075 |
+
else:
|
1076 |
+
print("FFmpeg executable not found. Cannot combine videos.")
|
1077 |
+
else:
|
1078 |
+
print(f"Final video not found for combining with source: {final_video_path_for_combine}")
|
1079 |
+
else:
|
1080 |
+
print(f"Input video path not found: {input_video_path}")
|
1081 |
+
except Exception as e_combine_outer: # Use a different exception variable name
|
1082 |
+
print(f"Error combining videos: {e_combine_outer}")
|
1083 |
+
traceback.print_exc()
|
1084 |
+
|
1085 |
+
# Combine input frames (resized and center cropped if needed) with final generated history_pixels tensor sequentially ---
|
1086 |
+
# This creates ID_combined.mp4
|
1087 |
+
# RT_BORG: Be sure to add this check if we decide to retain the processed input frames for "small" input videos
|
1088 |
+
# and job_params.get('input_frames_resized_np') is not None
|
1089 |
+
if (model_type == "Video" or model_type == "Video F1") and combine_with_source and history_pixels is not None:
|
1090 |
+
print(f"Creating combined video ({job_id}_combined.mp4) with processed input frames and generated history_pixels tensor...")
|
1091 |
+
try:
|
1092 |
+
# input_frames_resized_np = job_params.get('input_frames_resized_np')
|
1093 |
+
|
1094 |
+
# RT_BORG: I cringe calliing methods on BaseModelGenerator that only exist on VideoBaseGenerator, until we refactor
|
1095 |
+
input_frames_resized_np, fps, target_height, target_width = studio_module.current_generator.extract_video_frames(
|
1096 |
+
is_for_encode=False,
|
1097 |
+
video_path=job_params['input_image'],
|
1098 |
+
resolution=job_params['resolutionW'],
|
1099 |
+
no_resize=False,
|
1100 |
+
input_files_dir=job_params['input_files_dir']
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
# history_pixels is (B, C, T, H, W), float32, [-1,1], on CPU
|
1104 |
+
if input_frames_resized_np is not None and history_pixels.numel() > 0 : # Check if history_pixels is not empty
|
1105 |
+
combined_sequential_output_filename = os.path.join(output_dir, f'{job_id}_combined.mp4')
|
1106 |
+
|
1107 |
+
# fps variable should be from the video_encode call earlier.
|
1108 |
+
input_video_fps_for_combine = fps
|
1109 |
+
current_crf = settings.get("mp4_crf", 16)
|
1110 |
+
|
1111 |
+
# Call the new function from video_tools.py
|
1112 |
+
combined_sequential_result_path = combine_videos_sequentially_from_tensors(
|
1113 |
+
processed_input_frames_np=input_frames_resized_np,
|
1114 |
+
generated_frames_pt=history_pixels,
|
1115 |
+
output_path=combined_sequential_output_filename,
|
1116 |
+
target_fps=input_video_fps_for_combine,
|
1117 |
+
crf_value=current_crf
|
1118 |
+
)
|
1119 |
+
if combined_sequential_result_path:
|
1120 |
+
stream_to_use.output_queue.push(('file', combined_sequential_result_path))
|
1121 |
+
except Exception as e:
|
1122 |
+
print(f"Error creating combined video ({job_id}_combined.mp4): {e}")
|
1123 |
+
traceback.print_exc()
|
1124 |
+
|
1125 |
+
# Final verification of LoRA state
|
1126 |
+
if studio_module.current_generator and studio_module.current_generator.transformer:
|
1127 |
+
# Verify LoRA state
|
1128 |
+
has_loras = False
|
1129 |
+
if hasattr(studio_module.current_generator.transformer, 'peft_config'):
|
1130 |
+
adapter_names = list(studio_module.current_generator.transformer.peft_config.keys()) if studio_module.current_generator.transformer.peft_config else []
|
1131 |
+
if adapter_names:
|
1132 |
+
has_loras = True
|
1133 |
+
print(f"Transformer has LoRAs: {', '.join(adapter_names)}")
|
1134 |
+
else:
|
1135 |
+
print(f"Transformer has no LoRAs in peft_config")
|
1136 |
+
else:
|
1137 |
+
print(f"Transformer has no peft_config attribute")
|
1138 |
+
|
1139 |
+
# Check for any LoRA modules
|
1140 |
+
for name, module in studio_module.current_generator.transformer.named_modules():
|
1141 |
+
if hasattr(module, 'lora_A') and module.lora_A:
|
1142 |
+
has_loras = True
|
1143 |
+
if hasattr(module, 'lora_B') and module.lora_B:
|
1144 |
+
has_loras = True
|
1145 |
+
|
1146 |
+
if not has_loras:
|
1147 |
+
print(f"No LoRA components found in transformer")
|
1148 |
+
|
1149 |
+
stream_to_use.output_queue.push(('end', None))
|
1150 |
+
return
|
modules/prompt_handler.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class PromptSection:
|
8 |
+
"""Represents a section of the prompt with specific timing information"""
|
9 |
+
prompt: str
|
10 |
+
start_time: float = 0 # in seconds
|
11 |
+
end_time: Optional[float] = None # in seconds, None means until the end
|
12 |
+
|
13 |
+
|
14 |
+
def snap_to_section_boundaries(prompt_sections: List[PromptSection], latent_window_size: int, fps: int = 30) -> List[PromptSection]:
|
15 |
+
"""
|
16 |
+
Adjust timestamps to align with model's internal section boundaries
|
17 |
+
|
18 |
+
Args:
|
19 |
+
prompt_sections: List of PromptSection objects
|
20 |
+
latent_window_size: Size of the latent window used in the model
|
21 |
+
fps: Frames per second (default: 30)
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List of PromptSection objects with aligned timestamps
|
25 |
+
"""
|
26 |
+
section_duration = (latent_window_size * 4 - 3) / fps # Duration of one section in seconds
|
27 |
+
|
28 |
+
aligned_sections = []
|
29 |
+
for section in prompt_sections:
|
30 |
+
# Snap start time to nearest section boundary
|
31 |
+
aligned_start = round(section.start_time / section_duration) * section_duration
|
32 |
+
|
33 |
+
# Snap end time to nearest section boundary
|
34 |
+
aligned_end = None
|
35 |
+
if section.end_time is not None:
|
36 |
+
aligned_end = round(section.end_time / section_duration) * section_duration
|
37 |
+
|
38 |
+
# Ensure minimum section length
|
39 |
+
if aligned_end is not None and aligned_end <= aligned_start:
|
40 |
+
aligned_end = aligned_start + section_duration
|
41 |
+
|
42 |
+
aligned_sections.append(PromptSection(
|
43 |
+
prompt=section.prompt,
|
44 |
+
start_time=aligned_start,
|
45 |
+
end_time=aligned_end
|
46 |
+
))
|
47 |
+
|
48 |
+
return aligned_sections
|
49 |
+
|
50 |
+
|
51 |
+
def parse_timestamped_prompt(prompt_text: str, total_duration: float, latent_window_size: int = 9, generation_type: str = "Original") -> List[PromptSection]:
|
52 |
+
"""
|
53 |
+
Parse a prompt with timestamps in the format [0s-2s: text] or [3s: text]
|
54 |
+
|
55 |
+
Args:
|
56 |
+
prompt_text: The input prompt text with optional timestamp sections
|
57 |
+
total_duration: Total duration of the video in seconds
|
58 |
+
latent_window_size: Size of the latent window used in the model
|
59 |
+
generation_type: Type of generation ("Original" or "F1")
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
List of PromptSection objects with timestamps aligned to section boundaries
|
63 |
+
and reversed to account for reverse generation (only for Original type)
|
64 |
+
"""
|
65 |
+
# Default prompt for the entire duration if no timestamps are found
|
66 |
+
if "[" not in prompt_text or "]" not in prompt_text:
|
67 |
+
return [PromptSection(prompt=prompt_text.strip())]
|
68 |
+
|
69 |
+
sections = []
|
70 |
+
# Find all timestamp sections [time: text]
|
71 |
+
timestamp_pattern = r'\[(\d+(?:\.\d+)?s)(?:-(\d+(?:\.\d+)?s))?\s*:\s*(.*?)\]'
|
72 |
+
regular_text = prompt_text
|
73 |
+
|
74 |
+
for match in re.finditer(timestamp_pattern, prompt_text):
|
75 |
+
start_time_str = match.group(1)
|
76 |
+
end_time_str = match.group(2)
|
77 |
+
section_text = match.group(3).strip()
|
78 |
+
|
79 |
+
# Convert time strings to seconds
|
80 |
+
start_time = float(start_time_str.rstrip('s'))
|
81 |
+
end_time = float(end_time_str.rstrip('s')) if end_time_str else None
|
82 |
+
|
83 |
+
sections.append(PromptSection(
|
84 |
+
prompt=section_text,
|
85 |
+
start_time=start_time,
|
86 |
+
end_time=end_time
|
87 |
+
))
|
88 |
+
|
89 |
+
# Remove the processed section from regular_text
|
90 |
+
regular_text = regular_text.replace(match.group(0), "")
|
91 |
+
|
92 |
+
# If there's any text outside of timestamp sections, use it as a default for the entire duration
|
93 |
+
regular_text = regular_text.strip()
|
94 |
+
if regular_text:
|
95 |
+
sections.append(PromptSection(
|
96 |
+
prompt=regular_text,
|
97 |
+
start_time=0,
|
98 |
+
end_time=None
|
99 |
+
))
|
100 |
+
|
101 |
+
# Sort sections by start time
|
102 |
+
sections.sort(key=lambda x: x.start_time)
|
103 |
+
|
104 |
+
# Fill in end times if not specified
|
105 |
+
for i in range(len(sections) - 1):
|
106 |
+
if sections[i].end_time is None:
|
107 |
+
sections[i].end_time = sections[i+1].start_time
|
108 |
+
|
109 |
+
# Set the last section's end time to the total duration if not specified
|
110 |
+
if sections and sections[-1].end_time is None:
|
111 |
+
sections[-1].end_time = total_duration
|
112 |
+
|
113 |
+
# Snap timestamps to section boundaries
|
114 |
+
sections = snap_to_section_boundaries(sections, latent_window_size)
|
115 |
+
|
116 |
+
# Only reverse timestamps for Original generation type
|
117 |
+
if generation_type in ("Original", "Original with Endframe", "Video"):
|
118 |
+
# Now reverse the timestamps to account for reverse generation
|
119 |
+
reversed_sections = []
|
120 |
+
for section in sections:
|
121 |
+
reversed_start = total_duration - section.end_time if section.end_time is not None else 0
|
122 |
+
reversed_end = total_duration - section.start_time
|
123 |
+
reversed_sections.append(PromptSection(
|
124 |
+
prompt=section.prompt,
|
125 |
+
start_time=reversed_start,
|
126 |
+
end_time=reversed_end
|
127 |
+
))
|
128 |
+
|
129 |
+
# Sort the reversed sections by start time
|
130 |
+
reversed_sections.sort(key=lambda x: x.start_time)
|
131 |
+
return reversed_sections
|
132 |
+
|
133 |
+
return sections
|
134 |
+
|
135 |
+
|
136 |
+
def get_section_boundaries(latent_window_size: int = 9, count: int = 10) -> str:
|
137 |
+
"""
|
138 |
+
Calculate and format section boundaries for UI display
|
139 |
+
|
140 |
+
Args:
|
141 |
+
latent_window_size: Size of the latent window used in the model
|
142 |
+
count: Number of boundaries to display
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Formatted string of section boundaries
|
146 |
+
"""
|
147 |
+
section_duration = (latent_window_size * 4 - 3) / 30
|
148 |
+
return ", ".join([f"{i*section_duration:.1f}s" for i in range(count)])
|
149 |
+
|
150 |
+
|
151 |
+
def get_quick_prompts() -> List[List[str]]:
|
152 |
+
"""
|
153 |
+
Get a list of example timestamped prompts
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
List of example prompts formatted for Gradio Dataset
|
157 |
+
"""
|
158 |
+
prompts = [
|
159 |
+
'[0s: The person waves hello] [2s: The person jumps up and down] [4s: The person does a spin]',
|
160 |
+
'[0s: The person raises both arms slowly] [2s: The person claps hands enthusiastically]',
|
161 |
+
'[0s: Person gives thumbs up] [1.1s: Person smiles and winks] [2.2s: Person shows two thumbs down]',
|
162 |
+
'[0s: Person looks surprised] [1.1s: Person raises arms above head] [2.2s-3.3s: Person puts hands on hips]'
|
163 |
+
]
|
164 |
+
return [[x] for x in prompts]
|
modules/settings.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, Any, Optional
|
4 |
+
import os
|
5 |
+
|
6 |
+
class Settings:
|
7 |
+
def __init__(self):
|
8 |
+
# Get the project root directory (where settings.py is located)
|
9 |
+
project_root = Path(__file__).parent.parent
|
10 |
+
|
11 |
+
self.settings_file = project_root / ".framepack" / "settings.json"
|
12 |
+
self.settings_file.parent.mkdir(parents=True, exist_ok=True)
|
13 |
+
|
14 |
+
# Set default paths relative to project root
|
15 |
+
self.default_settings = {
|
16 |
+
"save_metadata": True,
|
17 |
+
"gpu_memory_preservation": 6,
|
18 |
+
"output_dir": str(project_root / "outputs"),
|
19 |
+
"metadata_dir": str(project_root / "outputs"),
|
20 |
+
"lora_dir": str(project_root / "loras"),
|
21 |
+
"gradio_temp_dir": str(project_root / "temp"),
|
22 |
+
"input_files_dir": str(project_root / "input_files"), # New setting for input files
|
23 |
+
"auto_save_settings": True,
|
24 |
+
"gradio_theme": "default",
|
25 |
+
"mp4_crf": 16,
|
26 |
+
"clean_up_videos": True,
|
27 |
+
"override_system_prompt": False,
|
28 |
+
"auto_cleanup_on_startup": False, # ADDED: New setting for startup cleanup
|
29 |
+
"latents_display_top": False, # NEW: Control latents preview position (False = right column, True = top of interface)
|
30 |
+
"system_prompt_template": "{\"template\": \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\", \"crop_start\": 95}",
|
31 |
+
"startup_model_type": "None",
|
32 |
+
"startup_preset_name": None,
|
33 |
+
"enhancer_prompt_template": """You are a creative assistant for a text-to-video generator. Your task is to take a user's prompt and make it more descriptive, vivid, and detailed. Focus on visual elements. Do not change the core action, but embellish it.
|
34 |
+
|
35 |
+
User prompt: "{text_to_enhance}"
|
36 |
+
|
37 |
+
Enhanced prompt:"""
|
38 |
+
}
|
39 |
+
self.settings = self.load_settings()
|
40 |
+
|
41 |
+
def load_settings(self) -> Dict[str, Any]:
|
42 |
+
"""Load settings from file or return defaults"""
|
43 |
+
if self.settings_file.exists():
|
44 |
+
try:
|
45 |
+
with open(self.settings_file, 'r') as f:
|
46 |
+
loaded_settings = json.load(f)
|
47 |
+
# Merge with defaults to ensure all settings exist
|
48 |
+
settings = self.default_settings.copy()
|
49 |
+
settings.update(loaded_settings)
|
50 |
+
return settings
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error loading settings: {e}")
|
53 |
+
return self.default_settings.copy()
|
54 |
+
return self.default_settings.copy()
|
55 |
+
|
56 |
+
def save_settings(self, **kwargs):
|
57 |
+
"""Save settings to file. Accepts keyword arguments for any settings to update."""
|
58 |
+
# Update self.settings with any provided keyword arguments
|
59 |
+
self.settings.update(kwargs)
|
60 |
+
# Ensure all default fields are present
|
61 |
+
for k, v in self.default_settings.items():
|
62 |
+
self.settings.setdefault(k, v)
|
63 |
+
|
64 |
+
# Ensure directories exist for relevant fields
|
65 |
+
for dir_key in ["output_dir", "metadata_dir", "lora_dir", "gradio_temp_dir"]:
|
66 |
+
dir_path = self.settings.get(dir_key)
|
67 |
+
if dir_path:
|
68 |
+
os.makedirs(dir_path, exist_ok=True)
|
69 |
+
|
70 |
+
# Save to file
|
71 |
+
with open(self.settings_file, 'w') as f:
|
72 |
+
json.dump(self.settings, f, indent=4)
|
73 |
+
|
74 |
+
def get(self, key: str, default: Any = None) -> Any:
|
75 |
+
"""Get a setting value"""
|
76 |
+
return self.settings.get(key, default)
|
77 |
+
|
78 |
+
def set(self, key: str, value: Any) -> None:
|
79 |
+
"""Set a setting value"""
|
80 |
+
self.settings[key] = value
|
81 |
+
if self.settings.get("auto_save_settings", True):
|
82 |
+
self.save_settings()
|
83 |
+
|
84 |
+
def update(self, settings: Dict[str, Any]) -> None:
|
85 |
+
"""Update multiple settings at once"""
|
86 |
+
self.settings.update(settings)
|
87 |
+
if self.settings.get("auto_save_settings", True):
|
88 |
+
self.save_settings()
|
modules/toolbox/RIFE/IFNet_HDv3.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .warplayer import warp
|
5 |
+
|
6 |
+
|
7 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
8 |
+
return nn.Sequential(
|
9 |
+
nn.Conv2d(
|
10 |
+
in_planes,
|
11 |
+
out_planes,
|
12 |
+
kernel_size=kernel_size,
|
13 |
+
stride=stride,
|
14 |
+
padding=padding,
|
15 |
+
dilation=dilation,
|
16 |
+
bias=True,
|
17 |
+
),
|
18 |
+
nn.PReLU(out_planes),
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
23 |
+
return nn.Sequential(
|
24 |
+
nn.Conv2d(
|
25 |
+
in_planes,
|
26 |
+
out_planes,
|
27 |
+
kernel_size=kernel_size,
|
28 |
+
stride=stride,
|
29 |
+
padding=padding,
|
30 |
+
dilation=dilation,
|
31 |
+
bias=False,
|
32 |
+
),
|
33 |
+
nn.BatchNorm2d(out_planes),
|
34 |
+
nn.PReLU(out_planes),
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class IFBlock(nn.Module):
|
39 |
+
def __init__(self, in_planes, c=64):
|
40 |
+
super(IFBlock, self).__init__()
|
41 |
+
self.conv0 = nn.Sequential(
|
42 |
+
conv(in_planes, c // 2, 3, 2, 1),
|
43 |
+
conv(c // 2, c, 3, 2, 1),
|
44 |
+
)
|
45 |
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
46 |
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
47 |
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
48 |
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
49 |
+
self.conv1 = nn.Sequential(
|
50 |
+
nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
|
51 |
+
nn.PReLU(c // 2),
|
52 |
+
nn.ConvTranspose2d(c // 2, 4, 4, 2, 1),
|
53 |
+
)
|
54 |
+
self.conv2 = nn.Sequential(
|
55 |
+
nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
|
56 |
+
nn.PReLU(c // 2),
|
57 |
+
nn.ConvTranspose2d(c // 2, 1, 4, 2, 1),
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(self, x, flow, scale=1):
|
61 |
+
x = F.interpolate(
|
62 |
+
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
63 |
+
)
|
64 |
+
flow = (
|
65 |
+
F.interpolate(
|
66 |
+
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
67 |
+
)
|
68 |
+
* 1.0
|
69 |
+
/ scale
|
70 |
+
)
|
71 |
+
feat = self.conv0(torch.cat((x, flow), 1))
|
72 |
+
feat = self.convblock0(feat) + feat
|
73 |
+
feat = self.convblock1(feat) + feat
|
74 |
+
feat = self.convblock2(feat) + feat
|
75 |
+
feat = self.convblock3(feat) + feat
|
76 |
+
flow = self.conv1(feat)
|
77 |
+
mask = self.conv2(feat)
|
78 |
+
flow = (
|
79 |
+
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
80 |
+
* scale
|
81 |
+
)
|
82 |
+
mask = F.interpolate(
|
83 |
+
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
84 |
+
)
|
85 |
+
return flow, mask
|
86 |
+
|
87 |
+
|
88 |
+
class IFNet(nn.Module):
|
89 |
+
def __init__(self):
|
90 |
+
super(IFNet, self).__init__()
|
91 |
+
self.block0 = IFBlock(7 + 4, c=90)
|
92 |
+
self.block1 = IFBlock(7 + 4, c=90)
|
93 |
+
self.block2 = IFBlock(7 + 4, c=90)
|
94 |
+
self.block_tea = IFBlock(10 + 4, c=90)
|
95 |
+
# self.contextnet = Contextnet()
|
96 |
+
# self.unet = Unet()
|
97 |
+
|
98 |
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
99 |
+
if training == False:
|
100 |
+
channel = x.shape[1] // 2
|
101 |
+
img0 = x[:, :channel]
|
102 |
+
img1 = x[:, channel:]
|
103 |
+
flow_list = []
|
104 |
+
merged = []
|
105 |
+
mask_list = []
|
106 |
+
warped_img0 = img0
|
107 |
+
warped_img1 = img1
|
108 |
+
flow = (x[:, :4]).detach() * 0
|
109 |
+
mask = (x[:, :1]).detach() * 0
|
110 |
+
loss_cons = 0
|
111 |
+
block = [self.block0, self.block1, self.block2]
|
112 |
+
for i in range(3):
|
113 |
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
114 |
+
f1, m1 = block[i](
|
115 |
+
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
|
116 |
+
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
117 |
+
scale=scale_list[i],
|
118 |
+
)
|
119 |
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
120 |
+
mask = mask + (m0 + (-m1)) / 2
|
121 |
+
mask_list.append(mask)
|
122 |
+
flow_list.append(flow)
|
123 |
+
warped_img0 = warp(img0, flow[:, :2])
|
124 |
+
warped_img1 = warp(img1, flow[:, 2:4])
|
125 |
+
merged.append((warped_img0, warped_img1))
|
126 |
+
"""
|
127 |
+
c0 = self.contextnet(img0, flow[:, :2])
|
128 |
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
129 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
130 |
+
res = tmp[:, 1:4] * 2 - 1
|
131 |
+
"""
|
132 |
+
for i in range(3):
|
133 |
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
134 |
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
135 |
+
# merged[i] = torch.clamp(merged[i] + res, 0, 1)
|
136 |
+
return flow_list, mask_list[2], merged
|
modules/toolbox/RIFE/RIFE_HDv3.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.optim import AdamW
|
6 |
+
import numpy as np
|
7 |
+
import itertools
|
8 |
+
from .warplayer import warp
|
9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
+
from .IFNet_HDv3 import *
|
11 |
+
from .loss import *
|
12 |
+
import devicetorch
|
13 |
+
device = devicetorch.get(torch)
|
14 |
+
|
15 |
+
|
16 |
+
class Model:
|
17 |
+
def __init__(self, local_rank=-1):
|
18 |
+
self.flownet = IFNet()
|
19 |
+
self.device()
|
20 |
+
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
|
21 |
+
self.epe = EPE()
|
22 |
+
# self.vgg = VGGPerceptualLoss().to(device)
|
23 |
+
self.sobel = SOBEL()
|
24 |
+
if local_rank != -1:
|
25 |
+
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
|
26 |
+
|
27 |
+
def train(self):
|
28 |
+
self.flownet.train()
|
29 |
+
|
30 |
+
def eval(self):
|
31 |
+
self.flownet.eval()
|
32 |
+
|
33 |
+
def device(self):
|
34 |
+
self.flownet.to(device)
|
35 |
+
|
36 |
+
def load_model(self, path, rank=0):
|
37 |
+
def convert(param):
|
38 |
+
if rank == -1:
|
39 |
+
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
|
40 |
+
else:
|
41 |
+
return param
|
42 |
+
|
43 |
+
if rank <= 0:
|
44 |
+
model_path = "{}/flownet.pkl".format(path)
|
45 |
+
# Check PyTorch version to safely use weights_only
|
46 |
+
from packaging import version
|
47 |
+
use_weights_only = version.parse(torch.__version__) >= version.parse("1.13")
|
48 |
+
|
49 |
+
load_kwargs = {}
|
50 |
+
if not torch.cuda.is_available():
|
51 |
+
load_kwargs['map_location'] = "cpu"
|
52 |
+
|
53 |
+
if use_weights_only:
|
54 |
+
# For modern PyTorch, be explicit and safe
|
55 |
+
load_kwargs['weights_only'] = True
|
56 |
+
# print(f"PyTorch >= 1.13 detected. Loading RIFE model with weights_only=True.")
|
57 |
+
state_dict = torch.load(model_path, **load_kwargs)
|
58 |
+
else:
|
59 |
+
# For older PyTorch, load the old way
|
60 |
+
print(f"PyTorch < 1.13 detected. Loading RIFE model using legacy method.")
|
61 |
+
state_dict = torch.load(model_path, **load_kwargs)
|
62 |
+
|
63 |
+
self.flownet.load_state_dict(convert(state_dict))
|
64 |
+
|
65 |
+
def inference(self, img0, img1, scale=1.0):
|
66 |
+
imgs = torch.cat((img0, img1), 1)
|
67 |
+
scale_list = [4 / scale, 2 / scale, 1 / scale]
|
68 |
+
flow, mask, merged = self.flownet(imgs, scale_list)
|
69 |
+
return merged[2]
|
70 |
+
|
71 |
+
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
72 |
+
for param_group in self.optimG.param_groups:
|
73 |
+
param_group["lr"] = learning_rate
|
74 |
+
img0 = imgs[:, :3]
|
75 |
+
img1 = imgs[:, 3:]
|
76 |
+
if training:
|
77 |
+
self.train()
|
78 |
+
else:
|
79 |
+
self.eval()
|
80 |
+
scale = [4, 2, 1]
|
81 |
+
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
|
82 |
+
loss_l1 = (merged[2] - gt).abs().mean()
|
83 |
+
loss_smooth = self.sobel(flow[2], flow[2] * 0).mean()
|
84 |
+
# loss_vgg = self.vgg(merged[2], gt)
|
85 |
+
if training:
|
86 |
+
self.optimG.zero_grad()
|
87 |
+
loss_G = loss_cons + loss_smooth * 0.1
|
88 |
+
loss_G.backward()
|
89 |
+
self.optimG.step()
|
90 |
+
else:
|
91 |
+
flow_teacher = flow[2]
|
92 |
+
return merged[2], {
|
93 |
+
"mask": mask,
|
94 |
+
"flow": flow[2][:, :2],
|
95 |
+
"loss_l1": loss_l1,
|
96 |
+
"loss_cons": loss_cons,
|
97 |
+
"loss_smooth": loss_smooth,
|
98 |
+
}
|
modules/toolbox/RIFE/__int__.py
ADDED
File without changes
|